Skip to content

Conversation

@ryan-mangeno
Copy link
Contributor

@ryan-mangeno ryan-mangeno commented Aug 28, 2025

adding support to run granite embedding small, and it primarily pulls the modern bert architecture - https://huggingface.co/ibm-granite/granite-embedding-small-english-r2, currently working on it still, havent figured out the pre-tokenizer type or if I need to impliment it, also for the ubatch size the assert fails in llama-graph.cpp, hacked it to accept ubatch size of 1 for testing, but it seems to keep failing there and not sure why,

if I comment out of the line in llama-graph.cpp

assert(!ubatch.equal_seqs());

then it works

@ryan-mangeno ryan-mangeno marked this pull request as draft August 28, 2025 17:05
@ryan-mangeno
Copy link
Contributor Author

ryan-mangeno commented Aug 28, 2025

@gabe-l-hart thanks in advance :)

@ryan-mangeno
Copy link
Contributor Author

@gabe-l-hart thanks in advance :)

also realizing this a little late haha, but should I be changing all of the modern bert stuff to a granite embedding macro like LLM_ARCH_GRANITE_EMBD or keep it as is

@CISC
Copy link
Collaborator

CISC commented Aug 28, 2025

You may want to check out an earlier attempt at ModernBert in #14014

@gabe-l-hart
Copy link
Collaborator

Thanks for getting this together @ryan-mangeno and thanks for pointing out the previous work @CISC. Ryan, let me know if/when you've looked over that PR and found anything to fix and I'll take a pass at review.

@gabe-l-hart
Copy link
Collaborator

also realizing this a little late haha, but should I be changing all of the modern bert stuff to a granite embedding macro like LLM_ARCH_GRANITE_EMBD or keep it as is

In general, we want to keep things as generic as possible, so since this uses the ModernBertModel architecture from transformers, it's best to keep the implementation here similarly robust unless there's a concrete reason to subset the transformers architecture to just work for granite (eg there's some non-trivial code path in the transformers version that would make sense as a separate architecture).

@github-actions github-actions bot added the python python script changes label Aug 28, 2025
@ryan-mangeno
Copy link
Contributor Author

Thanks for getting this together @ryan-mangeno and thanks for pointing out the previous work @CISC. Ryan, let me know if/when you've looked over that PR and found anything to fix and I'll take a pass at review.

will do

@ryan-mangeno
Copy link
Contributor Author

ryan-mangeno commented Sep 3, 2025

@gabe-l-hart im looking into modern berts research paper, I cant find a mention of symmetric sliding window attention but rather local sliding window attention so I am going to opt to use LLAMA_SWA_TYPE_LOCAL versus LLAMA_SWA_TYPE_SYMMETRIC used in the previous attempt. It also uses global attention every third layer so I am going to implement this stuff and then it should be ready for a review :)

@gabe-l-hart
Copy link
Collaborator

@ryan-mangeno That sounds good! I haven't unpacked any of those mechanics myself, but can try to get into it if you get stuck.

… per previous attempt, added local sliding window attention that alternates every third layer
@ryan-mangeno
Copy link
Contributor Author

@ryan-mangeno That sounds good! I haven't unpacked any of those mechanics myself, but can try to get into it if you get stuck.

ok 👍 , made some changes but not sure if its fully ready yet, I will ping you when I think its ready if thats ok

@ryan-mangeno
Copy link
Contributor Author

ryan-mangeno commented Sep 4, 2025

status update - I found out that modern bert uses an alternating rope method , per https://arxiv.org/pdf/2412.13663

In ModernBERT, every third layer employs global
attention with a RoPE theta of 160,000 and the
remaining layers use a 128 token, local sliding window attention with a RoPE theta of 10,000.

I am currently figuring out how to implement this

@ryan-mangeno
Copy link
Contributor Author

@gabe-l-hart I believe this should be ready for review again, I added a new hparam -> n_swa_pattern which now works in the conversion script and can be pulled during model loading rather than it being hardcoded. Let me know of any changes, finals are coming up so I might be a bit slow for the next week just fyi

@gabe-l-hart
Copy link
Collaborator

@ryan-mangeno it looks like there are some conflicts that need resolving. Can you merge in origin/master? Once the conflicts are clean, I'll dig into review (and we can hopefully get it moved forward!)

@gabe-l-hart
Copy link
Collaborator

It looks like these failing ubuntu tests are all hitting something along these lines:

2025-12-06T00:49:31.4614055Z 38: Test command: /home/runner/work/llama.cpp/llama.cpp/build/bin/test-quantize-perf
2025-12-06T00:49:31.4614706Z 38: Working Directory: .
2025-12-06T00:49:31.4615062Z 38: Test timeout computed to be: 900
2025-12-06T00:49:31.6862442Z 36/39 Test #38: test-quantize-perf ................***Exception: Illegal  0.22 sec
2025-12-06T00:49:31.6863602Z test 39
2: Test command: /home/runner/work/llama.cpp/llama.cpp/build/bin/test-tokenizer-0 "/home/runner/work/llama.cpp/llama.cpp/models/ggml-vocab-command-r.gguf"
2: Working Directory: .
2: Test timeout computed to be: 900
2: main : reading vocab from: '/home/runner/work/llama.cpp/llama.cpp/models/ggml-vocab-command-r.gguf'
 2/38 Test  #2: test-tokenizer-0-command-r ........***Exception: Illegal  0.23 sec
test 3

I'm not clear if these are real issues, but I haven't seen these tests fail randomly on other PRs.

@CISC
Copy link
Collaborator

CISC commented Dec 10, 2025

I'm not clear if these are real issues, but I haven't seen these tests fail randomly on other PRs.

They are not real, it's a ccache issue, somehow the cache gets poisoned by incompatible binaries, not sure what is going on, but the only way to solve it is to find and delete the affected cache in Actions->Caches (or wait for it to get purged), then rerun job if in doubt.

Edit: Keep in mind that a PR branch job will pick the default branch cache if no other is available. I do routinely make sure to delete corrupt caches on master though.

@gabe-l-hart
Copy link
Collaborator

There are only two hard things in Computer Science...

@ryan-mangeno
Copy link
Contributor Author

I'm not clear if these are real issues, but I haven't seen these tests fail randomly on other PRs.

They are not real, it's a ccache issue, somehow the cache gets poisoned by incompatible binaries, not sure what is going on, but the only way to solve it is to find and delete the affected cache in Actions->Caches (or wait for it to get purged), then rerun job if in doubt.

Edit: Keep in mind that a PR branch job will pick the default branch cache if no other is available. I do routinely make sure to delete corrupt caches on master though.

I just deleted cache in actions and rerunning now, is that all that needs to be done as of now regarding that job?

@CISC
Copy link
Collaborator

CISC commented Dec 15, 2025

I just deleted cache in actions and rerunning now, is that all that needs to be done as of now regarding that job?

The cache is long gone by now, it gets purged on a daily basis (hourly when exceeding 10GB).

@ryan-mangeno
Copy link
Contributor Author

I just deleted cache in actions and rerunning now, is that all that needs to be done as of now regarding that job?

The cache is long gone by now, it gets purged on a daily basis (hourly when exceeding 10GB).

hmm, it seems to be failing again (ubuntu-cpu-cmake (x64, ubuntu-22.04) here is a excert from the logs

2025-12-15T17:18:22.9944131Z 19: Failed to parse up to error: [json.exception.parse_error.101] parse error at line 2, column 48: syntax error while parsing value - invalid string: missing closing quote; last read: '"special': <<<[
2025-12-15T17:18:22.9945354Z 19:     {"tool_call_id": "0", "tool_name": "special>>>

is this related to my implementation and if so any idea where to start debugging? thanks!

@CISC
Copy link
Collaborator

CISC commented Dec 15, 2025

hmm, it seems to be failing again (ubuntu-cpu-cmake (x64, ubuntu-22.04) here is a excert from the logs

2025-12-15T17:18:22.9944131Z 19: Failed to parse up to error: [json.exception.parse_error.101] parse error at line 2, column 48: syntax error while parsing value - invalid string: missing closing quote; last read: '"special': <<<[
2025-12-15T17:18:22.9945354Z 19:     {"tool_call_id": "0", "tool_name": "special>>>

is this related to my implementation and if so any idea where to start debugging? thanks!

Nope, those are expected failures, what is actually failing are these:

2025-12-15T17:18:36.1932270Z The following tests FAILED:
2025-12-15T17:18:36.1932750Z 	 29 - test-thread-safety (ILLEGAL)                      main
2025-12-15T17:18:36.1933333Z 	 31 - test-opt (ILLEGAL)                                main
2025-12-15T17:18:36.1933883Z 	 36 - test-barrier (ILLEGAL)                            main
2025-12-15T17:18:36.1934463Z 	 37 - test-quantize-fns (ILLEGAL)                       main
2025-12-15T17:18:36.1935043Z 	 38 - test-quantize-perf (ILLEGAL)                      main
2025-12-15T17:18:36.1935719Z 	 42 - test-eval-callback (ILLEGAL)                      curl eval-callback

..but it has nothing to do with your PR, just ignore it.

@ryan-mangeno
Copy link
Contributor Author

hmm, it seems to be failing again (ubuntu-cpu-cmake (x64, ubuntu-22.04) here is a excert from the logs

2025-12-15T17:18:22.9944131Z 19: Failed to parse up to error: [json.exception.parse_error.101] parse error at line 2, column 48: syntax error while parsing value - invalid string: missing closing quote; last read: '"special': <<<[
2025-12-15T17:18:22.9945354Z 19:     {"tool_call_id": "0", "tool_name": "special>>>

is this related to my implementation and if so any idea where to start debugging? thanks!

Nope, those are expected failures, what is actually failing are these:

2025-12-15T17:18:36.1932270Z The following tests FAILED:
2025-12-15T17:18:36.1932750Z 	 29 - test-thread-safety (ILLEGAL)                      main
2025-12-15T17:18:36.1933333Z 	 31 - test-opt (ILLEGAL)                                main
2025-12-15T17:18:36.1933883Z 	 36 - test-barrier (ILLEGAL)                            main
2025-12-15T17:18:36.1934463Z 	 37 - test-quantize-fns (ILLEGAL)                       main
2025-12-15T17:18:36.1935043Z 	 38 - test-quantize-perf (ILLEGAL)                      main
2025-12-15T17:18:36.1935719Z 	 42 - test-eval-callback (ILLEGAL)                      curl eval-callback

..but it has nothing to do with your PR, just ignore it.

ok thanks! will be on the lookout for review fixes

@gabe-l-hart
Copy link
Collaborator

One note as I'm testing this: When I run convert_hf_to_gguf.py with my local version of transformers (c67ec2c4c1) which is past the 5.0 version change, I see errors about KeyError: 'global_rope_theta'. It seems that this has something to do with changes in the config structure for the model architecture that happened during the 5.0 version (we've seen a lot of this kind of break for other Granite models), so it's not an issue to hold up this PR, but it likely means that this won't work with transformers>4 until that is fixed.

@gabe-l-hart
Copy link
Collaborator

After some dependency flailing, I was able to get my comparison script running again, and the side-by-side results with sentence-transformers look good!

granite_embed.py
from sentence_transformers import SentenceTransformer
import numpy as np
import subprocess
import shlex

model_path = "/Users/ghart/models/ibm-granite/granite-embedding-small-english-r2/"
model = SentenceTransformer(model_path)

input_queries = [
    "hello world",
    "tell me a story about a developer and their dog",
    "123sfg this is a r@nd0m t35t",
]


def cosine_similarity(vector_a: np.ndarray, vector_b: np.ndarray) -> float:
    vector_a = np.asarray(vector_a)
    vector_b = np.asarray(vector_b)
    numerator = np.dot(vector_a, vector_b)
    denominator_a = np.linalg.norm(vector_a)
    denominator_b = np.linalg.norm(vector_b)
    if denominator_a == 0 or denominator_b == 0: return 0.0
    cosine_sim = numerator / (denominator_a * denominator_b)
    return cosine_sim


for query in input_queries:
    print("### BASELINE ###")
    embedding = model.encode([query])
    print("Embedding shape:", embedding.shape)
    print("Embedding vector:", embedding[:, :8])

    print("### llama.cpp ###")
    lcpp_exe = "/Users/ghart/Projects/github/ggml-org/llama.cpp/build/bin/llama-embedding"
    lcpp_model = f"{model_path}/granite-embedding-small-english-r2-BF16.gguf"
    cmd = f"{lcpp_exe} -m {lcpp_model} -p \"{query}\" --temp 0 --embd-normalize -1"
    print(f"llama.cpp command: {cmd}")
    proc = subprocess.Popen(
        shlex.split(cmd),
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
    )
    out, _ = proc.communicate()
    vals = out.decode("utf-8").split(":")[-1]
    vals = [
        float(v) for v in vals.split()
        if v.strip()
    ]
    lcpp_emb = np.array(vals)
    print("llama.cpp Embedding shape:", lcpp_emb.shape)
    print("llama.cpp Embedding vector:", lcpp_emb[:8])
    print()
    cos_sim = cosine_similarity(embedding, lcpp_emb)
    print(f"COSINE SIMILARITY: {cos_sim}")
    print("--------------------------------")
    print()
### BASELINE ###
Embedding shape: (1, 384)
Embedding vector: [[ 0.47021735 -0.08181904 -0.9702138   0.101168   -0.16487186 -0.41284022
  -0.28690544 -0.6374487 ]]
### llama.cpp ###
llama.cpp command: /Users/ghart/Projects/github/ggml-org/llama.cpp/build/bin/llama-embedding -m /Users/ghart/models/ibm-granite/granite-embedding-small-english-r2//granite-embedding-small-english-r2-BF16.gguf -p "hello world" --temp 0 --embd-normalize -1
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
llama.cpp Embedding shape: (384,)
llama.cpp Embedding vector: [ 0.470165 -0.081821 -0.969903  0.101372 -0.164882 -0.41247  -0.287036
 -0.63773 ]

COSINE SIMILARITY: [0.99999991]
--------------------------------

### BASELINE ###
Embedding shape: (1, 384)
Embedding vector: [[ 1.2659577   0.05745669 -0.12995665  1.3856934   0.06200508 -1.2863929
  -0.2949096   1.1680874 ]]
### llama.cpp ###
llama.cpp command: /Users/ghart/Projects/github/ggml-org/llama.cpp/build/bin/llama-embedding -m /Users/ghart/models/ibm-granite/granite-embedding-small-english-r2//granite-embedding-small-english-r2-BF16.gguf -p "tell me a story about a developer and their dog" --temp 0 --embd-normalize -1
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
llama.cpp Embedding shape: (384,)
llama.cpp Embedding vector: [ 1.261374  0.054074 -0.127809  1.384826  0.063593 -1.281999 -0.29321
  1.165062]

COSINE SIMILARITY: [0.99999489]
--------------------------------

### BASELINE ###
Embedding shape: (1, 384)
Embedding vector: [[ 0.46219376 -0.22369052 -1.0632571   0.92421246  0.8207397  -0.04330811
  -0.4359354  -0.04913762]]
### llama.cpp ###
llama.cpp command: /Users/ghart/Projects/github/ggml-org/llama.cpp/build/bin/llama-embedding -m /Users/ghart/models/ibm-granite/granite-embedding-small-english-r2//granite-embedding-small-english-r2-BF16.gguf -p "123sfg this is a r@nd0m t35t" --temp 0 --embd-normalize -1
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
llama.cpp Embedding shape: (384,)
llama.cpp Embedding vector: [ 0.458999 -0.229462 -1.061105  0.923894  0.817016 -0.048115 -0.43095
 -0.04808 ]

COSINE SIMILARITY: [0.99999438]
--------------------------------

@CISC
Copy link
Collaborator

CISC commented Dec 17, 2025

One note as I'm testing this: When I run convert_hf_to_gguf.py with my local version of transformers (c67ec2c4c1) which is past the 5.0 version change, I see errors about KeyError: 'global_rope_theta'. It seems that this has something to do with changes in the config structure for the model architecture that happened during the 5.0 version (we've seen a lot of this kind of break for other Granite models), so it's not an issue to hold up this PR, but it likely means that this won't work with transformers>4 until that is fixed.

This is actually handled now, but requires a little more care for local_rope_theta.

Copy link
Collaborator

@gabe-l-hart gabe-l-hart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One small naming NIT, but that aside, I think this is ready to go! Thank you @ryan-mangeno

SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"
SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern"
TEMPERATURE_SCALE = "{arch}.attention.temperature_scale"
DENSE_EVERY_N_LAYERS = "{arch}.attention.dense_every_n_layers"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we now have several different ways to indicate the layer type for different hybrid patterns. This DENSE_EVERY_N_LAYERS implies a strictly periodic pattern. The above SLIDING_WINDOW_PATTERN appears to be for interleaved sliding window hybrid architectures, but strangely I don't see the corresponding key anywhere on the c++ side (here). We also use the fact that head_count_kv can be either a scalar or a list of scalars to differentiate recurrent vs attention layers for GraniteHybrid (here).

Since this is already fairly confusing and potentially redundent, I don't think we need to hold up this PR, but I'm curious if others can think of a clean way to accomplish the goal of layer type designation without a net-new hparam.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think sliding_window_pattern and dense_every_n_layers are basically the same thing, but is indeed missing support (probably forgotten at some stage), all models seem to just hardcode the hparams.set_swa_pattern.

I wonder if we can just reuse the sliding_window_pattern key instead and differentiate between loading a single integer and an array of bools instead of introducing a new key?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I think what we can do is make an overload of llama_model_loader::get_key_or_arr that has uin32_t as result (and no n) which simply fails if the metadata is an array.

That way we can ml.get_key_or_arr the sliding_window_pattern into an uint32_t (that has the previously hardcoded value as default) we feed to hparams.set_swa_pattern.

ryan-mangeno and others added 3 commits December 18, 2025 21:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Model specific python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants