Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions conversion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"DbrxForCausalLM": "dbrx",
"DeciLMForCausalLM": "deci",
"DeepseekForCausalLM": "deepseek",
"DeepseekOCRForCausalLM": "deepseek",
"DeepseekV2ForCausalLM": "deepseek",
"DeepseekV3ForCausalLM": "deepseek",
"DeepseekV32ForCausalLM": "deepseek",
Expand Down Expand Up @@ -231,6 +232,7 @@
"UMT5ForConditionalGeneration": "t5",
"UMT5Model": "t5",
"UltravoxModel": "ultravox",
"UnlimitedOCRForCausalLM": "deepseek",
"VLlama3ForCausalLM": "llama",
"VoxtralForConditionalGeneration": "llama",
"WavTokenizerDec": "wavtokenizer",
Expand Down Expand Up @@ -296,6 +298,7 @@
"StepVLForConditionalGeneration": "step3",
"Step3p7ForConditionalGeneration": "step3",
"UltravoxModel": "ultravox",
"UnlimitedOCRForCausalLM": "deepseek",
"VoxtralForConditionalGeneration": "ultravox",
"YoutuVLForConditionalGeneration": "youtuvl",
}
Expand Down
12 changes: 10 additions & 2 deletions conversion/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .qwen import QwenModel


@ModelBase.register("DeepseekOCRForCausalLM")
@ModelBase.register("DeepseekOCRForCausalLM", "UnlimitedOCRForCausalLM")
class DeepseekOCRVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -205,6 +205,8 @@ def prepare_tensors(self):
@ModelBase.register(
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekOCRForCausalLM",
"UnlimitedOCRForCausalLM",
"KimiVLForConditionalGeneration",
"KimiK25ForConditionalGeneration",
"YoutuForCausalLM",
Expand All @@ -224,7 +226,7 @@ def __init__(self, *args, **kwargs):
self.origin_hf_arch = hparams.get('architectures', [None])[0]

# special handling for Deepseek OCR
if self.origin_hf_arch in ("DeepseekOCRForCausalLM", "DeepseekOCR2ForCausalLM"):
if self.origin_hf_arch in ("DeepseekOCRForCausalLM", "DeepseekOCR2ForCausalLM", "UnlimitedOCRForCausalLM"):
self.model_arch = gguf.MODEL_ARCH.DEEPSEEK2OCR
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
self.gguf_writer.add_architecture()
Expand Down Expand Up @@ -350,6 +352,12 @@ def set_gguf_parameters(self):

self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])

# Unlimited-OCR R-SWA sliding window; the deepseek2-ocr decoder reads it
if is_ocr:
sliding_window = hparams.get("sliding_window_size") or hparams.get("sliding_window")
if sliding_window:
self.gguf_writer.add_sliding_window(sliding_window)

if (rope_mscale_all := self.rope_parameters.get("mscale_all_dim")) is not None:
# [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
# note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul
Expand Down
5 changes: 4 additions & 1 deletion src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ static void print_mask(const T * data, int64_t n_tokens, int64_t n_kv, int64_t n
case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
case LLAMA_SWA_TYPE_REFERENCE: swa_type_str = "LLAMA_SWA_TYPE_REFERENCE"; break;
};

LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swa_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
Expand Down Expand Up @@ -2285,7 +2286,9 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);

{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
// REFERENCE masks within this single cache; other SWA types need iswa
GGML_ASSERT((hparams.swa_type == LLAMA_SWA_TYPE_NONE ||
hparams.swa_type == LLAMA_SWA_TYPE_REFERENCE) && "Use llama_kv_cache_iswa for SWA");

inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
Expand Down
14 changes: 13 additions & 1 deletion src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ enum llama_swa_type {
LLAMA_SWA_TYPE_STANDARD = 1,
LLAMA_SWA_TYPE_CHUNKED = 2,
LLAMA_SWA_TYPE_SYMMETRIC = 3,
LLAMA_SWA_TYPE_REFERENCE = 4, // R-SWA: always-visible prefix + window over the rest
};

// forward declaration; full definition in llama-graph.h
Expand Down Expand Up @@ -357,7 +358,8 @@ struct llama_hparams {
// note: inlined on purpose for performance reasons
// TODO: think of a better place for this function
// TODO: pack the SWA params in a struct?
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
// n_ref = R-SWA prefix length L_m (always-visible positions); < 0 = unlatched, full causal
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1, llama_pos n_ref = -1) {
assert(p0 >= 0 && p1 >= 0);

switch (swa_type) {
Expand Down Expand Up @@ -388,6 +390,16 @@ struct llama_hparams {
return true;
}
} break;
case LLAMA_SWA_TYPE_REFERENCE:
{
// visible iff in the prefix (p0 < n_ref) or within the window (p1 - p0 < n_swa)
const bool windowed = p1 - p0 >= (int32_t) n_swa;
const bool in_prefix = n_ref < 0 || p0 < n_ref;

if (windowed && !in_prefix) {
return true;
}
} break;
}

return false;
Expand Down
99 changes: 98 additions & 1 deletion src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ llama_kv_cache::llama_kv_cache(
v_cells_impl(other ? other->v_cells_impl : std::make_shared<llama_kv_cells_vec>()),
v_cells(*v_cells_impl) {

n_ref.fill(-1);

// shared cells view the source cache's K/V tensors, so the cell count
// follows the source allocation: a fitted target can be smaller than the
// draft default and oversized views would overflow the source tensors
Expand Down Expand Up @@ -377,6 +379,8 @@ llama_kv_cache::llama_kv_cache(
}

void llama_kv_cache::clear(bool data) {
n_ref.fill(-1);

for (uint32_t s = 0; s < n_stream; ++s) {
v_cells[s].reset();
v_heads[s] = 0;
Expand Down Expand Up @@ -405,6 +409,15 @@ bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
p1 = std::numeric_limits<llama_pos>::max();
}

// dropping from pos 0 invalidates the latched prefix
if (p0 == 0) {
if (seq_id >= 0) {
n_ref[seq_id] = -1;
} else {
n_ref.fill(-1);
}
}

if (seq_id >= 0) {
auto & cells = v_cells[seq_to_stream[seq_id]];
auto & head = v_heads[seq_to_stream[seq_id]];
Expand Down Expand Up @@ -466,6 +479,9 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());

// copy inherits the latched prefix
n_ref[seq_id_dst] = n_ref[seq_id_src];

const auto s0 = seq_to_stream[seq_id_src];
const auto s1 = seq_to_stream[seq_id_dst];

Expand Down Expand Up @@ -557,6 +573,13 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) {

GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());

// other seqs are purged -> drop their latched prefix
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
if ((llama_seq_id) s != seq_id) {
n_ref[s] = -1;
}
}

auto & cells = v_cells[seq_to_stream[seq_id]];
auto & head = v_heads[seq_to_stream[seq_id]];

Expand Down Expand Up @@ -607,6 +630,11 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll
return;
}

// the prefix boundary is an absolute pos -> shift it with its cells
if (n_ref[seq_id] >= 0 && n_ref[seq_id] >= p0 && n_ref[seq_id] < p1) {
n_ref[seq_id] += shift;
}

for (uint32_t i = 0; i < cells.size(); ++i) {
if (!cells.pos_in(i, p0, p1)) {
continue;
Expand Down Expand Up @@ -654,6 +682,11 @@ void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, in
return;
}

// the prefix boundary is an absolute pos -> divide it with its cells
if (n_ref[seq_id] >= 0 && n_ref[seq_id] >= p0 && n_ref[seq_id] < p1) {
n_ref[seq_id] /= d;
}

for (uint32_t i = 0; i < cells.size(); ++i) {
if (!cells.pos_in(i, p0, p1)) {
continue;
Expand Down Expand Up @@ -1109,6 +1142,22 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
return;
}

// latch L_m at the prefill->decode boundary (first single-token append to a populated
// seq); until then the mask is full causal. assumes single-token decode (mtmd-cli/server).
if (swa_type == LLAMA_SWA_TYPE_REFERENCE) {
uint32_t n_tok_seq[LLAMA_MAX_SEQ] = { 0 };
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
n_tok_seq[ubatch.seq_id[i][0]]++;
}
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
const llama_seq_id seq_id = ubatch.seq_id[i][0];
if (n_ref[seq_id] < 0 && n_tok_seq[seq_id] == 1 &&
v_cells[seq_to_stream[seq_id]].seq_pos_max(seq_id) >= 0) {
n_ref[seq_id] = ubatch.pos[i];
}
}
}

// keep track of the max sequence position that we would overwrite with this ubatch
// for non-SWA cache, this would be always empty
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
Expand Down Expand Up @@ -1519,6 +1568,9 @@ struct args_set_input_kq_mask {
uint32_t n_swa;
llama_swa_type swa_type;

// per-seq R-SWA prefix length L_m (-1 = unlatched), indexed by seq_id
const llama_pos * n_ref;

int64_t n_kv;
int64_t n_stream;
int64_t n_tps;
Expand Down Expand Up @@ -1654,7 +1706,7 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data

// apply SWA if any
if (swa) {
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1, args.n_ref[seq_id])) {
goto skip;
}
}
Expand Down Expand Up @@ -1734,6 +1786,7 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
/*.seq_to_stream =*/ seq_to_stream,
/*.n_swa =*/ n_swa,
/*.swa_type =*/ swa_type,
/*.n_ref =*/ n_ref.data(),
/*.n_kv =*/ n_kv,
/*.n_stream =*/ n_stream,
/*.n_tps =*/ n_tps,
Expand Down Expand Up @@ -1960,6 +2013,20 @@ void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, lla

io.write(&n_stream, sizeof(n_stream));

// persist n_ref; REFERENCE-guarded so other cache types' state format is unchanged.
// whole-cache case is count-prefixed to tolerate a different n_seq_max on restore.
if (swa_type == LLAMA_SWA_TYPE_REFERENCE) {
if (seq_id == -1) {
const uint32_t n_ref_count = n_seq_max;
io.write(&n_ref_count, sizeof(n_ref_count));
for (uint32_t i = 0; i < n_ref_count; ++i) {
io.write(&n_ref[i], sizeof(llama_pos));
}
} else {
io.write(&n_ref[seq_id], sizeof(llama_pos));
}
}

for (uint32_t s = 0; s < n_stream; ++s) {
cell_ranges_t cr { s, {} };

Expand Down Expand Up @@ -2036,6 +2103,25 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama
throw std::runtime_error("n_stream mismatch");
}

// read n_ref now but apply after the restore below; clear()/seq_rm() would reset it
std::array<llama_pos, LLAMA_MAX_SEQ> n_ref_restored;
n_ref_restored.fill(-1);
if (swa_type == LLAMA_SWA_TYPE_REFERENCE) {
if (seq_id == -1) {
uint32_t n_ref_count = 0;
io.read(&n_ref_count, sizeof(n_ref_count));
for (uint32_t i = 0; i < n_ref_count; ++i) {
llama_pos v;
io.read(&v, sizeof(v));
if (i < n_seq_max) {
n_ref_restored[i] = v;
}
}
} else {
io.read(&n_ref_restored[seq_id], sizeof(llama_pos));
}
}

for (uint32_t s = 0; s < n_stream; ++s) {
uint32_t cell_count;
io.read(&cell_count, sizeof(cell_count));
Expand All @@ -2061,6 +2147,17 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama
throw std::runtime_error("failed to restore kv cache");
}
}

// cells restored -> reinstate n_ref
if (swa_type == LLAMA_SWA_TYPE_REFERENCE) {
if (seq_id == -1) {
for (uint32_t i = 0; i < n_seq_max; ++i) {
n_ref[i] = n_ref_restored[i];
}
} else {
n_ref[seq_id] = n_ref_restored[seq_id];
}
}
}

void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
Expand Down
4 changes: 4 additions & 0 deletions src/llama-kv-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,10 @@ class llama_kv_cache : public llama_memory_i {
// this is the SWA type of the cache - not to be confused with the model SWA type
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;

// R-SWA per-seq prefix length L_m (-1 = unlatched -> full causal mask); latched at the
// prefill->decode boundary in apply_ubatch, read by set_input_kq_mask.
std::array<llama_pos, LLAMA_MAX_SEQ> n_ref;

// ggml contexts for the KV cache along with the allocated backend buffers:
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;

Expand Down
31 changes: 31 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2023,6 +2023,37 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
{
res = nullptr;
} break;
case LLM_ARCH_DEEPSEEK2OCR:
{
// R-SWA runs on one full cache - the REFERENCE mask keeps the
// prefix visible, so no eviction and no iswa.
//
// The V cache must be F32. This OCR decoder reads dense layout
// (e.g. tables) by attending over the always-visible visual
// prefix; an F16 V cache truncates those value vectors enough to
// garble the output (table headers come out as "&quot;"). The HF
// reference accumulates attention in F32, so match it here. We
// promote the F16 default to F32 by default while still honoring
// an explicit lower-precision -ctv (e.g. q8_0). See PR #24975.
const ggml_type type_v = params.type_v == GGML_TYPE_F16 ? GGML_TYPE_F32 : params.type_v;
res = new llama_kv_cache(
*this,
hparams,
params.type_k,
type_v,
!cparams.flash_attn,
cparams.offload_kqv,
cparams.kv_unified,
cparams.n_ctx_seq,
cparams.n_seq_max,
1,
hparams.n_swa,
hparams.swa_type,
nullptr,
nullptr,
nullptr,
nullptr);
} break;
case LLM_ARCH_DEEPSEEK32:
{
res = new llama_kv_cache_dsa(
Expand Down
6 changes: 6 additions & 0 deletions src/models/deepseek2ocr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ void llama_model_deepseek2ocr::load_arch_hparams(llama_model_loader & ml) {
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX;
}

// Unlimited-OCR sets sliding_window -> R-SWA
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
if (hparams.n_swa > 0) {
hparams.swa_type = LLAMA_SWA_TYPE_REFERENCE;
}

switch (hparams.n_layer()) {
case 12: type = LLM_TYPE_3B; break;
default: type = LLM_TYPE_UNKNOWN;
Expand Down
Loading