Skip to content

llama + spec: MTP Support #22673

Open
am17an wants to merge 5 commits into
ggml-org:masterfrom
am17an:mtp-clean
Open

llama + spec: MTP Support #22673
am17an wants to merge 5 commits into
ggml-org:masterfrom
am17an:mtp-clean

Conversation

@am17an
Copy link
Copy Markdown
Contributor

@am17an am17an commented May 4, 2026

Overview

This PR adds support for MTP (Multi Token Prediction) heads. I tested this on Qwen3.6 27B and Qwen3.6 35BA3B but in principle it should work for any MTP model. I've posted the detailed results below, but typically I see a steady-state acceptance of around 75% with 3 draft tokens, which is more than >2x speed-up over baseline. The design decisions I took to get to this stage are as follows:

Next Steps

Performance

A simple bench for testing various prompts is here: https://gist.github.com/am17an/228edfb84ed082aa88e3865d6fa27090. Posting the results below:

Performance on DGX Spark 🧵

No MTP (baseline)

./llama-server -m ../qwen3.6-q8_0.gguf -np 1 --chat-template-kwargs "{\"preserve_thinking\": true}"

  code_python        pred= 192 draft=   0 acc=   0 rate=n/a tok/s=7.0
  code_cpp           pred= 192 draft=   0 acc=   0 rate=n/a tok/s=7.3
  explain_concept    pred= 192 draft=   0 acc=   0 rate=n/a tok/s=7.3
  summarize          pred=  53 draft=   0 acc=   0 rate=n/a tok/s=7.1
  qa_factual         pred= 177 draft=   0 acc=   0 rate=n/a tok/s=7.0
  translation        pred=  22 draft=   0 acc=   0 rate=n/a tok/s=7.7
  creative_short     pred= 192 draft=   0 acc=   0 rate=n/a tok/s=7.1
  stepwise_math      pred= 192 draft=   0 acc=   0 rate=n/a tok/s=7.2
  long_code_review   pred= 192 draft=   0 acc=   0 rate=n/a tok/s=7.0

Aggregate: {
  "n_requests": 9,
  "total_predicted": 1404,
  "total_draft": 0,
  "total_draft_accepted": 0,
  "aggregate_accept_rate": null,
  "wall_s_total": 201.07
}

MTP --spec-draft-max-n 3

./llama-server -m ../qwen3.6-q8_0-mtp.gguf -np 1 --chat-template-kwargs "{\"preserve_thinking\": true}" --spec-type mtp --spec-draft-n-max 3

  code_python        pred= 192 draft= 153 acc= 139 rate=0.908 tok/s=21.6
  code_cpp           pred= 192 draft= 176 acc= 132 rate=0.750 tok/s=18.7
  explain_concept    pred= 192 draft= 191 acc= 126 rate=0.660 tok/s=16.3
  summarize          pred=  55 draft=  51 acc=  37 rate=0.726 tok/s=17.9
  qa_factual         pred= 177 draft= 174 acc= 118 rate=0.678 tok/s=16.5
  translation        pred=  22 draft=  24 acc=  13 rate=0.542 tok/s=13.9
  creative_short     pred= 192 draft= 200 acc= 123 rate=0.615 tok/s=15.8
  stepwise_math      pred= 192 draft= 171 acc= 133 rate=0.778 tok/s=19.3
  long_code_review   pred= 192 draft= 179 acc= 131 rate=0.732 tok/s=18.0

Aggregate: {
  "n_requests": 9,
  "total_predicted": 1406,
  "total_draft": 1319,
  "total_draft_accepted": 952,
  "aggregate_accept_rate": 0.7218,
  "wall_s_total": 83.8
}

MTP --spec-draft-max-n 2

./llama-server -m ../qwen3.6-q8_0-mtp.gguf -np 1 --chat-template-kwargs "{\"preserve_thinking\": true}" --spec-type mtp --spec-draft-n-max 2

  code_python        pred= 192 draft= 134 acc= 123 rate=0.918 tok/s=17.4
  code_cpp           pred= 192 draft= 145 acc= 118 rate=0.814 tok/s=16.5
  explain_concept    pred= 192 draft= 148 acc= 116 rate=0.784 tok/s=16.1
  summarize          pred=  55 draft=  44 acc=  32 rate=0.727 tok/s=15.6
  qa_factual         pred= 192 draft= 132 acc= 125 rate=0.947 tok/s=18.2
  translation        pred=  22 draft=  18 acc=  12 rate=0.667 tok/s=15.2
  creative_short     pred= 192 draft= 149 acc= 116 rate=0.778 tok/s=16.1
  stepwise_math      pred= 192 draft= 139 acc= 121 rate=0.871 tok/s=17.2
  long_code_review   pred= 192 draft= 153 acc= 114 rate=0.745 tok/s=15.6

Aggregate: {
  "n_requests": 9,
  "total_predicted": 1421,
  "total_draft": 1062,
  "total_draft_accepted": 877,
  "aggregate_accept_rate": 0.8258,
  "wall_s_total": 90.44
}

Draft model (Qwen3.5 0.8B) with spec-draft-n-max 16 with partial rollback

llama-server -m ../qwen3.6/Qwen3.6-27B-Q8_0.gguf -hfd unsloth/Qwen3.5-0.8B-GGUF:Q8_0 --spec-draft-n-max 16 -np 1 --chat-template-kwargs "{\"preserve_thinking\": true}"

  code_python        pred= 192 draft= 188 acc= 156 rate=0.830 tok/s=26.4
  code_cpp           pred= 192 draft= 201 acc= 126 rate=0.627 tok/s=16.8
  explain_concept    pred= 192 draft= 263 acc= 112 rate=0.426 tok/s=12.7
  summarize          pred=  57 draft=  63 acc=  39 rate=0.619 tok/s=16.9
  qa_factual         pred= 192 draft= 178 acc= 177 rate=0.994 tok/s=47.7
  translation        pred=  23 draft=  18 acc=  15 rate=0.833 tok/s=18.7
  creative_short     pred= 192 draft= 189 acc= 120 rate=0.635 tok/s=15.4
  stepwise_math      pred= 192 draft= 190 acc= 148 rate=0.779 tok/s=22.3
  long_code_review   pred= 192 draft= 207 acc= 120 rate=0.580 tok/s=14.5

Aggregate: {
  "n_requests": 9,
  "total_predicted": 1424,
  "total_draft": 1497,
  "total_draft_accepted": 1013,
  "aggregate_accept_rate": 0.6767,
  "wall_s_total": 81.39
}

Master with draft model with spec-draft-n-max 64 with no partial rollback

llama-server -m ../qwen3.6/Qwen3.6-27B-Q8_0.gguf -hfd unsloth/Qwen3.5-0.8B-GGUF:Q8_0 --spec-draft-n-max 64 -np 1 --chat-template-kwargs "{\"preserve_thinking\": true}"

  code_python        pred= 192 draft= 174 acc= 159 rate=0.914 tok/s=27.2
  code_cpp           pred= 192 draft= 138 acc= 120 rate=0.870 tok/s=15.0
  explain_concept    pred= 192 draft= 170 acc= 101 rate=0.594 tok/s=11.4
  summarize          pred=  55 draft=  48 acc=  36 rate=0.750 tok/s=14.6
  qa_factual         pred= 177 draft= 126 acc= 106 rate=0.841 tok/s=13.9
  translation        pred=  22 draft=  13 acc=  13 rate=1.000 tok/s=16.5
  creative_short     pred= 192 draft= 136 acc= 104 rate=0.765 tok/s=12.8
  stepwise_math      pred= 192 draft= 172 acc= 147 rate=0.855 tok/s=22.0
  long_code_review   pred= 192 draft= 160 acc= 111 rate=0.694 tok/s=13.0

Aggregate: {
  "n_requests": 9,
  "total_predicted": 1406,
  "total_draft": 1137,
  "total_draft_accepted": 897,
  "aggregate_accept_rate": 0.7889,
  "wall_s_total": 97.13
}

How to use

I've uploaded the GGUF which I made by using the convert_hf_to_gguf.py changes in this PR. Here is another GGUF for the MoE (35BA3B) model

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: Yes, for debugging and reviewing. Also the convert_hf_to_gguf.py + model definitions. Writing bench for validation against vLLM.

@github-actions github-actions Bot added model Model specific testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs Vulkan Issues specific to the Vulkan backend examples python python script changes server ggml changes relating to the ggml tensor library for machine learning labels May 4, 2026
@ngxson
Copy link
Copy Markdown
Contributor

ngxson commented May 4, 2026

Nice, I think this is a fresh start better than my WIP #18886 (that I still never find the time to continue)

There were some other attempts to add MTP support but they all heavily rely on host <--> device data copy. I assume you tried addressed this, right? (Maybe there was a discussion somewhere but I wasn't aware of)

Copy link
Copy Markdown
Contributor

@ngxson ngxson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(not a review, but opening some discussions)

Comment thread src/llama-memory-recurrent.h Outdated
Comment thread src/models/qwen35.cpp Outdated

for (int il = 0; il < n_layer; ++il) {
// MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass.
const int n_transformer_layers = n_layer - (int)hparams.nextn_predict_layers;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nits, but maybe call it n_main_layers, as technically nextn layer is also a transformer layer

Comment thread tools/server/server-context.cpp Outdated
Comment on lines +811 to +823
//TODO: generalize if this is ok, we should load <arch_name>_mtp arch?
if (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_MTP) {
SRV_INF("loading MTP head from '%s' (override_arch=qwen35_mtp)\n",
params_base.model.path.c_str());

auto mparams_mtp = common_model_params_to_llama(params_base);
mparams_mtp.override_arch = "qwen35_mtp";

model_mtp.reset(llama_model_load_from_file(params_base.model.path.c_str(), mparams_mtp));
if (model_mtp == nullptr) {
SRV_ERR("failed to load MTP head from '%s'\n", params_base.model.path.c_str());
return false;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you look at #18886, the better way is to move llama_graph_type to the public API, then load the context with the appropriate graph type

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that seems like the correct way to do this if we want to support MTP in a generic way

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented May 4, 2026

@ngxson yes the h2d was discussed with GG, he's working on a refactor which will allow us to share tensors between two llama context

@pwilkin
Copy link
Copy Markdown
Member

pwilkin commented May 4, 2026

Great work, this should massively bridge the TG gap with vLLM, or maybe even surpass it together with tensor-parallel.

@cmp-nct
Copy link
Copy Markdown
Contributor

cmp-nct commented May 4, 2026

in my opinion Qwen 3.6 is the most important thing that happened in open source models in a long time, this is going to be so valuable.
I wonder if this, once merged, could be combined with ngram drafting ?
So MTP is used until ngram is triggered - switching to ngram until rejection and back to MTP

ngram could be set to match only very strong and long candidates - for large repetitive paraphrasing
and MTP fills the gap

@Dampfinchen
Copy link
Copy Markdown

Dampfinchen commented May 4, 2026

" idea is that MTP should automatically start and we shouldn't need to distribute the MTP gguf separately but also it has it's own context/kv-cache etc." -> Does this mean MTP needs additional resources (RAM/VRAM?)

If so, there should always be an option to remain to disable it. Right now on my system (6 GB VRAM, 32 GB RAM), speculative decoding just makes things much slower even on very small draft models because of that exact reason, they need own context and kv-cache. Such low to midrange systems already operate on the edge in terms of memory.

@mbednarek360
Copy link
Copy Markdown

I'm getting garbage responses running this PR on the Vulkan backend with an R9700 using llama-server. I'm using the GGUF you linked above. Interestingly, draft acceptance is only 0.01282.

Prompt: "Hello!"
Response:

The from,

;::...

... on;srible威风to{ islitor

\ ...

• We
&eq和chn ***, on
Prompt (:
mouth

“ ? forM� P 

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented May 4, 2026

@cmp-nct I'm not sure, but could be possible

@Dampfinchen as of right now it is opt-in via --spec-type mtp, but in terms of memory it should be < 10% of overall memory used (it's just a single layer transformer + kv cache, much lighter than draft models)

@mbednarek360 I've only tested this on a small number of CUDA devices as of now, once it's ready to review I would have tested more devices/backends. In particular this PR relies on #22400 which is not implemented for vulkan for now, if you ask an LLM to add support for that you might get a little further Vulkan and Metal also tested

@nawoa
Copy link
Copy Markdown

nawoa commented May 4, 2026

Might it be possible/useful to run the draft model on a second GPU? Given that MTP weights model are relatively small this might provide a useful speedup on systems with a dedicated high-VRAM "AI" GPU with a cheaper low-VRAM "normal" GPU used for display output, etc... possibly prevent some degree of resource contention.

@cturan
Copy link
Copy Markdown

cturan commented May 4, 2026

Thank you, we are eagerly awaiting this to become stable, here automated test results for my machine;

__
Qwen3.6-27B Q6_K benchmark on llama.cpp b9025-10829dbcc / PR #22673 branch
Hardware: RTX 3090 24GB + RTX 3060 12GB
Runtime flags: -fa on -c 10000 -np 1 -ngl 99 --no-mmap --no-cache-prompt
Endpoint: /completion, raw text prompt
Prompt: 6978 tokens
Generation: 256 tokens
Runs: 3 measured runs after warmup

mode model prefill tok/s avg generation tok/s avg MTP acceptance loaded VRAM
MTP enabled Qwen3.6-27B-MTP-Q6_K.gguf + --spec-type mtp --spec-draft-n-max 3 665.14 42.45 76.0% 24.96 GiB
MTP disabled, same GGUF Qwen3.6-27B-MTP-Q6_K.gguf, no spec 1315.46 22.97 n/a 22.47 GiB
Existing non-MTP Q6 Qwen3.6-27B-Q6_K.gguf, no spec 1260.12 22.39 n/a 22.59 GiB

Result:

  • MTP improves decode from 22.97 tok/s to 42.45 tok/s on the same GGUF: ~1.85x speedup.
  • Against the existing non-MTP Q6 file, decode improves from 22.39 tok/s to 42.45 tok/s: ~1.90x speedup.
  • Prefill is slower with MTP enabled in this PR path: 665 tok/s vs 1315 tok/s on the same GGUF (~0.51x).
  • MTP adds about 2.49 GiB loaded VRAM in this setup.

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented May 4, 2026

@cturan Thanks for testing, I'm aware of the issue for the prefill and will work on a fix.

@iiLaurens
Copy link
Copy Markdown

Might be a long shot, but any chance of supporting MTP with a reduced vocabulary? MTP layers are rather chonky and reducing token embeddings might help users with less VRAM by filtering out certain languages. Obviously the full model will still be able to produce those tokens if need be so it won't be gimped.

@nybblr
Copy link
Copy Markdown

nybblr commented May 4, 2026

Working on taking this for a spin with the Q4_K_M quant of Qwen3.6-35BA3B. I was gonna try to start from unsloth's quant since they already perform really well, but of course they don't have any mtp layers.

@am17an Think it would work if I just "steal" the layers from your q8 quant and merge them into the unsloth quant? (add blk.40 and bump some top-level config like block_count and kv_count)

@volkermauel
Copy link
Copy Markdown

only a quick test run, 1x 5090 qwen3.6-27b mtp 3, q4_0 quantized, kv also q4_0

slot launch_slot_: id  0 | task -1 | sampler chain: logits -> penalties -> ?dry -> ?top-n-sigma -> top-k -> ?typical -> top-p -> min-p -> ?xtc -> ?temp-ext -> dist
slot launch_slot_: id  0 | task 532 | processing task, is_child = 0
slot update_slots: id  0 | task 532 | new prompt, n_ctx_slot = 200192, n_keep = 0, task.n_tokens = 16
slot update_slots: id  0 | task 532 | n_past = 3, slot.prompt.tokens.size() = 1327, seq_id = 0, pos_min = 1326, n_swa = 0
slot update_slots: id  0 | task 532 | forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
slot update_slots: id  0 | task 532 | n_tokens = 0, memory_seq_rm [0, end)
srv  log_server_r: done request: POST /v1/chat/completions 192.168.178.49 200
slot update_slots: id  0 | task 532 | prompt processing progress, n_tokens = 12, batch.n_tokens = 12, progress = 0.750000
slot update_slots: id  0 | task 532 | n_tokens = 12, memory_seq_rm [12, end)
slot init_sampler: id  0 | task 532 | init sampler, took 0.01 ms, tokens: text = 16, total = 16
slot update_slots: id  0 | task 532 | prompt processing done, n_tokens = 16, batch.n_tokens = 4
slot print_timing: id  0 | task 532 |
prompt eval time =������63.16 ms /����16 tokens (����3.95 ms per token,   253.34 tokens per second)
�������eval time =   56063.04 ms /  5913 tokens (����9.48 ms per token,   105.47 tokens per second)
������total time =   56126.20 ms /  5929 tokens
draft acceptance rate = 0.79728 ( 4169 accepted /  5229 generated)
statistics mtp: #calls(b,g,a) = 2 2272 1976, #gen drafts = 2272, #acc drafts = 1976, #gen tokens = 6816, #acc tokens = 4950, dur(b,g,a) = 0.007, 15393.656, 64.921 ms
slot������release: id  0 | task 532 | stop processing: n_tokens = 5928, truncated = 0
srv  update_slots: all slots are idle

same model, same config (except mtp)

slot update_slots: id  0 | task 0 | prompt processing done, n_tokens = 16, batch.n_tokens = 4
slot print_timing: id  0 | task 0 | 
prompt eval time =      91.85 ms /    16 tokens (    5.74 ms per token,   174.20 tokens per second)
       eval time =  103127.94 ms /  6571 tokens (   15.69 ms per token,    63.72 tokens per second)
      total time =  103219.79 ms /  6587 tokens
slot      release: id  0 | task 0 | stop processing: n_tokens = 6586, truncated = 0
srv  update_slots: all slots are idle

prompt „create a flappy bird clone“

(I‘m not creative, sorry)

Great Speedup!

@alexandrupetraru
Copy link
Copy Markdown

this is a game changer, on Strix Halo with the q8 Qwen 3.6 35B3A jumping from 40 to 70 tg at low context and for the 27B from 12 to 25 tg(with layer split 7900 xtx and strix halo 50,50) for coding. We need this one to master asap together with turbo4, it performs very well and without any issues. Good job

@GloballyUniquePlaceholder
Copy link
Copy Markdown

On a 3060 Laptop 6GB vram + 64GB ram running your provided Qwen 3.6 35A3B gguf there is a reasonable speed up.

spec-draft-n-max average tk\s wall_s_total aggregate_accept_rate
n/a - no mtp 22.92 77.69 n/a
1 27.58 68.34 0.8835
2 29.39 66.00 0.815
3 27.78 67.96 0.7127
4 26.09 72.23 0.6421
raw results

spec-draft-n-max 4

llama.cpp\build\bin\Release\llama-server.exe -fa on -c 5000 -np 1 -fit on -m Qwen3.6-35BA3B-MTP.gguf --chat-template-kwargs "{\"preserve_thinking\": true}" --spec-type mtp --spec-draft-n-max 4

python mtp-bench.py
  code_python        pred= 192 draft= 180 acc= 146 rate=0.811 tok/s=31.3
  code_cpp           pred= 192 draft= 216 acc= 136 rate=0.630 tok/s=22.7
  explain_concept    pred= 192 draft= 224 acc= 134 rate=0.598 tok/s=22.3
  summarize          pred=  53 draft=  52 acc=  39 rate=0.750 tok/s=33.3
  qa_factual         pred= 192 draft= 196 acc= 141 rate=0.719 tok/s=29.2
  translation        pred=  22 draft=  32 acc=  13 rate=0.406 tok/s=19.4
  creative_short     pred= 192 draft= 264 acc= 124 rate=0.470 tok/s=20.7
  stepwise_math      pred= 192 draft= 192 acc= 143 rate=0.745 tok/s=30.7
  long_code_review   pred= 192 draft= 220 acc= 136 rate=0.618 tok/s=25.2

Aggregate: {
  "n_requests": 9,
  "total_predicted": 1419,
  "total_draft": 1576,
  "total_draft_accepted": 1012,
  "aggregate_accept_rate": 0.6421,
  "wall_s_total": 72.23
}

spec-draft-n-max 3

llama.cpp\build\bin\Release\llama-server.exe -fa on -c 5000 -np 1 -fit on -m Qwen3.6-35BA3B-MTP.gguf --chat-template-kwargs "{\"preserve_thinking\": true}" --spec-type mtp --spec-draft-n-max 3

python mtp-bench.py
  code_python        pred= 192 draft= 165 acc= 136 rate=0.824 tok/s=30.2
  code_cpp           pred= 192 draft= 168 acc= 135 rate=0.804 tok/s=27.6
  explain_concept    pred= 192 draft= 189 acc= 128 rate=0.677 tok/s=25.3
  summarize          pred=  53 draft=  48 acc=  36 rate=0.750 tok/s=32.5
  qa_factual         pred= 192 draft= 180 acc= 131 rate=0.728 tok/s=29.2
  translation        pred=  22 draft=  24 acc=  13 rate=0.542 tok/s=24.5
  creative_short     pred= 192 draft= 210 acc= 120 rate=0.571 tok/s=23.2
  stepwise_math      pred= 192 draft= 174 acc= 133 rate=0.764 tok/s=30.5
  long_code_review   pred= 192 draft= 189 acc= 128 rate=0.677 tok/s=27.2

Aggregate: {
  "n_requests": 9,
  "total_predicted": 1419,
  "total_draft": 1347,
  "total_draft_accepted": 960,
  "aggregate_accept_rate": 0.7127,
  "wall_s_total": 67.96
}

spec-draft-n-max 2

llama.cpp\build\bin\Release\llama-server.exe -fa on -c 5000 -np 1 -fit on -m Qwen3.6-35BA3B-MTP.gguf --chat-template-kwargs "{\"preserve_thinking\": true}" --spec-type mtp --spec-draft-n-max 2

python mtp-bench.py
  code_python        pred= 192 draft= 132 acc= 125 rate=0.947 tok/s=31.5
  code_cpp           pred= 192 draft= 140 acc= 120 rate=0.857 tok/s=27.0
  explain_concept    pred= 192 draft= 152 acc= 114 rate=0.750 tok/s=25.6
  summarize          pred=  53 draft=  40 acc=  32 rate=0.800 tok/s=32.2
  qa_factual         pred= 192 draft= 144 acc= 119 rate=0.826 tok/s=31.1
  translation        pred=  22 draft=  16 acc=  13 rate=0.812 tok/s=30.8
  creative_short     pred= 192 draft= 156 acc= 113 rate=0.724 tok/s=25.9
  stepwise_math      pred= 192 draft= 144 acc= 119 rate=0.826 tok/s=31.3
  long_code_review   pred= 192 draft= 146 acc= 117 rate=0.801 tok/s=29.1

Aggregate: {
  "n_requests": 9,
  "total_predicted": 1419,
  "total_draft": 1070,
  "total_draft_accepted": 872,
  "aggregate_accept_rate": 0.815,
  "wall_s_total": 66.0
}

spec-draft-n-max 1

llama.cpp\build\bin\Release\llama-server.exe -fa on -c 5000 -np 1 -fit on -m Qwen3.6-35BA3B-MTP.gguf --chat-template-kwargs "{\"preserve_thinking\": true}" --spec-type mtp --spec-draft-n-max 1

python mtp-bench.py
  code_python        pred= 192 draft=  96 acc=  94 rate=0.979 tok/s=28.3
  code_cpp           pred= 192 draft= 100 acc=  90 rate=0.900 tok/s=26.2
  explain_concept    pred= 192 draft= 102 acc=  89 rate=0.873 tok/s=25.9
  summarize          pred=  56 draft=  29 acc=  26 rate=0.897 tok/s=30.6
  qa_factual         pred= 192 draft= 100 acc=  90 rate=0.900 tok/s=28.5
  translation        pred=  22 draft=  12 acc=   9 rate=0.750 tok/s=27.0
  creative_short     pred= 192 draft= 104 acc=  86 rate=0.827 tok/s=24.9
  stepwise_math      pred= 192 draft= 102 acc=  88 rate=0.863 tok/s=28.7
  long_code_review   pred= 192 draft= 102 acc=  88 rate=0.863 tok/s=28.1

Aggregate: {
  "n_requests": 9,
  "total_predicted": 1422,
  "total_draft": 747,
  "total_draft_accepted": 660,
  "aggregate_accept_rate": 0.8835,
  "wall_s_total": 68.34
}

no mtp

llama.cpp\build\bin\Release\llama-server.exe -fa on -c 5000 -np 1 -fit on -m Qwen3.6-35BA3B-MTP.gguf --chat-template-kwargs "{\"preserve_thinking\": true}"

python mtp-bench.py
  code_python        pred= 192 draft=   0 acc=   0 rate=n/a tok/s=22.2
  code_cpp           pred= 192 draft=   0 acc=   0 rate=n/a tok/s=22.1
  explain_concept    pred= 192 draft=   0 acc=   0 rate=n/a tok/s=22.1
  summarize          pred=  53 draft=   0 acc=   0 rate=n/a tok/s=25.9
  qa_factual         pred= 192 draft=   0 acc=   0 rate=n/a tok/s=22.1
  translation        pred=  22 draft=   0 acc=   0 rate=n/a tok/s=22.3
  creative_short     pred= 192 draft=   0 acc=   0 rate=n/a tok/s=21.4
  stepwise_math      pred= 192 draft=   0 acc=   0 rate=n/a tok/s=24.0
  long_code_review   pred= 192 draft=   0 acc=   0 rate=n/a tok/s=24.2

Aggregate: {
  "n_requests": 9,
  "total_predicted": 1419,
  "total_draft": 0,
  "total_draft_accepted": 0,
  "aggregate_accept_rate": null,
  "wall_s_total": 77.69
}

@ninjas28
Copy link
Copy Markdown

ninjas28 commented May 5, 2026

Crashes when using -sm tensor with llama-server launch command args -hf am17an/Qwen3.6-27B-MTP-GGUF:Q8_0 -sm tensor -np 1 --chat-template-kwargs "{\"preserve_thinking\": true}" --spec-type mtp --spec-draft-n-max 3. Using -sm tensor without MTP works fine. This is on a triple GPU setup using ROCm.

srv  params_from_: Chat format: peg-native
slot get_availabl: id  0 | task -1 | selected slot by LRU, t_last = -1
srv  get_availabl: updating prompt cache
srv          load:  - looking for better prompt, base f_keep = -1.000, sim = 0.000
srv        update:  - cache state: 0 prompts, 0.000 MiB (limits: 8192.000 MiB, 262144 tokens, 8589934592 est)
srv  get_availabl: prompt cache update took 0.01 ms
slot launch_slot_: id  0 | task -1 | sampler chain: logits -> ?penalties -> ?dry -> ?top-n-sigma -> top-k -> ?typical -> top-p -> min-p -> ?xtc -> ?temp-ext -> dist 
slot launch_slot_: id  0 | task 0 | processing task, is_child = 0
slot update_slots: id  0 | task 0 | new prompt, n_ctx_slot = 262144, n_keep = 0, task.n_tokens = 356
slot update_slots: id  0 | task 0 | n_tokens = 0, memory_seq_rm [0, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_tokens = 352, batch.n_tokens = 352, progress = 0.988764
/root/llama.cpp/ggml/src/ggml-backend-meta.cpp:1013: GGML_ASSERT(split_state.ne[j] * tensor->src[i]->ne[src_ss[i].axis] == sum * tensor->ne[split_state.axis]) failed
/root/llama.cpp/build/bin/libggml-base.so.0(+0x1b25b)[0x74b4b4ca925b]
/root/llama.cpp/build/bin/libggml-base.so.0(ggml_print_backtrace+0x21f)[0x74b4b4ca96df]
/root/llama.cpp/build/bin/libggml-base.so.0(ggml_abort+0x152)[0x74b4b4ca98b2]
/root/llama.cpp/build/bin/libggml-base.so.0(+0x41506)[0x74b4b4ccf506]
/root/llama.cpp/build/bin/libggml-base.so.0(+0x3d579)[0x74b4b4ccb579]
/root/llama.cpp/build/bin/libggml-base.so.0(+0x41adb)[0x74b4b4ccfadb]
/root/llama.cpp/build/bin/libggml-base.so.0(ggml_gallocr_alloc_graph+0x474)[0x74b4b4cbff54]
/root/llama.cpp/build/bin/libggml-base.so.0(ggml_backend_sched_alloc_graph+0x111)[0x74b4b4cc6351]
/root/llama.cpp/build/bin/libllama.so.0(_ZN13llama_context14process_ubatchERK12llama_ubatch14llm_graph_typeP22llama_memory_context_iR11ggml_status+0xe8)[0x74b4b44dac08]
/root/llama.cpp/build/bin/libllama.so.0(_ZN13llama_context6decodeERK11llama_batch+0x37b)[0x74b4b44d912b]
/root/llama.cpp/build/bin/libllama.so.0(llama_decode+0x10)[0x74b4b44da780]
/root/llama.cpp/build/bin/libllama.so.0(_ZN13llama_context21handle_mtp_for_ubatchEiPKiS1_P11ggml_tensor+0x20d)[0x74b4b44da9bd]
/root/llama.cpp/build/bin/libllama.so.0(_ZN13llama_context14process_ubatchERK12llama_ubatch14llm_graph_typeP22llama_memory_context_iR11ggml_status+0x142)[0x74b4b44dac62]
/root/llama.cpp/build/bin/libllama.so.0(_ZN13llama_context6decodeERK11llama_batch+0x37b)[0x74b4b44d912b]
/root/llama.cpp/build/bin/libllama.so.0(llama_decode+0x10)[0x74b4b44da780]
llama-server(+0xf846e)[0x63c5e42c046e]
llama-server(+0x172971)[0x63c5e433a971]
llama-server(+0x5842c)[0x63c5e422042c]
/lib/x86_64-linux-gnu/libc.so.6(+0x29d90)[0x74b4b3c29d90]
/lib/x86_64-linux-gnu/libc.so.6(__libc_start_main+0x80)[0x74b4b3c29e40]
llama-server(+0x58cd5)[0x63c5e4220cd5]
Aborted```

@superjamie
Copy link
Copy Markdown

Tested on 3x RTX3060 12Gb. Sorry I don't have the VRAM for your Q8, I used RDson/Qwen3.6-27B-MTP-Q4_K_M-GGUF which was quantized with ik_llama's MTP.

Prompt: "Write a simple minimal hash table implementation in C99."

Three runs with no MTP, avg generation 18.51 tok/sec:

llama-server --model /models/RDson/Qwen3.6-27B-MTP-Q4_K_M-GGUF/Qwen3.6-27B-MTP-Q4_K_M.gguf \
 --port 8080 --host 0.0.0.0 --n-gpu-layers 999 --flash-attn on --ctx-size $((16*1024)) \
 --temp 0.6 --top-p 0.95 --presence-penalty 0.0 --top-k 20 --min-p 0.0 --repeat_penalty 1.0 \
 --no-mmproj --chat-template-kwargs '{"enable_thinking":false}'

prompt eval time =     177.62 ms /    24 tokens (    7.40 ms per token,   135.12 tokens per second)
       eval time =   99331.08 ms /  1837 tokens (   54.07 ms per token,    18.49 tokens per second)
      total time =   99508.70 ms /  1861 tokens

prompt eval time =     159.10 ms /    24 tokens (    6.63 ms per token,   150.85 tokens per second)
       eval time =  107505.42 ms /  1988 tokens (   54.08 ms per token,    18.49 tokens per second)
      total time =  107664.52 ms /  2012 tokens

prompt eval time =     158.43 ms /    24 tokens (    6.60 ms per token,   151.49 tokens per second)
       eval time =   48263.07 ms /   895 tokens (   53.93 ms per token,    18.54 tokens per second)
      total time =   48421.51 ms /   919 tokens

Three runs with MTP, avg generation 32.24 tok/sec:

llama-server --model /models/RDson/Qwen3.6-27B-MTP-Q4_K_M-GGUF/Qwen3.6-27B-MTP-Q4_K_M.gguf \
 --port 8080 --host 0.0.0.0 --n-gpu-layers 999 --flash-attn on --ctx-size $((16*1024)) \
 --temp 0.6 --top-p 0.95 --presence-penalty 0.0 --top-k 20 --min-p 0.0 --repeat_penalty 1.0 \
 --no-mmproj --chat-template-kwargs '{"enable_thinking":false}' \
 --spec-type mtp --spec-draft-n-max 3 --parallel 1

prompt eval time =     232.24 ms /    24 tokens (    9.68 ms per token,   103.34 tokens per second)
       eval time =   34610.94 ms /  1110 tokens (   31.18 ms per token,    32.07 tokens per second)
      total time =   34843.18 ms /  1134 tokens 
      
prompt eval time =     207.99 ms /    24 tokens (    8.67 ms per token,   115.39 tokens per second)
       eval time =   32110.05 ms /  1064 tokens (   30.18 ms per token,    33.14 tokens per second)
      total time =   32318.03 ms /  1088 tokens
      
prompt eval time =     208.50 ms /    24 tokens (    8.69 ms per token,   115.11 tokens per second)
       eval time =   39029.34 ms /  1230 tokens (   31.73 ms per token,    31.51 tokens per second)
      total time =   39237.84 ms /  1254 tokens 

Result 74% speedup. Wow!

Thank you for your work. You will make many users happy with this. What an exciting PR!

One small hiccup. On my initial attempt I got the error message:

load_model: MTP currently supports only n_parallel=1; got 4

Adding --parallel 1 fixed that.

@curvedinf
Copy link
Copy Markdown

curvedinf commented May 12, 2026

I am trying to recreate your Qwen3.6-35B-A3B Q4_K_M results on 1x Radeon Pro W7900 (gfx1100) but am not seeing as much of an uplift, I am getting at best 140tk/s -> 174tk/s with -ngl 999 -fa on -fit off --backend-sampling -sm none --spec-type mtp --spec-draft-n-max 2 -np 1 --spec-draft-ngl 99 on Vulkan

Note that I used --spec-draft-n-max 3. Compile options were those from the guide.

Comment thread tools/server/server-context.cpp
Comment thread convert_hf_to_gguf.py Outdated
Comment on lines +5574 to +5580
# Multimodal Qwen3.5/3.6 wrap the text model under `model.language_model.*`.
if name.startswith("model.language_model."):
name = "model." + name[len("model.language_model."):]
elif name.startswith("language_model."):
name = name[len("language_model."):]

# Remap MTP block tensors to llama.cpp's layer-indexed nextn naming.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Multimodal Qwen3.5/3.6 wrap the text model under `model.language_model.*`.
if name.startswith("model.language_model."):
name = "model." + name[len("model.language_model."):]
elif name.startswith("language_model."):
name = name[len("language_model."):]
# Remap MTP block tensors to llama.cpp's layer-indexed nextn naming.
# Remap MTP block tensors to llama.cpp's layer-indexed nextn naming.

No longer necessary (or maybe it is at Mixin stage, if so add filter_tensors?).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CISC can you review am17an#9 instead?

@rumgewieselt
Copy link
Copy Markdown

Hardware: 3x NVIDIA GTX 1080 Ti (Pascal sm_61, 11 GiB each, no NVLink, NCCL 2.22.3)

Build flags (critical for Pascal):

-DCMAKE_CUDA_ARCHITECTURES=61
-DGGML_CUDA=ON
-DGGML_CUDA_NCCL=ON
-DGGML_CUDA_NO_PEER_COPY=ON
-DGGML_CUDA_FORCE_MMQ=ON
-DGGML_CUDA_COMPRESSION_MODE=size
-DGGML_CUDA_GRAPHS=OFF
-DGGML_NATIVE=ON

Branch: mtp-clean (b9117, same as this PR)

Models: Qwen3.6-27B-MTP-UD-Q4_K_XL (16.8 GiB, dense, row-split)

Results (32K ctx, T=0):
Before:
27B -> 20.3 t/s

After:
27B -> 32.6 t/s +61%

Quality:

  • NIAH at 10/25/50/75/90% depth: all correct
  • Tool calling: correct function + args
  • Greedy determinism: 8/8 prompts byte-identical across 3 runs
  • No looping on stress tests
  • MTP draft acceptance: 60-80% on structured outputs

Notes:

  • --ctx-checkpoints 0 required on multi-GPU
  • --spec-draft-n-max 1, -np 1 mandatory
  • Flash attention not available on Pascal with this branch

MTP works well even on decade-old hardware.
Thanks a lot and great work on this. Mindblowing!

@janvitos
Copy link
Copy Markdown

janvitos commented May 12, 2026

FYI

I pulled the latest mtp-clean this morning and built from source. I then ran llama.cpp with the following parameters:

llama-server -m Qwen3.6-35B-A3B-UD-Q4_K_XL.gguf -fitt 1536 -c 131072 -n 32768 -fa on -np 1 -ctk q8_0 -ctv q8_0 -ctkd q8_0 -ctvd q8_0 -ctxcp 64 --no-mmap --mlock --no-warmup --spec-type mtp --spec-draft-n-max 3 --temp 0.6 --top-p 0.95 --top-k 20 --min-p 0.0 --presence-penalty 0.0 --repeat-penalty 1.0

I then ran mtp-bench.py and got these results:

❯ ./mtp-bench.py
  code_python        pred= 192 draft= 128 acc= 128 rate=1.000 tok/s=75.8
  code_cpp           pred=  59 draft=  42 acc=  41 rate=0.976 tok/s=76.6
  explain_concept    pred= 192 draft= 126 acc= 119 rate=0.944 tok/s=61.5
  summarize          pred=  54 draft=  37 acc=  36 rate=0.973 tok/s=59.0
  qa_factual         pred= 192 draft= 116 acc= 116 rate=1.000 tok/s=66.4
  translation        pred=  22 draft=  13 acc=  13 rate=1.000 tok/s=67.5
  creative_short     pred= 192 draft= 118 acc= 114 rate=0.966 tok/s=68.8
  stepwise_math      pred= 192 draft= 143 acc= 132 rate=0.923 tok/s=71.6
  long_code_review   pred= 192 draft= 123 acc= 116 rate=0.943 tok/s=64.6

Aggregate: {
  "n_requests": 9,
  "total_predicted": 1287,
  "total_draft": 846,
  "total_draft_accepted": 815,
  "aggregate_accept_rate": 0.9634,
  "wall_s_total": 21.6
}

With the previous mtp-clean-old, these were my mtp-bench.py results:

❯ ./mtp-bench.py
  code_python        pred= 192 draft= 159 acc= 138 rate=0.868 tok/s=80.8
  code_cpp           pred=  58 draft=  48 acc=  41 rate=0.854 tok/s=79.8
  explain_concept    pred= 192 draft= 189 acc= 127 rate=0.672 tok/s=70.8
  summarize          pred=  53 draft=  51 acc=  35 rate=0.686 tok/s=72.0
  qa_factual         pred= 192 draft= 174 acc= 133 rate=0.764 tok/s=77.8
  translation        pred=  22 draft=  24 acc=  13 rate=0.542 tok/s=67.1
  creative_short     pred= 192 draft= 213 acc= 119 rate=0.559 tok/s=63.2
  stepwise_math      pred= 192 draft= 171 acc= 134 rate=0.784 tok/s=78.9
  long_code_review   pred= 192 draft= 177 acc= 132 rate=0.746 tok/s=72.7

Aggregate: {
  "n_requests": 9,
  "total_predicted": 1285,
  "total_draft": 1206,
  "total_draft_accepted": 872,
  "aggregate_accept_rate": 0.7231,
  "wall_s_total": 20.22
}

So it's as if --spec-draft-n-max 3 is having no effect on the new mtp-clean.

EDIT: After a lot of testing, I ended up using --spec-draft-n-max 5 as the best flag.

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented May 12, 2026

@janvitos the new branch stops early if the draft tokens are of low quality

@kubakomu
Copy link
Copy Markdown

kubakomu commented May 12, 2026

(EDIT: sorry if spamming the perf results here is annoying, I saw people posting some and thought it could be useful.)

I tried running the branch mtp-clean (ebe4fca) with Qwen3.6-27B-Q5_K_M.gguf from Unsloth on 1x and 2x R9700 on Linux/Vulkan.

Server invocation:

  -c 32000 \
  -ngl 99 \
  -fa on \
  -np 1 \
  --device Vulkan1 [,Vulkan2 ] \
  --split-mode layer \
  -b 8192 \
  -ub 512 \
  --temp 0.6 \
  --top_p 0.95 \
  --top_k 20 \
  --min_p 0.0 \
  --presence_penalty 0.0 \
  --repeat_penalty 1.0 \
  --no-context-shift \
  [--spec-type mtp --spec-draft-n-max 3]

Benchmark command:
uvx llama-benchy --tg 512 --pp 4096 --base-url http://localhost:8000/v1 --model unsloth/Qwen3.6-27B --tokenizer unsloth/Qwen3.6-27B.

No-MTP, 1xR9700

model test t/s peak t/s ttfr (ms) est_ppt (ms) e2e_ttft (ms)
unsloth/Qwen3.6-27B pp4096 749.98 ± 6.95 5463.10 ± 50.06 5462.85 ± 50.06 5463.10 ± 50.06
unsloth/Qwen3.6-27B tg512 20.29 ± 0.12 22.00 ± 1.41

No-MTP, 2xR9700

model test t/s peak t/s ttfr (ms) est_ppt (ms) e2e_ttft (ms)
unsloth/Qwen3.6-27B pp4096 1037.44 ± 10.16 3949.56 ± 39.19 3949.20 ± 39.19 3949.56 ± 39.19
unsloth/Qwen3.6-27B tg512 20.80 ± 0.04 22.00 ± 0.00

MTP (--spec-draft-n-max 3), 1xR9700

model test t/s peak t/s ttfr (ms) est_ppt (ms) e2e_ttft (ms)
unsloth/Qwen3.6-27B pp4096 590.89 ± 18.38 6940.69 ± 220.78 6940.42 ± 220.78 6940.69 ± 220.78
unsloth/Qwen3.6-27B tg512 36.65 ± 1.22 51.67 ± 2.49

MTP (--spec-draft-n-max 3), 2xR9700

model test t/s peak t/s ttfr (ms) est_ppt (ms) e2e_ttft (ms)
unsloth/Qwen3.6-27B pp4096 513.46 ± 8.00 7981.43 ± 125.69 7981.15 ± 125.69 7981.43 ± 125.69
unsloth/Qwen3.6-27B tg512 35.69 ± 1.47 48.33 ± 0.47

Observations:

  • 50% increase in tp in test.
  • Significant PP degradation compared to the baseline. Same exe was used, except different starting command.
  • PP degraded even further with 2 cards.

I unfortunately do not know the specifics of the implementation to know whether this pp slowdown is to be expected. I was especially sad to see degradation deepen with 2 cards. Perhaps I would need to adjust the configuration to utilize the MTP feature better?

Thank you very much for your work.

@janvitos
Copy link
Copy Markdown

@janvitos the new branch stops early if the draft tokens are of low quality

Got it, thanks. So is --spec-draft-n-max now only a hard upper limit, with an internal confidence/quality early-stop deciding the actual draft length?

In my case, the new early-stop raises acceptance from ~72% to ~96%, but total throughput is slightly worse because the draft count drops from 1206 to 846. Is there currently a flag to tune or disable that early-stop threshold so users can choose between higher acceptance rate and higher draft aggressiveness?

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented May 12, 2026

@janvitos Yes you can play around with --spec-draft-p-min for this (please don't post your results on this thread though)

@NickM-27
Copy link
Copy Markdown

am I correct in understanding that Gemma4 MTP support will require further changes in a subsequent PR?

@exander77
Copy link
Copy Markdown

am I correct in understanding that Gemma4 MTP support will require further changes in a subsequent PR?

Yes, but I have some working implementation already. I can make a pull. So at least people test it.

Gemma 4 and Qwen 3.6 have different architecture.

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented May 12, 2026

Yes Gemma 4 should follow soon after this is merged

@demonhater
Copy link
Copy Markdown

For the mmproj issue... since the server already accepts .n_max per-request, would it be doable to automatically set n_max=0 when the request contains anything multimodal? That way MTP stays active for text-only turns and then falls back to autoregressive for vision turns, same server instance, no restart needed. Seems like it could be a low cost fix while the proper kernel-level support is worked out.

@beginor
Copy link
Copy Markdown

beginor commented May 12, 2026

waiting for mtp release

@candrews
Copy link
Copy Markdown

--spec-type mtp with a non-MTP model shouldn't work, of course. But, it should produce a reasonable error.

Instead, attempting to enable MTP with a non-MTP model results in a core dump:

$ llama-cli --spec-type mtp --prompt '/exit' -hf unsloth/Qwen3.6-35B-A3B-GGUF:Q4_K_M

Loading model... //var/tmp/portage/sci-misc/llama-cpp-9999/work/llama-cpp-9999/src/models/qwen35moe-mtp.cpp:10: GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE_MTP requires nextn_predict_layers > 0") failed                                                                      -[New LWP 390121]
[New LWP 390120]
[New LWP 390119]
[New LWP 390118]
[New LWP 390117]
[New LWP 390116]
[New LWP 390115]
[New LWP 390063]
[New LWP 390062]
[New LWP 390057]
[New LWP 390056]
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/usr/lib64/libthread_db.so.1".
0x00007f9bdb2bd332 in __syscall_cancel_arch () from /usr/lib64/libc.so.6#0  0x00007f9bdb2bd332 in __syscall_cancel_arch () from /usr/lib64/libc.so.6
#1  0x00007f9bdb2b0ff8 in __internal_syscall_cancel () from /usr/lib64/libc.so.6
#2  0x00007f9bdb2b1041 in __syscall_cancel () from /usr/lib64/libc.so.6
#3  0x00007f9bdb3130db in wait4 () from /usr/lib64/libc.so.6
#4  0x00007f9bdbfe3aab in ggml_print_backtrace () from /usr/lib64/llama.cpp/libggml-base.so.0
#5  0x00007f9bdbfe3c76 in ggml_abort () from /usr/lib64/llama.cpp/libggml-base.so.0
#6  0x00007f9bdba21a2b in llama_model_qwen35moe_mtp::load_arch_hparams(llama_model_loader&) () from /usr/lib64/llama.cpp/libllama.so.0
#7  0x00007f9bdb9385bb in llama_model_base::load_hparams(llama_model_loader&) () from /usr/lib64/llama.cpp/libllama.so.0
#8  0x00007f9bdb890ec6 in llama_model_load(gguf_context*, void (*)(ggml_tensor*, void*), void*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >&, _IO_FILE*, llama_model_params&) () from /usr/lib64/llama.cpp/libllama.so.0
#9  0x00007f9bdb8920aa in llama_model_load_from_file_impl(gguf_context*, void (*)(ggml_tensor*, void*), void*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >&, _IO_FILE*, llama_model_params) () from /usr/lib64/llama.cpp/libllama.so.0
#10 0x00007f9bdb892465 in llama_model_load_from_file () from /usr/lib64/llama.cpp/libllama.so.0
#11 0x000055802d2cf806 in server_context_impl::load_model(common_params&) ()
#12 0x000055802d260b73 in main ()                                      \[Inferior 1 (process 390051) detached]
Aborted                    (core dumped) llama-cli --spec-type mtp --prompt '/exit' -hf unsloth/Qwen3.6-35B-A3B-GGUF:Q4_K_M

@illsk1lls
Copy link
Copy Markdown

illsk1lls commented May 12, 2026

I lose vision and have to exclude my mmproj file when I build this... otherwise...

1xTesla T4 (16gb), Dual Xeon Platinum 8268, 1.5tb 2666 (Dell R640)

I'm getting my fastest speeds on a single CPU with the below settings

ExecStart=/usr/bin/numactl --cpunodebind=0 --membind=0
/home/user/llama.cpp-mtp/build/bin/llama-server-mtp
-m /models/Qwen3.6-35B-A3B-uncensored-heretic-Native-MTP-Preserved-Q8_0.gguf
-fitt 4096
-c 262144
-t 16
--host 0.0.0.0
--port 8081
-fa on
-np 1
-ctk f16
-ctv f16
--no-mmap
--mlock
--spec-type mtp
--spec-draft-n-max 2
--chat-template-kwargs '{"preserve_thinking": false}'
--temp 0.6
--top-p 0.95
--top-k 20
--min-p 0.0
--repeat-penalty 1.0

117t/s prompt
40t/s output

@rfairburn
Copy link
Copy Markdown

I tested ROCm/HIP builds on dual RX 7900 XTX / gfx1100 using the same build flags and runtime configuration.

Model:
froggeric/Qwen3.6-27B-MTP-GGUF:Q8_0

Runtime:
--split-mode tensor --tensor-split 1,1 --flash-attn on --ctx-size 131072 --batch-size 4096 --ubatch-size 1024 --ctx-checkpoints 16 --spec-type mtp --spec-draft-n-max 3 --spec-draft-ngl 99

Original MTP PR commit from May 7:
5d5f1b4

Result:

  • prompt eval: 666.92 tok/s
  • eval: 57.54 tok/s
  • total time: 145.12 s
  • draft acceptance: 65.53% (1384 accepted / 2112 generated)

Current mtp-clean / PR head:
ebe4fca

Result:

  • prompt eval: 669.89 tok/s
  • eval: 22.24 tok/s
  • total time: 226.58 s
  • draft acceptance: 97.40% (1574 accepted / 1616 generated)

Prompt processing is essentially unchanged, but token generation is much slower on mtp-clean despite the much higher draft acceptance rate. Am I doing something wrong or is there some kind of known or otherwise unknown regression? The updated PR performance is significantly slower than without MTP at all.

@StarWingOwl
Copy link
Copy Markdown

@rfairburn @candrews I'm having the same issue on a 7900GRE (gfx1100) . Tried both the Vulkan and rocm backend on both mtp-clean and pr-22673, if you see my previous message on here (granted, this is a very long thread. ), I had to severely limit -ngl and -n-cpu-moe to even get the first message, and I still ended up getting a fraction of the speeds I normally would, or it would just core dump. Any ideas on what could be happening?

@ZisIsNotZis
Copy link
Copy Markdown

ZisIsNotZis commented May 13, 2026

I think something just broke? This used to work using instruction under https://huggingface.co/havenoammo/Qwen3.6-35B-A3B-MTP-GGUF, but I just re-applied those instructions on another host with 4090, and the program fails:

Details
z@z3:~$ llama-server -m ~/hf/Qwen3.6-35B-A3B-MTP-UD-Q4_K_XL.gguf --spec-type mtp --spec-draft-n-max 3 -np 1 -c 131072
ggml_cuda_init: found 1 CUDA devices (Total VRAM: 24080 MiB):
  Device 0: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes, VRAM: 24080 MiB
build_info: b2507-0fa27bb1
system_info: n_threads = 6 (n_threads_batch = 6) / 12 | CUDA : ARCHS = 890 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 128 | FA_ALL_QUANTS = 1 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 | 
init: using 11 threads for HTTP server
start: binding port with default address family
main: loading model
srv    load_model: loading model '/home/z/hf/Qwen3.6-35B-A3B-MTP-UD-Q4_K_XL.gguf'
common_init_result: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on
common_params_fit_impl: getting device memory data for initial parameters:
common_memory_breakdown_print: | memory breakdown [MiB] | total    free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 4090)   | 24080 = 23605 + (24770 = 21654 +    2622 +     493) +      -24295 |
common_memory_breakdown_print: |   - Host               |                    779 =   515 +       0 +     264                |
common_params_fit_impl: projected to use 24770 MiB of device memory vs. 23605 MiB of free device memory
common_params_fit_impl: cannot meet free memory target of 1024 MiB, need to reduce device memory by 2188 MiB
common_params_fit_impl: context size set by user to 131072 -> no change
common_params_fit_impl: getting device memory data with all MoE tensors moved to system memory:
common_memory_breakdown_print: | memory breakdown [MiB] | total    free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 4090)   | 24080 = 23605 + ( 5194 =  2078 +    2622 +     493) +       -4719 |
common_memory_breakdown_print: |   - Host               |                  20355 = 20091 +       0 +     264                |
common_params_fit_impl: with only dense weights in device memory there is a total surplus of 17387 MiB
common_params_fit_impl: id=0, target=22581 MiB
common_memory_breakdown_print: | memory breakdown [MiB] | total    free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 4090)   | 24080 = 23669 + ( 1012 =     0 +       0 +    1012) +        -601 |
common_memory_breakdown_print: |   - Host               |                  25068 = 22169 +    2622 +     276                |
common_params_fit_impl: memory for test allocation by device:
common_params_fit_impl: id=0, n_layer= 0, n_part= 0, overflow_type=4, mem=  1012 MiB
common_params_fit_impl: filling dense-only layers back-to-front:
common_memory_breakdown_print: | memory breakdown [MiB] | total    free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 4090)   | 24080 = 23605 + ( 5658 =  2542 +    2622 +     493) +       -5183 |
common_memory_breakdown_print: |   - Host               |                  19891 = 19627 +       0 +     264                |
common_params_fit_impl: memory for test allocation by device:
common_params_fit_impl: id=0, n_layer=42, n_part=41, overflow_type=4, mem=  5658 MiB
common_params_fit_impl: set ngl_per_device[0].n_layer=42
common_params_fit_impl:   - CUDA0 (NVIDIA GeForce RTX 4090): 42 layers,   5658 MiB used,  17947 MiB free
common_params_fit_impl: converting dense-only layers to full layers and filling them front-to-back with overflow to next device/system memory:
common_memory_breakdown_print: | memory breakdown [MiB] | total    free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 4090)   | 24080 = 23605 + (24770 = 21654 +    2622 +     493) +      -24295 |
common_memory_breakdown_print: |   - Host               |                    779 =   515 +       0 +     264                |
common_params_fit_impl: memory for test allocation by device:
common_params_fit_impl: id=0, n_layer=42, n_part= 0, overflow_type=4, mem= 24770 MiB
common_memory_breakdown_print: | memory breakdown [MiB] | total    free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 4090)   | 24080 = 23605 + (22631 = 19378 +    2622 +     630) +      -22156 |
common_memory_breakdown_print: |   - Host               |                   3055 =  2791 +       0 +     264                |
common_params_fit_impl: memory for test allocation by device:
common_params_fit_impl: id=0, n_layer=42, n_part= 5, overflow_type=4, mem= 22631 MiB
common_params_fit_impl: set ngl_per_device_high[0].(n_layer, n_part)=(42, 5), id_dense_start_high=0
common_memory_breakdown_print: | memory breakdown [MiB] | total    free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 4090)   | 24080 = 23605 + (22167 = 18914 +    2622 +     630) +      -21692 |
common_memory_breakdown_print: |   - Host               |                   3519 =  3255 +       0 +     264                |
common_params_fit_impl: memory for test allocation by device:
common_params_fit_impl: id=0, n_layer=42, n_part= 6, overflow_type=4, mem= 22167 MiB
common_params_fit_impl: set ngl_per_device[0].(n_layer, n_part)=(42, 6), id_dense_start=0
common_params_fit_impl: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP
common_memory_breakdown_print: | memory breakdown [MiB] | total    free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 4090)   | 24080 = 23605 + (22307 = 19054 +    2622 +     630) +      -21832 |
common_memory_breakdown_print: |   - Host               |                   3379 =  3115 +       0 +     264                |
common_params_fit_impl: memory for test allocation by device:
common_params_fit_impl: id=0, n_layer=42, n_part= 6, overflow_type=2, mem= 22307 MiB
common_params_fit_impl: set ngl_per_device[0].(n_layer, n_part, overflow_type)=(42, 6, UP), id_dense_start=0
common_params_fit_impl: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE
common_memory_breakdown_print: | memory breakdown [MiB] | total    free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 4090)   | 24080 = 23605 + (22454 = 19201 +    2622 +     630) +      -21979 |
common_memory_breakdown_print: |   - Host               |                   3232 =  2968 +       0 +     264                |
common_params_fit_impl: memory for test allocation by device:
common_params_fit_impl: id=0, n_layer=42, n_part= 6, overflow_type=3, mem= 22454 MiB
common_params_fit_impl: set ngl_per_device[0].(n_layer, n_part, overflow_type)=(42, 6, GATE), id_dense_start=0
common_params_fit_impl:   - CUDA0 (NVIDIA GeForce RTX 4090): 42 layers ( 6 overflowing),  22454 MiB used,   1151 MiB free
common_fit_params: successfully fit params to free device memory
common_fit_params: fitting params to free memory took 6.87 seconds
llama_model_loader: loaded meta data with 55 key-value pairs and 753 tensors from /home/z/hf/Qwen3.6-35B-A3B-MTP-UD-Q4_K_XL.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = qwen35moe
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                     general.sampling.top_k i32              = 20
llama_model_loader: - kv   3:                     general.sampling.top_p f32              = 0.950000
llama_model_loader: - kv   4:                      general.sampling.temp f32              = 1.000000
llama_model_loader: - kv   5:                               general.name str              = Qwen3.6-35B-A3B
llama_model_loader: - kv   6:                           general.basename str              = Qwen3.6-35B-A3B
llama_model_loader: - kv   7:                       general.quantized_by str              = Unsloth
llama_model_loader: - kv   8:                         general.size_label str              = 35B-A3B
llama_model_loader: - kv   9:                            general.license str              = apache-2.0
llama_model_loader: - kv  10:                       general.license.link str              = https://huggingface.co/Qwen/Qwen3.6-3...
llama_model_loader: - kv  11:                           general.repo_url str              = https://huggingface.co/unsloth
llama_model_loader: - kv  12:                   general.base_model.count u32              = 1
llama_model_loader: - kv  13:                  general.base_model.0.name str              = Qwen3.6 35B A3B
llama_model_loader: - kv  14:          general.base_model.0.organization str              = Qwen
llama_model_loader: - kv  15:              general.base_model.0.repo_url str              = https://huggingface.co/Qwen/Qwen3.6-3...
llama_model_loader: - kv  16:                               general.tags arr[str,3]       = ["qwen3_5_moe", "qwen", "image-text-t...
llama_model_loader: - kv  17:                   qwen35moe.context_length u32              = 262144
llama_model_loader: - kv  18:                 qwen35moe.embedding_length u32              = 2048
llama_model_loader: - kv  19:             qwen35moe.attention.head_count u32              = 16
llama_model_loader: - kv  20:          qwen35moe.attention.head_count_kv u32              = 2
llama_model_loader: - kv  21:          qwen35moe.rope.dimension_sections arr[i32,4]       = [11, 11, 10, 0]
llama_model_loader: - kv  22:                   qwen35moe.rope.freq_base f32              = 10000000.000000
llama_model_loader: - kv  23: qwen35moe.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  24:                     qwen35moe.expert_count u32              = 256
llama_model_loader: - kv  25:                qwen35moe.expert_used_count u32              = 8
llama_model_loader: - kv  26:             qwen35moe.attention.key_length u32              = 256
llama_model_loader: - kv  27:           qwen35moe.attention.value_length u32              = 256
llama_model_loader: - kv  28:       qwen35moe.expert_feed_forward_length u32              = 512
llama_model_loader: - kv  29: qwen35moe.expert_shared_feed_forward_length u32              = 512
llama_model_loader: - kv  30:                  qwen35moe.ssm.conv_kernel u32              = 4
llama_model_loader: - kv  31:                   qwen35moe.ssm.state_size u32              = 128
llama_model_loader: - kv  32:                  qwen35moe.ssm.group_count u32              = 16
llama_model_loader: - kv  33:               qwen35moe.ssm.time_step_rank u32              = 32
llama_model_loader: - kv  34:                   qwen35moe.ssm.inner_size u32              = 4096
llama_model_loader: - kv  35:          qwen35moe.full_attention_interval u32              = 4
llama_model_loader: - kv  36:             qwen35moe.rope.dimension_count u32              = 64
llama_model_loader: - kv  37:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  38:                         tokenizer.ggml.pre str              = qwen35
llama_model_loader: - kv  39:                      tokenizer.ggml.tokens arr[str,248320]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  40:                  tokenizer.ggml.token_type arr[i32,248320]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  41:                      tokenizer.ggml.merges arr[str,247587]  = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "Ġ t",...
llama_model_loader: - kv  42:                tokenizer.ggml.eos_token_id u32              = 248046
llama_model_loader: - kv  43:            tokenizer.ggml.padding_token_id u32              = 248055
llama_model_loader: - kv  44:                tokenizer.ggml.bos_token_id u32              = 248044
llama_model_loader: - kv  45:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  46:                    tokenizer.chat_template str              = {%- set image_count = namespace(value...
llama_model_loader: - kv  47:               general.quantization_version u32              = 2
llama_model_loader: - kv  48:                          general.file_type u32              = 15
llama_model_loader: - kv  49:                      quantize.imatrix.file str              = Qwen3.6-35B-A3B-GGUF/imatrix_unsloth....
llama_model_loader: - kv  50:                   quantize.imatrix.dataset str              = unsloth_calibration_Qwen3.6-35B-A3B.txt
llama_model_loader: - kv  51:             quantize.imatrix.entries_count u32              = 510
llama_model_loader: - kv  52:              quantize.imatrix.chunks_count u32              = 76
llama_model_loader: - kv  53:                      qwen35moe.block_count u32              = 41
llama_model_loader: - kv  54:             qwen35moe.nextn_predict_layers u32              = 1
llama_model_loader: - type  f32:  368 tensors
llama_model_loader: - type q8_0:  265 tensors
llama_model_loader: - type q4_K:   78 tensors
llama_model_loader: - type q5_K:   38 tensors
llama_model_loader: - type q6_K:    4 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = Q4_K - Medium
print_info: file size   = 21.65 GiB (5.24 BPW) 
llama_prepare_model_devices: using device CUDA0 (NVIDIA GeForce RTX 4090) (0000:4c:00.0) - 23669 MiB free
load: 0 unused tokens
load: printing all EOG tokens:
load:   - 248044 ('<|endoftext|>')
load:   - 248046 ('<|im_end|>')
load:   - 248063 ('<|fim_pad|>')
load:   - 248064 ('<|repo_name|>')
load:   - 248065 ('<|file_sep|>')
load: special tokens cache size = 33
load: token to piece cache size = 1.7581 MB
print_info: arch                  = qwen35moe
print_info: vocab_only            = 0
print_info: no_alloc              = 0
print_info: n_ctx_train           = 262144
print_info: n_embd                = 2048
print_info: n_embd_inp            = 2048
print_info: n_layer               = 41
print_info: n_head                = 16
print_info: n_head_kv             = 2
print_info: n_rot                 = 64
print_info: n_swa                 = 0
print_info: is_swa_any            = 0
print_info: n_embd_head_k         = 256
print_info: n_embd_head_v         = 256
print_info: n_gqa                 = 8
print_info: n_embd_k_gqa          = 512
print_info: n_embd_v_gqa          = 512
print_info: f_norm_eps            = 0.0e+00
print_info: f_norm_rms_eps        = 1.0e-06
print_info: f_clamp_kqv           = 0.0e+00
print_info: f_max_alibi_bias      = 0.0e+00
print_info: f_logit_scale         = 0.0e+00
print_info: f_attn_scale          = 0.0e+00
print_info: f_attn_value_scale    = 0.0000
print_info: n_ff                  = 0
print_info: n_expert              = 256
print_info: n_expert_used         = 8
print_info: n_expert_groups       = 0
print_info: n_group_used          = 0
print_info: causal attn           = 1
print_info: pooling type          = -1
print_info: rope type             = 40
print_info: rope scaling          = linear
print_info: freq_base_train       = 10000000.0
print_info: freq_scale_train      = 1
print_info: n_ctx_orig_yarn       = 262144
print_info: rope_yarn_log_mul     = 0.0000
print_info: rope_finetuned        = unknown
print_info: mrope sections        = [11, 11, 10, 0]
print_info: ssm_d_conv            = 4
print_info: ssm_d_inner           = 4096
print_info: ssm_d_state           = 128
print_info: ssm_dt_rank           = 32
print_info: ssm_n_group           = 16
print_info: ssm_dt_b_c_rms        = 0
print_info: model type            = 35B.A3B
print_info: model params          = 35.51 B
print_info: general.name          = Qwen3.6-35B-A3B
print_info: vocab type            = BPE
print_info: n_vocab               = 248320
print_info: n_merges              = 247587
print_info: BOS token             = 248044 '<|endoftext|>'
print_info: EOS token             = 248046 '<|im_end|>'
print_info: EOT token             = 248046 '<|im_end|>'
print_info: PAD token             = 248055 '<|vision_pad|>'
print_info: LF token              = 198 'Ċ'
print_info: FIM PRE token         = 248060 '<|fim_prefix|>'
print_info: FIM SUF token         = 248062 '<|fim_suffix|>'
print_info: FIM MID token         = 248061 '<|fim_middle|>'
print_info: FIM PAD token         = 248063 '<|fim_pad|>'
print_info: FIM REP token         = 248064 '<|repo_name|>'
print_info: FIM SEP token         = 248065 '<|file_sep|>'
print_info: EOG token             = 248044 '<|endoftext|>'
print_info: EOG token             = 248046 '<|im_end|>'
print_info: EOG token             = 248063 '<|fim_pad|>'
print_info: EOG token             = 248064 '<|repo_name|>'
print_info: EOG token             = 248065 '<|file_sep|>'
print_info: max token length      = 256
load_tensors: loading model tensors, this can take a while... (mmap = true, direct_io = false)
llama_model_loader: tensor overrides to CPU are used with mmap enabled - consider using --no-mmap for better performance
load_tensors: offloading output layer to GPU
load_tensors: offloading 40 repeating layers to GPU
load_tensors: offloaded 42/42 layers to GPU
load_tensors:   CPU_Mapped model buffer size = 21623.30 MiB
load_tensors:        CUDA0 model buffer size = 19201.62 MiB
.................................................................................................
common_init_result: added <|endoftext|> logit bias = -inf
common_init_result: added <|im_end|> logit bias = -inf
common_init_result: added <|fim_pad|> logit bias = -inf
common_init_result: added <|repo_name|> logit bias = -inf
common_init_result: added <|file_sep|> logit bias = -inf
llama_init_from_model: model default pooling_type is [-1], but [3] was specified
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 131072
llama_context: n_ctx_seq     = 131072
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = auto
llama_context: kv_unified    = false
llama_context: freq_base     = 10000000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_seq (131072) < n_ctx_train (262144) -- the full capacity of the model will not be utilized
llama_context:  CUDA_Host  output buffer size =     0.95 MiB
llama_kv_cache:      CUDA0 KV buffer size =  2560.00 MiB
llama_kv_cache: size = 2560.00 MiB (131072 cells,  10 layers,  1/1 seqs), K (f16): 1280.00 MiB, V (f16): 1280.00 MiB
llama_kv_cache: attn_rot_k = 0, n_embd_head_k_all = 256
llama_kv_cache: attn_rot_v = 0, n_embd_head_k_all = 256
llama_memory_recurrent:      CUDA0 RS buffer size =    62.81 MiB
llama_memory_recurrent: size =   62.81 MiB (     1 cells,  41 layers,  1 seqs), R (f32):    2.81 MiB, S (f32):   60.00 MiB
sched_reserve: reserving ...
sched_reserve: Flash Attention was auto, set to enabled
sched_reserve: resolving fused Gated Delta Net support:
sched_reserve: fused Gated Delta Net (autoregressive) enabled
sched_reserve: fused Gated Delta Net (chunked) enabled
sched_reserve:      CUDA0 compute buffer size =   630.02 MiB
sched_reserve:  CUDA_Host compute buffer size =   264.02 MiB
sched_reserve: graph nodes  = 3729
sched_reserve: graph splits = 13 (with bs=512), 12 (with bs=1)
sched_reserve: reserve took 101.26 ms, sched copies = 1
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
srv    load_model: loading MTP head from '/home/z/hf/Qwen3.6-35B-A3B-MTP-UD-Q4_K_XL.gguf' (override_arch=qwen35moe_mtp)
llama_model_loader: loaded meta data with 55 key-value pairs and 753 tensors from /home/z/hf/Qwen3.6-35B-A3B-MTP-UD-Q4_K_XL.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = qwen35moe
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                     general.sampling.top_k i32              = 20
llama_model_loader: - kv   3:                     general.sampling.top_p f32              = 0.950000
llama_model_loader: - kv   4:                      general.sampling.temp f32              = 1.000000
llama_model_loader: - kv   5:                               general.name str              = Qwen3.6-35B-A3B
llama_model_loader: - kv   6:                           general.basename str              = Qwen3.6-35B-A3B
llama_model_loader: - kv   7:                       general.quantized_by str              = Unsloth
llama_model_loader: - kv   8:                         general.size_label str              = 35B-A3B
llama_model_loader: - kv   9:                            general.license str              = apache-2.0
llama_model_loader: - kv  10:                       general.license.link str              = https://huggingface.co/Qwen/Qwen3.6-3...
llama_model_loader: - kv  11:                           general.repo_url str              = https://huggingface.co/unsloth
llama_model_loader: - kv  12:                   general.base_model.count u32              = 1
llama_model_loader: - kv  13:                  general.base_model.0.name str              = Qwen3.6 35B A3B
llama_model_loader: - kv  14:          general.base_model.0.organization str              = Qwen
llama_model_loader: - kv  15:              general.base_model.0.repo_url str              = https://huggingface.co/Qwen/Qwen3.6-3...
llama_model_loader: - kv  16:                               general.tags arr[str,3]       = ["qwen3_5_moe", "qwen", "image-text-t...
llama_model_loader: - kv  17:                   qwen35moe.context_length u32              = 262144
llama_model_loader: - kv  18:                 qwen35moe.embedding_length u32              = 2048
llama_model_loader: - kv  19:             qwen35moe.attention.head_count u32              = 16
llama_model_loader: - kv  20:          qwen35moe.attention.head_count_kv u32              = 2
llama_model_loader: - kv  21:          qwen35moe.rope.dimension_sections arr[i32,4]       = [11, 11, 10, 0]
llama_model_loader: - kv  22:                   qwen35moe.rope.freq_base f32              = 10000000.000000
llama_model_loader: - kv  23: qwen35moe.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  24:                     qwen35moe.expert_count u32              = 256
llama_model_loader: - kv  25:                qwen35moe.expert_used_count u32              = 8
llama_model_loader: - kv  26:             qwen35moe.attention.key_length u32              = 256
llama_model_loader: - kv  27:           qwen35moe.attention.value_length u32              = 256
llama_model_loader: - kv  28:       qwen35moe.expert_feed_forward_length u32              = 512
llama_model_loader: - kv  29: qwen35moe.expert_shared_feed_forward_length u32              = 512
llama_model_loader: - kv  30:                  qwen35moe.ssm.conv_kernel u32              = 4
llama_model_loader: - kv  31:                   qwen35moe.ssm.state_size u32              = 128
llama_model_loader: - kv  32:                  qwen35moe.ssm.group_count u32              = 16
llama_model_loader: - kv  33:               qwen35moe.ssm.time_step_rank u32              = 32
llama_model_loader: - kv  34:                   qwen35moe.ssm.inner_size u32              = 4096
llama_model_loader: - kv  35:          qwen35moe.full_attention_interval u32              = 4
llama_model_loader: - kv  36:             qwen35moe.rope.dimension_count u32              = 64
llama_model_loader: - kv  37:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  38:                         tokenizer.ggml.pre str              = qwen35
llama_model_loader: - kv  39:                      tokenizer.ggml.tokens arr[str,248320]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  40:                  tokenizer.ggml.token_type arr[i32,248320]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  41:                      tokenizer.ggml.merges arr[str,247587]  = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "Ġ t",...
llama_model_loader: - kv  42:                tokenizer.ggml.eos_token_id u32              = 248046
llama_model_loader: - kv  43:            tokenizer.ggml.padding_token_id u32              = 248055
llama_model_loader: - kv  44:                tokenizer.ggml.bos_token_id u32              = 248044
llama_model_loader: - kv  45:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  46:                    tokenizer.chat_template str              = {%- set image_count = namespace(value...
llama_model_loader: - kv  47:               general.quantization_version u32              = 2
llama_model_loader: - kv  48:                          general.file_type u32              = 15
llama_model_loader: - kv  49:                      quantize.imatrix.file str              = Qwen3.6-35B-A3B-GGUF/imatrix_unsloth....
llama_model_loader: - kv  50:                   quantize.imatrix.dataset str              = unsloth_calibration_Qwen3.6-35B-A3B.txt
llama_model_loader: - kv  51:             quantize.imatrix.entries_count u32              = 510
llama_model_loader: - kv  52:              quantize.imatrix.chunks_count u32              = 76
llama_model_loader: - kv  53:                      qwen35moe.block_count u32              = 41
llama_model_loader: - kv  54:             qwen35moe.nextn_predict_layers u32              = 1
llama_model_loader: - type  f32:  368 tensors
llama_model_loader: - type q8_0:  265 tensors
llama_model_loader: - type q4_K:   78 tensors
llama_model_loader: - type q5_K:   38 tensors
llama_model_loader: - type q6_K:    4 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = Q4_K - Medium
print_info: file size   = 21.65 GiB (5.24 BPW) 
llama_model_create: overriding architecture qwen35moe -> qwen35moe_mtp
llama_prepare_model_devices: using device CUDA0 (NVIDIA GeForce RTX 4090) (0000:4c:00.0) - 1167 MiB free
load: 0 unused tokens
load: printing all EOG tokens:
load:   - 248044 ('<|endoftext|>')
load:   - 248046 ('<|im_end|>')
load:   - 248063 ('<|fim_pad|>')
load:   - 248064 ('<|repo_name|>')
load:   - 248065 ('<|file_sep|>')
load: special tokens cache size = 33
load: token to piece cache size = 1.7581 MB
print_info: arch                  = qwen35moe_mtp
print_info: vocab_only            = 0
print_info: no_alloc              = 0
print_info: n_ctx_train           = 262144
print_info: n_embd                = 2048
print_info: n_embd_inp            = 2048
print_info: n_layer               = 41
print_info: n_head                = 16
print_info: n_head_kv             = 2
print_info: n_rot                 = 64
print_info: n_swa                 = 0
print_info: is_swa_any            = 0
print_info: n_embd_head_k         = 256
print_info: n_embd_head_v         = 256
print_info: n_gqa                 = 8
print_info: n_embd_k_gqa          = 512
print_info: n_embd_v_gqa          = 512
print_info: f_norm_eps            = 0.0e+00
print_info: f_norm_rms_eps        = 1.0e-06
print_info: f_clamp_kqv           = 0.0e+00
print_info: f_max_alibi_bias      = 0.0e+00
print_info: f_logit_scale         = 0.0e+00
print_info: f_attn_scale          = 0.0e+00
print_info: f_attn_value_scale    = 0.0000
print_info: n_ff                  = 0
print_info: n_expert              = 256
print_info: n_expert_used         = 8
print_info: n_expert_groups       = 0
print_info: n_group_used          = 0
print_info: causal attn           = 1
print_info: pooling type          = -1
print_info: rope type             = 40
print_info: rope scaling          = linear
print_info: freq_base_train       = 10000000.0
print_info: freq_scale_train      = 1
print_info: n_ctx_orig_yarn       = 262144
print_info: rope_yarn_log_mul     = 0.0000
print_info: rope_finetuned        = unknown
print_info: mrope sections        = [11, 11, 10, 0]
print_info: model type            = ?B
print_info: model params          = 35.51 B
print_info: general.name          = Qwen3.6-35B-A3B
print_info: vocab type            = BPE
print_info: n_vocab               = 248320
print_info: n_merges              = 247587
print_info: BOS token             = 248044 '<|endoftext|>'
print_info: EOS token             = 248046 '<|im_end|>'
print_info: EOT token             = 248046 '<|im_end|>'
print_info: PAD token             = 248055 '<|vision_pad|>'
print_info: LF token              = 198 'Ċ'
print_info: FIM PRE token         = 248060 '<|fim_prefix|>'
print_info: FIM SUF token         = 248062 '<|fim_suffix|>'
print_info: FIM MID token         = 248061 '<|fim_middle|>'
print_info: FIM PAD token         = 248063 '<|fim_pad|>'
print_info: FIM REP token         = 248064 '<|repo_name|>'
print_info: FIM SEP token         = 248065 '<|file_sep|>'
print_info: EOG token             = 248044 '<|endoftext|>'
print_info: EOG token             = 248046 '<|im_end|>'
print_info: EOG token             = 248063 '<|fim_pad|>'
print_info: EOG token             = 248064 '<|repo_name|>'
print_info: EOG token             = 248065 '<|file_sep|>'
print_info: max token length      = 256
load_tensors: loading model tensors, this can take a while... (mmap = true, direct_io = false)
done_getting_tensors: partial load — used 23 of 753 tensors in the file (rest belong to a sibling model on the same .gguf)
load_tensors: offloading output layer to GPU
load_tensors: offloading 40 repeating layers to GPU
load_tensors: offloaded 42/42 layers to GPU
load_tensors:   CPU_Mapped model buffer size = 21623.30 MiB
load_tensors:        CUDA0 model buffer size =   555.21 MiB
.....llama_init_from_model: model default pooling_type is [-1], but [3] was specified
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 131072
llama_context: n_ctx_seq     = 131072
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = auto
llama_context: kv_unified    = false
llama_context: freq_base     = 10000000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_seq (131072) < n_ctx_train (262144) -- the full capacity of the model will not be utilized
llama_context:  CUDA_Host  output buffer size =     0.95 MiB
llama_kv_cache:      CUDA0 KV buffer size =   256.00 MiB
llama_kv_cache: size =  256.00 MiB (131072 cells,   1 layers,  1/1 seqs), K (f16):  128.00 MiB, V (f16):  128.00 MiB
llama_kv_cache: attn_rot_k = 0, n_embd_head_k_all = 256
llama_kv_cache: attn_rot_v = 0, n_embd_head_k_all = 256
sched_reserve: reserving ...
sched_reserve: Flash Attention was auto, set to enabled
sched_reserve: resolving fused Gated Delta Net support:
sched_reserve: fused Gated Delta Net (autoregressive) enabled
sched_reserve: fused Gated Delta Net (chunked) enabled
ggml_backend_cuda_buffer_type_alloc_buffer: allocating 497.00 MiB on device 0: cudaMalloc failed: out of memory
ggml_gallocr_reserve_n_impl: failed to allocate CUDA0 buffer of size 521142272
graph_reserve: failed to allocate compute buffers
llama_init_from_model: failed to initialize the context: failed to allocate compute pp buffers
srv    load_model: failed to create MTP context
srv    operator(): operator(): cleaning up before exit...
main: exiting due to model loading error

Then I retried those instructions on my RTX 2000 Ada Laptop:

(z) z@z2:~/llama.cpp$ llama-server -m ~/hf/Qwen3.6-35B-A3B-MTP-UD-Q4_K_XL.gguf --spec-type mtp --spec-draft-n-max 3 -np 1 -c 131072
ggml_cuda_init: found 1 CUDA devices (Total VRAM: 7807 MiB):
  Device 0: NVIDIA RTX 2000 Ada Generation Laptop GPU, compute capability 8.9, VMM: yes, VRAM: 7807 MiB
build_info: b963-d7267a0
system_info: n_threads = 8 (n_threads_batch = 8) / 28 | CUDA : ARCHS = 890 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 128 | FA_ALL_QUANTS = 1 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX_VNNI = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 | 
init: using 27 threads for HTTP server
start: binding port with default address family
main: loading model
srv    load_model: loading model '/home/z/hf/Qwen3.6-35B-A3B-MTP-UD-Q4_K_XL.gguf'
common_init_result: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on
common_params_fit_impl: getting device memory data for initial parameters:
common_memory_breakdown_print: | memory breakdown [MiB]                         | total   free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 2000 Ada Generation Laptop GPU) |  7807 = 7627 + (24770 = 21654 +    2622 +     493) +      -24590 |
common_memory_breakdown_print: |   - Host                                       |                   779 =   515 +       0 +     264                |
common_params_fit_impl: projected to use 24770 MiB of device memory vs. 7627 MiB of free device memory
common_params_fit_impl: cannot meet free memory target of 1024 MiB, need to reduce device memory by 18167 MiB
common_params_fit_impl: context size set by user to 131072 -> no change
common_params_fit_impl: getting device memory data with all MoE tensors moved to system memory:
common_memory_breakdown_print: | memory breakdown [MiB]                         | total   free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 2000 Ada Generation Laptop GPU) |  7807 = 7627 + ( 5194 =  2078 +    2622 +     493) +       -5014 |
common_memory_breakdown_print: |   - Host                                       |                 20355 = 20091 +       0 +     264                |
common_params_fit_impl: with only dense weights in device memory there is a total surplus of 1408 MiB
common_params_fit_impl: id=0, target=6603 MiB
common_memory_breakdown_print: | memory breakdown [MiB]                         | total   free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 2000 Ada Generation Laptop GPU) |  7807 = 7691 + ( 1012 =     0 +       0 +    1012) +        -895 |
common_memory_breakdown_print: |   - Host                                       |                 25068 = 22169 +    2622 +     276                |
common_params_fit_impl: memory for test allocation by device:
common_params_fit_impl: id=0, n_layer= 0, n_part= 0, overflow_type=4, mem=  1012 MiB
common_params_fit_impl: filling dense-only layers back-to-front:
common_memory_breakdown_print: | memory breakdown [MiB]                         | total   free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 2000 Ada Generation Laptop GPU) |  7807 = 7627 + ( 5658 =  2542 +    2622 +     493) +       -5478 |
common_memory_breakdown_print: |   - Host                                       |                 19891 = 19627 +       0 +     264                |
common_params_fit_impl: memory for test allocation by device:
common_params_fit_impl: id=0, n_layer=42, n_part=41, overflow_type=4, mem=  5658 MiB
common_params_fit_impl: set ngl_per_device[0].n_layer=42
common_params_fit_impl:   - CUDA0 (NVIDIA RTX 2000 Ada Generation Laptop GPU): 42 layers,   5658 MiB used,   1968 MiB free
common_params_fit_impl: converting dense-only layers to full layers and filling them front-to-back with overflow to next device/system memory:
common_memory_breakdown_print: | memory breakdown [MiB]                         | total   free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 2000 Ada Generation Laptop GPU) |  7807 = 7627 + (24770 = 21654 +    2622 +     493) +      -24590 |
common_memory_breakdown_print: |   - Host                                       |                   779 =   515 +       0 +     264                |
common_params_fit_impl: memory for test allocation by device:
common_params_fit_impl: id=0, n_layer=42, n_part= 0, overflow_type=4, mem= 24770 MiB
common_memory_breakdown_print: | memory breakdown [MiB]                         | total   free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 2000 Ada Generation Laptop GPU) |  7807 = 7627 + ( 6821 =  3568 +    2622 +     630) +       -6641 |
common_memory_breakdown_print: |   - Host                                       |                 18865 = 18601 +       0 +     264                |
common_params_fit_impl: memory for test allocation by device:
common_params_fit_impl: id=0, n_layer=42, n_part=39, overflow_type=4, mem=  6821 MiB
common_params_fit_impl: set ngl_per_device_high[0].(n_layer, n_part)=(42, 39), id_dense_start_high=0
common_memory_breakdown_print: | memory breakdown [MiB]                         | total   free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 2000 Ada Generation Laptop GPU) |  7807 = 7627 + ( 6220 =  3104 +    2622 +     493) +       -6040 |
common_memory_breakdown_print: |   - Host                                       |                 19329 = 19065 +       0 +     264                |
common_params_fit_impl: memory for test allocation by device:
common_params_fit_impl: id=0, n_layer=42, n_part=40, overflow_type=4, mem=  6220 MiB
common_params_fit_impl: set ngl_per_device[0].(n_layer, n_part)=(42, 40), id_dense_start=0
common_params_fit_impl: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP
common_memory_breakdown_print: | memory breakdown [MiB]                         | total   free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 2000 Ada Generation Laptop GPU) |  7807 = 7627 + ( 6504 =  3244 +    2622 +     637) +       -6324 |
common_memory_breakdown_print: |   - Host                                       |                 19189 = 18925 +       0 +     264                |
common_params_fit_impl: memory for test allocation by device:
common_params_fit_impl: id=0, n_layer=42, n_part=40, overflow_type=2, mem=  6504 MiB
common_params_fit_impl: set ngl_per_device[0].(n_layer, n_part, overflow_type)=(42, 40, UP), id_dense_start=0
common_params_fit_impl: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE
common_memory_breakdown_print: | memory breakdown [MiB]                         | total   free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - CUDA0 (RTX 2000 Ada Generation Laptop GPU) |  7807 = 7627 + ( 6647 =  3391 +    2622 +     633) +       -6467 |
common_memory_breakdown_print: |   - Host                                       |                 19042 = 18778 +       0 +     264                |
common_params_fit_impl: memory for test allocation by device:
common_params_fit_impl: id=0, n_layer=42, n_part=40, overflow_type=3, mem=  6647 MiB
common_params_fit_impl:   - CUDA0 (NVIDIA RTX 2000 Ada Generation Laptop GPU): 42 layers (40 overflowing),   6504 MiB used,   1122 MiB free
common_fit_params: successfully fit params to free device memory
common_fit_params: fitting params to free memory took 4.31 seconds
llama_model_loader: loaded meta data with 55 key-value pairs and 753 tensors from /home/z/hf/Qwen3.6-35B-A3B-MTP-UD-Q4_K_XL.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = qwen35moe
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                     general.sampling.top_k i32              = 20
llama_model_loader: - kv   3:                     general.sampling.top_p f32              = 0.950000
llama_model_loader: - kv   4:                      general.sampling.temp f32              = 1.000000
llama_model_loader: - kv   5:                               general.name str              = Qwen3.6-35B-A3B
llama_model_loader: - kv   6:                           general.basename str              = Qwen3.6-35B-A3B
llama_model_loader: - kv   7:                       general.quantized_by str              = Unsloth
llama_model_loader: - kv   8:                         general.size_label str              = 35B-A3B
llama_model_loader: - kv   9:                            general.license str              = apache-2.0
llama_model_loader: - kv  10:                       general.license.link str              = https://huggingface.co/Qwen/Qwen3.6-3...
llama_model_loader: - kv  11:                           general.repo_url str              = https://huggingface.co/unsloth
llama_model_loader: - kv  12:                   general.base_model.count u32              = 1
llama_model_loader: - kv  13:                  general.base_model.0.name str              = Qwen3.6 35B A3B
llama_model_loader: - kv  14:          general.base_model.0.organization str              = Qwen
llama_model_loader: - kv  15:              general.base_model.0.repo_url str              = https://huggingface.co/Qwen/Qwen3.6-3...
llama_model_loader: - kv  16:                               general.tags arr[str,3]       = ["qwen3_5_moe", "qwen", "image-text-t...
llama_model_loader: - kv  17:                   qwen35moe.context_length u32              = 262144
llama_model_loader: - kv  18:                 qwen35moe.embedding_length u32              = 2048
llama_model_loader: - kv  19:             qwen35moe.attention.head_count u32              = 16
llama_model_loader: - kv  20:          qwen35moe.attention.head_count_kv u32              = 2
llama_model_loader: - kv  21:          qwen35moe.rope.dimension_sections arr[i32,4]       = [11, 11, 10, 0]
llama_model_loader: - kv  22:                   qwen35moe.rope.freq_base f32              = 10000000.000000
llama_model_loader: - kv  23: qwen35moe.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  24:                     qwen35moe.expert_count u32              = 256
llama_model_loader: - kv  25:                qwen35moe.expert_used_count u32              = 8
llama_model_loader: - kv  26:             qwen35moe.attention.key_length u32              = 256
llama_model_loader: - kv  27:           qwen35moe.attention.value_length u32              = 256
llama_model_loader: - kv  28:       qwen35moe.expert_feed_forward_length u32              = 512
llama_model_loader: - kv  29: qwen35moe.expert_shared_feed_forward_length u32              = 512
llama_model_loader: - kv  30:                  qwen35moe.ssm.conv_kernel u32              = 4
llama_model_loader: - kv  31:                   qwen35moe.ssm.state_size u32              = 128
llama_model_loader: - kv  32:                  qwen35moe.ssm.group_count u32              = 16
llama_model_loader: - kv  33:               qwen35moe.ssm.time_step_rank u32              = 32
llama_model_loader: - kv  34:                   qwen35moe.ssm.inner_size u32              = 4096
llama_model_loader: - kv  35:          qwen35moe.full_attention_interval u32              = 4
llama_model_loader: - kv  36:             qwen35moe.rope.dimension_count u32              = 64
llama_model_loader: - kv  37:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  38:                         tokenizer.ggml.pre str              = qwen35
llama_model_loader: - kv  39:                      tokenizer.ggml.tokens arr[str,248320]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  40:                  tokenizer.ggml.token_type arr[i32,248320]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  41:                      tokenizer.ggml.merges arr[str,247587]  = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "Ġ t",...
llama_model_loader: - kv  42:                tokenizer.ggml.eos_token_id u32              = 248046
llama_model_loader: - kv  43:            tokenizer.ggml.padding_token_id u32              = 248055
llama_model_loader: - kv  44:                tokenizer.ggml.bos_token_id u32              = 248044
llama_model_loader: - kv  45:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  46:                    tokenizer.chat_template str              = {%- set image_count = namespace(value...
llama_model_loader: - kv  47:               general.quantization_version u32              = 2
llama_model_loader: - kv  48:                          general.file_type u32              = 15
llama_model_loader: - kv  49:                      quantize.imatrix.file str              = Qwen3.6-35B-A3B-GGUF/imatrix_unsloth....
llama_model_loader: - kv  50:                   quantize.imatrix.dataset str              = unsloth_calibration_Qwen3.6-35B-A3B.txt
llama_model_loader: - kv  51:             quantize.imatrix.entries_count u32              = 510
llama_model_loader: - kv  52:              quantize.imatrix.chunks_count u32              = 76
llama_model_loader: - kv  53:                      qwen35moe.block_count u32              = 41
llama_model_loader: - kv  54:             qwen35moe.nextn_predict_layers u32              = 1
llama_model_loader: - type  f32:  368 tensors
llama_model_loader: - type q8_0:  265 tensors
llama_model_loader: - type q4_K:   78 tensors
llama_model_loader: - type q5_K:   38 tensors
llama_model_loader: - type q6_K:    4 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = Q4_K - Medium
print_info: file size   = 21.65 GiB (5.24 BPW) 
llama_prepare_model_devices: using device CUDA0 (NVIDIA RTX 2000 Ada Generation Laptop GPU) (0000:01:00.0) - 7691 MiB free
load: 0 unused tokens
load: printing all EOG tokens:
load:   - 248044 ('<|endoftext|>')
load:   - 248046 ('<|im_end|>')
load:   - 248063 ('<|fim_pad|>')
load:   - 248064 ('<|repo_name|>')
load:   - 248065 ('<|file_sep|>')
load: special tokens cache size = 33
load: token to piece cache size = 1.7581 MB
print_info: arch                  = qwen35moe
print_info: vocab_only            = 0
print_info: no_alloc              = 0
print_info: n_ctx_train           = 262144
print_info: n_embd                = 2048
print_info: n_embd_inp            = 2048
print_info: n_layer               = 41
print_info: n_head                = 16
print_info: n_head_kv             = 2
print_info: n_rot                 = 64
print_info: n_swa                 = 0
print_info: is_swa_any            = 0
print_info: n_embd_head_k         = 256
print_info: n_embd_head_v         = 256
print_info: n_gqa                 = 8
print_info: n_embd_k_gqa          = 512
print_info: n_embd_v_gqa          = 512
print_info: f_norm_eps            = 0.0e+00
print_info: f_norm_rms_eps        = 1.0e-06
print_info: f_clamp_kqv           = 0.0e+00
print_info: f_max_alibi_bias      = 0.0e+00
print_info: f_logit_scale         = 0.0e+00
print_info: f_attn_scale          = 0.0e+00
print_info: f_attn_value_scale    = 0.0000
print_info: n_ff                  = 0
print_info: n_expert              = 256
print_info: n_expert_used         = 8
print_info: n_expert_groups       = 0
print_info: n_group_used          = 0
print_info: causal attn           = 1
print_info: pooling type          = -1
print_info: rope type             = 40
print_info: rope scaling          = linear
print_info: freq_base_train       = 10000000.0
print_info: freq_scale_train      = 1
print_info: n_ctx_orig_yarn       = 262144
print_info: rope_yarn_log_mul     = 0.0000
print_info: rope_finetuned        = unknown
print_info: mrope sections        = [11, 11, 10, 0]
print_info: ssm_d_conv            = 4
print_info: ssm_d_inner           = 4096
print_info: ssm_d_state           = 128
print_info: ssm_dt_rank           = 32
print_info: ssm_n_group           = 16
print_info: ssm_dt_b_c_rms        = 0
print_info: model type            = 35B.A3B
print_info: model params          = 35.51 B
print_info: general.name          = Qwen3.6-35B-A3B
print_info: vocab type            = BPE
print_info: n_vocab               = 248320
print_info: n_merges              = 247587
print_info: BOS token             = 248044 '<|endoftext|>'
print_info: EOS token             = 248046 '<|im_end|>'
print_info: EOT token             = 248046 '<|im_end|>'
print_info: PAD token             = 248055 '<|vision_pad|>'
print_info: LF token              = 198 'Ċ'
print_info: FIM PRE token         = 248060 '<|fim_prefix|>'
print_info: FIM SUF token         = 248062 '<|fim_suffix|>'
print_info: FIM MID token         = 248061 '<|fim_middle|>'
print_info: FIM PAD token         = 248063 '<|fim_pad|>'
print_info: FIM REP token         = 248064 '<|repo_name|>'
print_info: FIM SEP token         = 248065 '<|file_sep|>'
print_info: EOG token             = 248044 '<|endoftext|>'
print_info: EOG token             = 248046 '<|im_end|>'
print_info: EOG token             = 248063 '<|fim_pad|>'
print_info: EOG token             = 248064 '<|repo_name|>'
print_info: EOG token             = 248065 '<|file_sep|>'
print_info: max token length      = 256
load_tensors: loading model tensors, this can take a while... (mmap = true, direct_io = false)
llama_model_loader: tensor overrides to CPU are used with mmap enabled - consider using --no-mmap for better performance
load_tensors: offloading output layer to GPU
load_tensors: offloading 40 repeating layers to GPU
load_tensors: offloaded 42/42 layers to GPU
load_tensors:   CPU_Mapped model buffer size = 21623.30 MiB
load_tensors:        CUDA0 model buffer size =  3244.55 MiB
.................................................................................................
common_init_result: added <|endoftext|> logit bias = -inf
common_init_result: added <|im_end|> logit bias = -inf
common_init_result: added <|fim_pad|> logit bias = -inf
common_init_result: added <|repo_name|> logit bias = -inf
common_init_result: added <|file_sep|> logit bias = -inf
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 131072
llama_context: n_ctx_seq     = 131072
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = auto
llama_context: kv_unified    = false
llama_context: freq_base     = 10000000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_seq (131072) < n_ctx_train (262144) -- the full capacity of the model will not be utilized
llama_context:  CUDA_Host  output buffer size =     0.95 MiB
llama_kv_cache:      CUDA0 KV buffer size =  2560.00 MiB
llama_kv_cache: size = 2560.00 MiB (131072 cells,  10 layers,  1/1 seqs), K (f16): 1280.00 MiB, V (f16): 1280.00 MiB
llama_kv_cache: attn_rot_k = 0, n_embd_head_k_all = 256
llama_kv_cache: attn_rot_v = 0, n_embd_head_k_all = 256
llama_memory_recurrent:      CUDA0 RS buffer size =    62.81 MiB
llama_memory_recurrent: size =   62.81 MiB (     1 cells,  41 layers,  1 seqs), R (f32):    2.81 MiB, S (f32):   60.00 MiB
sched_reserve: reserving ...
sched_reserve: Flash Attention was auto, set to enabled
sched_reserve: resolving fused Gated Delta Net support:
sched_reserve: fused Gated Delta Net (autoregressive) enabled
sched_reserve: fused Gated Delta Net (chunked) enabled
sched_reserve:      CUDA0 compute buffer size =   637.06 MiB
sched_reserve:  CUDA_Host compute buffer size =   264.02 MiB
sched_reserve: graph nodes  = 3729
sched_reserve: graph splits = 119 (with bs=512), 82 (with bs=1)
sched_reserve: reserve took 104.87 ms, sched copies = 1
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
srv    load_model: loading MTP head from '/home/z/hf/Qwen3.6-35B-A3B-MTP-UD-Q4_K_XL.gguf' (override_arch=qwen35moe_mtp)
llama_model_loader: loaded meta data with 55 key-value pairs and 753 tensors from /home/z/hf/Qwen3.6-35B-A3B-MTP-UD-Q4_K_XL.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = qwen35moe
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                     general.sampling.top_k i32              = 20
llama_model_loader: - kv   3:                     general.sampling.top_p f32              = 0.950000
llama_model_loader: - kv   4:                      general.sampling.temp f32              = 1.000000
llama_model_loader: - kv   5:                               general.name str              = Qwen3.6-35B-A3B
llama_model_loader: - kv   6:                           general.basename str              = Qwen3.6-35B-A3B
llama_model_loader: - kv   7:                       general.quantized_by str              = Unsloth
llama_model_loader: - kv   8:                         general.size_label str              = 35B-A3B
llama_model_loader: - kv   9:                            general.license str              = apache-2.0
llama_model_loader: - kv  10:                       general.license.link str              = https://huggingface.co/Qwen/Qwen3.6-3...
llama_model_loader: - kv  11:                           general.repo_url str              = https://huggingface.co/unsloth
llama_model_loader: - kv  12:                   general.base_model.count u32              = 1
llama_model_loader: - kv  13:                  general.base_model.0.name str              = Qwen3.6 35B A3B
llama_model_loader: - kv  14:          general.base_model.0.organization str              = Qwen
llama_model_loader: - kv  15:              general.base_model.0.repo_url str              = https://huggingface.co/Qwen/Qwen3.6-3...
llama_model_loader: - kv  16:                               general.tags arr[str,3]       = ["qwen3_5_moe", "qwen", "image-text-t...
llama_model_loader: - kv  17:                   qwen35moe.context_length u32              = 262144
llama_model_loader: - kv  18:                 qwen35moe.embedding_length u32              = 2048
llama_model_loader: - kv  19:             qwen35moe.attention.head_count u32              = 16
llama_model_loader: - kv  20:          qwen35moe.attention.head_count_kv u32              = 2
llama_model_loader: - kv  21:          qwen35moe.rope.dimension_sections arr[i32,4]       = [11, 11, 10, 0]
llama_model_loader: - kv  22:                   qwen35moe.rope.freq_base f32              = 10000000.000000
llama_model_loader: - kv  23: qwen35moe.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  24:                     qwen35moe.expert_count u32              = 256
llama_model_loader: - kv  25:                qwen35moe.expert_used_count u32              = 8
llama_model_loader: - kv  26:             qwen35moe.attention.key_length u32              = 256
llama_model_loader: - kv  27:           qwen35moe.attention.value_length u32              = 256
llama_model_loader: - kv  28:       qwen35moe.expert_feed_forward_length u32              = 512
llama_model_loader: - kv  29: qwen35moe.expert_shared_feed_forward_length u32              = 512
llama_model_loader: - kv  30:                  qwen35moe.ssm.conv_kernel u32              = 4
llama_model_loader: - kv  31:                   qwen35moe.ssm.state_size u32              = 128
llama_model_loader: - kv  32:                  qwen35moe.ssm.group_count u32              = 16
llama_model_loader: - kv  33:               qwen35moe.ssm.time_step_rank u32              = 32
llama_model_loader: - kv  34:                   qwen35moe.ssm.inner_size u32              = 4096
llama_model_loader: - kv  35:          qwen35moe.full_attention_interval u32              = 4
llama_model_loader: - kv  36:             qwen35moe.rope.dimension_count u32              = 64
llama_model_loader: - kv  37:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  38:                         tokenizer.ggml.pre str              = qwen35
llama_model_loader: - kv  39:                      tokenizer.ggml.tokens arr[str,248320]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  40:                  tokenizer.ggml.token_type arr[i32,248320]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  41:                      tokenizer.ggml.merges arr[str,247587]  = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "Ġ t",...
llama_model_loader: - kv  42:                tokenizer.ggml.eos_token_id u32              = 248046
llama_model_loader: - kv  43:            tokenizer.ggml.padding_token_id u32              = 248055
llama_model_loader: - kv  44:                tokenizer.ggml.bos_token_id u32              = 248044
llama_model_loader: - kv  45:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  46:                    tokenizer.chat_template str              = {%- set image_count = namespace(value...
llama_model_loader: - kv  47:               general.quantization_version u32              = 2
llama_model_loader: - kv  48:                          general.file_type u32              = 15
llama_model_loader: - kv  49:                      quantize.imatrix.file str              = Qwen3.6-35B-A3B-GGUF/imatrix_unsloth....
llama_model_loader: - kv  50:                   quantize.imatrix.dataset str              = unsloth_calibration_Qwen3.6-35B-A3B.txt
llama_model_loader: - kv  51:             quantize.imatrix.entries_count u32              = 510
llama_model_loader: - kv  52:              quantize.imatrix.chunks_count u32              = 76
llama_model_loader: - kv  53:                      qwen35moe.block_count u32              = 41
llama_model_loader: - kv  54:             qwen35moe.nextn_predict_layers u32              = 1
llama_model_loader: - type  f32:  368 tensors
llama_model_loader: - type q8_0:  265 tensors
llama_model_loader: - type q4_K:   78 tensors
llama_model_loader: - type q5_K:   38 tensors
llama_model_loader: - type q6_K:    4 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = Q4_K - Medium
print_info: file size   = 21.65 GiB (5.24 BPW) 
llama_model_create: overriding architecture qwen35moe -> qwen35moe_mtp
llama_prepare_model_devices: using device CUDA0 (NVIDIA RTX 2000 Ada Generation Laptop GPU) (0000:01:00.0) - 1173 MiB free
load: 0 unused tokens
load: printing all EOG tokens:
load:   - 248044 ('<|endoftext|>')
load:   - 248046 ('<|im_end|>')
load:   - 248063 ('<|fim_pad|>')
load:   - 248064 ('<|repo_name|>')
load:   - 248065 ('<|file_sep|>')
load: special tokens cache size = 33
load: token to piece cache size = 1.7581 MB
print_info: arch                  = qwen35moe_mtp
print_info: vocab_only            = 0
print_info: no_alloc              = 0
print_info: n_ctx_train           = 262144
print_info: n_embd                = 2048
print_info: n_embd_inp            = 2048
print_info: n_layer               = 41
print_info: n_head                = 16
print_info: n_head_kv             = 2
print_info: n_rot                 = 64
print_info: n_swa                 = 0
print_info: is_swa_any            = 0
print_info: n_embd_head_k         = 256
print_info: n_embd_head_v         = 256
print_info: n_gqa                 = 8
print_info: n_embd_k_gqa          = 512
print_info: n_embd_v_gqa          = 512
print_info: f_norm_eps            = 0.0e+00
print_info: f_norm_rms_eps        = 1.0e-06
print_info: f_clamp_kqv           = 0.0e+00
print_info: f_max_alibi_bias      = 0.0e+00
print_info: f_logit_scale         = 0.0e+00
print_info: f_attn_scale          = 0.0e+00
print_info: f_attn_value_scale    = 0.0000
print_info: n_ff                  = 0
print_info: n_expert              = 256
print_info: n_expert_used         = 8
print_info: n_expert_groups       = 0
print_info: n_group_used          = 0
print_info: causal attn           = 1
print_info: pooling type          = -1
print_info: rope type             = 40
print_info: rope scaling          = linear
print_info: freq_base_train       = 10000000.0
print_info: freq_scale_train      = 1
print_info: n_ctx_orig_yarn       = 262144
print_info: rope_yarn_log_mul     = 0.0000
print_info: rope_finetuned        = unknown
print_info: mrope sections        = [11, 11, 10, 0]
print_info: model type            = ?B
print_info: model params          = 35.51 B
print_info: general.name          = Qwen3.6-35B-A3B
print_info: vocab type            = BPE
print_info: n_vocab               = 248320
print_info: n_merges              = 247587
print_info: BOS token             = 248044 '<|endoftext|>'
print_info: EOS token             = 248046 '<|im_end|>'
print_info: EOT token             = 248046 '<|im_end|>'
print_info: PAD token             = 248055 '<|vision_pad|>'
print_info: LF token              = 198 'Ċ'
print_info: FIM PRE token         = 248060 '<|fim_prefix|>'
print_info: FIM SUF token         = 248062 '<|fim_suffix|>'
print_info: FIM MID token         = 248061 '<|fim_middle|>'
print_info: FIM PAD token         = 248063 '<|fim_pad|>'
print_info: FIM REP token         = 248064 '<|repo_name|>'
print_info: FIM SEP token         = 248065 '<|file_sep|>'
print_info: EOG token             = 248044 '<|endoftext|>'
print_info: EOG token             = 248046 '<|im_end|>'
print_info: EOG token             = 248063 '<|fim_pad|>'
print_info: EOG token             = 248064 '<|repo_name|>'
print_info: EOG token             = 248065 '<|file_sep|>'
print_info: max token length      = 256
load_tensors: loading model tensors, this can take a while... (mmap = true, direct_io = false)
done_getting_tensors: partial load — used 23 of 753 tensors in the file (rest belong to a sibling model on the same .gguf)
load_tensors: offloading output layer to GPU
load_tensors: offloading 40 repeating layers to GPU
load_tensors: offloaded 42/42 layers to GPU
load_tensors:   CPU_Mapped model buffer size = 21623.30 MiB
load_tensors:        CUDA0 model buffer size =   555.21 MiB
.....llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 131072
llama_context: n_ctx_seq     = 131072
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = auto
llama_context: kv_unified    = false
llama_context: freq_base     = 10000000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_seq (131072) < n_ctx_train (262144) -- the full capacity of the model will not be utilized
llama_context:  CUDA_Host  output buffer size =     0.95 MiB
llama_kv_cache:      CUDA0 KV buffer size =   256.00 MiB
llama_kv_cache: size =  256.00 MiB (131072 cells,   1 layers,  1/1 seqs), K (f16):  128.00 MiB, V (f16):  128.00 MiB
llama_kv_cache: attn_rot_k = 0, n_embd_head_k_all = 256
llama_kv_cache: attn_rot_v = 0, n_embd_head_k_all = 256
sched_reserve: reserving ...
sched_reserve: Flash Attention was auto, set to enabled
sched_reserve: resolving fused Gated Delta Net support:
sched_reserve: fused Gated Delta Net (autoregressive) enabled
sched_reserve: fused Gated Delta Net (chunked) enabled
ggml_backend_cuda_buffer_type_alloc_buffer: allocating 497.00 MiB on device 0: cudaMalloc failed: out of memory
ggml_gallocr_reserve_n_impl: failed to allocate CUDA0 buffer of size 521142272
graph_reserve: failed to allocate compute buffers
llama_init_from_model: failed to initialize the context: failed to allocate compute pp buffers
srv    load_model: failed to create MTP context
srv    operator(): operator(): cleaning up before exit...
main: exiting due to model loading error

It definitely worked before for whatever reason (though wired high CPU usage without specifying --cpu-moe), and it doesn't work now. Since the base commit did not change, maybe something in this PR changed and made it not working any more?

@ZisIsNotZis
Copy link
Copy Markdown

Also, I think there was merge conflict in the last working version and now there isn't anymore, and now rebasing on latest master also don't work

am17an and others added 5 commits May 13, 2026 11:13
* MTP: clean-up

* review: use llama_context_type instead of llama_graph_type

* review: remove llama_model_has_mtp

* review: fix convert issues

* convert: fix pycheck

* review: formatting

* use `mtp-` for identifying mtp models

* convert: fix mtp conversion
@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented May 13, 2026

With the latest changes you can now combine mtp with other speculative techniques like ngram-mod using --spec_type mtp,ngram-mod --spec-draft-max 64 etc. The partial rollback changes (which are useful for Qwen hybrid attention) will be potentially added later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Apple Metal https://en.wikipedia.org/wiki/Metal_(API) examples ggml changes relating to the ggml tensor library for machine learning model Model specific Nvidia GPU Issues specific to Nvidia GPUs python python script changes server testing Everything test related Vulkan Issues specific to the Vulkan backend

Projects

None yet

Development

Successfully merging this pull request may close these issues.