Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1914,6 +1914,112 @@ int32_t common_speculative_n_max(const common_params_speculative * spec) {
return n_max;
}

common_params common_base_params_to_speculative(const common_params & params) {
const bool has_draft = params.speculative.has_dft();

const auto & params_spec = params.speculative.draft;
common_params result = params;

if (has_draft) {
result.devices = params_spec.devices;
result.model = params_spec.mparams;
result.n_gpu_layers = params_spec.n_gpu_layers;
result.tensor_buft_overrides = params_spec.tensor_buft_overrides;

if (params_spec.cpuparams.n_threads > 0) {
result.cpuparams.n_threads = params_spec.cpuparams.n_threads;
result.cpuparams_batch.n_threads = params_spec.cpuparams_batch.n_threads;
}
}

result.cache_type_k = params_spec.cache_type_k;
result.cache_type_v = params_spec.cache_type_v;
result.n_outputs_max = params.n_parallel;

return result;
}

struct common_init_speculative_result::impl {
impl() = default;
~impl() = default;

// note: the order in which model, context, etc. are declared matters because their destructors will be called bottom-to-top
llama_model_ptr model;
llama_context_ptr context;
};

common_init_speculative_result::common_init_speculative_result(
common_params & params,
llama_model * model_tgt,
llama_context * ctx_tgt) :
pimpl(new impl{}) {
const bool has_draft = params.speculative.has_dft();
const bool spec_mtp = std::find(params.speculative.types.begin(),
params.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
GGML_ASSERT(has_draft || spec_mtp);

auto mparams = common_model_params_to_llama(params);
auto cparams = common_context_params_to_llama(params);

if (spec_mtp) {
cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
}

// note: for small models maybe we can set this to the maximum possible draft from all speculative types
// the extra memory for small models is likely negligible?
cparams.n_rs_seq = 0;
cparams.ctx_other = ctx_tgt;

std::string model_path;
if (has_draft) {
model_path = params.speculative.draft.mparams.path;

llama_model * model_dft = llama_model_load_from_file(params.model.path.c_str(), mparams);
if (model_dft == NULL) {
LOG_ERR("%s: failed to load draft model, '%s'\n", __func__, model_path.c_str());
return;
}

pimpl->model.reset(model_dft);

llama_context * ctx_dft = llama_init_from_model(model_dft, cparams);
if (ctx_dft == nullptr) {
LOG_ERR("%s: failed to create MTP context\n", __func__);
return;
}

pimpl->context.reset(ctx_dft);
} else if (spec_mtp) {
model_path = params.model.path;

LOG_INF("%s: creating MTP draft context against the target model '%s'\n",
__func__, model_path.c_str());

llama_context * ctx_dft = llama_init_from_model(model_tgt, cparams);
if (ctx_dft == nullptr) {
LOG_ERR("%s: failed to create MTP context\n", __func__);
return;
}

pimpl->context.reset(ctx_dft);
}
}

common_init_speculative_result::~common_init_speculative_result() = default;

llama_model * common_init_speculative_result::model() {
return pimpl->model.get();
}

llama_context * common_init_speculative_result::context() {
return pimpl->context.get();
}

common_init_speculative_result_ptr common_init_speculative(common_params & params, llama_model * model_tgt, llama_context * ctx_tgt) {
return std::make_unique<common_init_speculative_result>(params, model_tgt, ctx_tgt);
}

// initialization of the speculative decoding system
//
common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq) {
Expand Down
18 changes: 18 additions & 0 deletions common/speculative.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ std::string common_speculative_type_to_str(enum common_speculative_type type);
// return the max number of draft tokens based on the speculative parameters
int32_t common_speculative_n_max(const common_params_speculative * spec);

common_params common_base_params_to_speculative(const common_params & params);

common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq);

void common_speculative_free(common_speculative * spec);
Expand Down Expand Up @@ -80,3 +82,19 @@ struct common_speculative_deleter {
};

typedef std::unique_ptr<common_speculative, common_speculative_deleter> common_speculative_ptr;

struct common_init_speculative_result {
common_init_speculative_result(common_params & params, llama_model * model_tgt, llama_context * ctx_tgt);
~common_init_speculative_result();

llama_model * model();
llama_context * context();

private:
struct impl;
std::unique_ptr<impl> pimpl;
};

using common_init_speculative_result_ptr = std::unique_ptr<common_init_speculative_result>;

common_init_speculative_result_ptr common_init_speculative(common_params & params, llama_model * model_tgt, llama_context * ctx_tgt);
Loading