Skip to content

Commit 0e5d5e5

Browse files
Fix formatting
1 parent 5ef2189 commit 0e5d5e5

File tree

8 files changed

+89
-44
lines changed

8 files changed

+89
-44
lines changed

src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,10 @@ protected GGMLTensorEntry getOutputWeight(Map<String, GGMLTensorEntry> tensorEnt
152152
/**
153153
* Create standard (CPU) weights.
154154
*/
155-
protected abstract Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, C config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
156-
GGMLTensorEntry outputWeight);
155+
protected abstract Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, C config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight);
157156

158157
/**
159158
* Create TornadoVM (GPU) weights.
160159
*/
161-
protected abstract Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, C config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
162-
GGMLTensorEntry outputWeight);
160+
protected abstract Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, C config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight);
163161
}

src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ protected Tokenizer createTokenizer(Map<String, Object> metadata, Vocabulary voc
4242
return new LlamaTokenizer(metadata, vocabulary);
4343
}
4444

45+
// @formatter:off
4546
@Override
4647
protected LlamaConfiguration createConfiguration(Map<String, Object> metadata) {
4748
int vocabSize = metadata.containsKey("llama.vocab_size") ? (int) metadata.get("llama.vocab_size") : (int) metadata.get("tokenizer.ggml.tokens.length");
@@ -59,18 +60,19 @@ protected LlamaConfiguration createConfiguration(Map<String, Object> metadata) {
5960
(float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f),
6061
(float) metadata.getOrDefault("llama.rope.freq_base", 10000f)).withContextLength(contextLength);
6162
}
63+
// @formatter:on
6264

6365
@Override
6466
protected Pair<float[], float[]> precomputeRopeFrequencies(LlamaConfiguration config) {
65-
return RoPE.precomputeFreqsCis(config.contextLength(), config.dim() / config.numberOfHeads(), config.ropeTheta(), false, 1.0f, 1.0f, 1.0f, config.contextLength()
66-
);
67+
return RoPE.precomputeFreqsCis(config.contextLength(), config.dim() / config.numberOfHeads(), config.ropeTheta(), false, 1.0f, 1.0f, 1.0f, config.contextLength());
6768
}
6869

6970
@Override
7071
protected Llama createModel(LlamaConfiguration config, Tokenizer tokenizer, Weights weights) {
7172
return new Llama(config, tokenizer, weights, ChatFormat.create(tokenizer, null));
7273
}
7374

75+
// @formatter:off
7476
@Override
7577
protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, LlamaConfiguration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
7678
GGMLTensorEntry outputWeight) {
@@ -94,7 +96,9 @@ protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntri
9496
loadTensor(outputWeight),
9597
outputWeight.ggmlType());
9698
}
99+
// @formatter:on
97100

101+
// @formatter:off
98102
@Override
99103
protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries,
100104
LlamaConfiguration config,
@@ -133,4 +137,5 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr
133137
ggmlType
134138
);
135139
}
140+
// @formatter:on
136141
}

src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ protected Tokenizer createTokenizer(Map<String, Object> metadata, Vocabulary voc
4040
return new MistralTokenizer(metadata, vocabulary);
4141
}
4242

43+
// @formatter:off
4344
@Override
4445
protected MistralConfiguration createConfiguration(Map<String, Object> metadata) {
4546
int modelContextLength = (int) metadata.get("llama.context_length");
@@ -48,29 +49,47 @@ protected MistralConfiguration createConfiguration(Map<String, Object> metadata)
4849
// Get vocabulary size from metadata
4950
int vocabSize = metadata.containsKey("llama.vocab_size") ? (int) metadata.get("llama.vocab_size") : (int) metadata.get("tokenizer.ggml.tokens.length");
5051

51-
return new MistralConfiguration((int) metadata.get("llama.embedding_length"), (int) metadata.get("llama.feed_forward_length"), (int) metadata.get("llama.block_count"),
52+
return new MistralConfiguration(
53+
(int) metadata.get("llama.embedding_length"),
54+
(int) metadata.get("llama.feed_forward_length"),
55+
(int) metadata.get("llama.block_count"),
5256
(int) metadata.get("llama.attention.head_count"),
53-
54-
metadata.containsKey("llama.attention.head_count_kv") ? (int) metadata.get("llama.attention.head_count_kv") : (int) metadata.get("llama.attention.head_count"),
55-
56-
vocabSize, finalContextLength, false, (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f),
57-
(float) metadata.getOrDefault("llama.rope.freq_base", 10000f));
57+
metadata.containsKey("llama.attention.head_count_kv") ?
58+
(int) metadata.get("llama.attention.head_count_kv")
59+
: (int) metadata.get("llama.attention.head_count"),
60+
vocabSize,
61+
finalContextLength,
62+
false,
63+
(float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f),
64+
(float) metadata.getOrDefault("llama.rope.freq_base", 10000f)
65+
);
5866
}
67+
// @formatter:on
5968

69+
// @formatter:off
6070
@Override
6171
protected Pair<float[], float[]> precomputeRopeFrequencies(MistralConfiguration config) {
62-
return RoPE.precomputeFreqsCis(config.contextLength(), config.dim() / config.numberOfHeads(), config.ropeTheta(), false, 1.0f, 1.0f, 1.0f, config.contextLength()
72+
return RoPE.precomputeFreqsCis(
73+
config.contextLength(),
74+
config.dim() / config.numberOfHeads(),
75+
config.ropeTheta(),
76+
false,
77+
1.0f,
78+
1.0f,
79+
1.0f,
80+
config.contextLength()
6381
);
6482
}
83+
// @formatter:on
6584

6685
@Override
6786
protected Mistral createModel(MistralConfiguration config, Tokenizer tokenizer, Weights weights) {
6887
return new Mistral(config, tokenizer, weights, ChatFormat.create(tokenizer, null));
6988
}
7089

90+
// @formatter:off
7191
@Override
72-
protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, MistralConfiguration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
73-
GGMLTensorEntry outputWeight) {
92+
protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, MistralConfiguration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) {
7493

7594
final int nl = config.numberOfLayers();
7695

@@ -91,10 +110,11 @@ protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntri
91110
loadTensor(outputWeight),
92111
outputWeight.ggmlType());
93112
}
113+
// @formatter:off
94114

115+
// @formatter:off
95116
@Override
96-
protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, MistralConfiguration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
97-
GGMLTensorEntry outputWeight) {
117+
protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, MistralConfiguration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) {
98118
GGMLType ggmlType = outputWeight.ggmlType();
99119

100120
if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
@@ -127,4 +147,5 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr
127147
ggmlType
128148
);
129149
}
150+
// @formatter:on
130151
}

src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,16 @@ private static ModelType detectModelType(Map<String, Object> metadata) {
6161
} else if (lowerName.contains("phi3") || lowerName.contains("phi-3")) {
6262
return ModelType.PHI_3;
6363
}
64-
6564
}
6665

6766
return ModelType.UNKNOWN;
6867
}
6968

7069
/**
7170
* Loads the language model based on the given options.
72-
* <p>
73-
* If Ahead-of-Time (AOT) mode is enabled, attempts to use a pre-loaded compiled model. Otherwise, loads the model from the specified path using the model loader.
74-
* </p>
71+
*
72+
* <p>If Ahead-of-Time (AOT) mode is enabled, attempts to use a pre-loaded compiled model.
73+
* Otherwise, loads the model from the specified path using the model loader.
7574
*
7675
* @param options the parsed CLI options containing model path and max token limit
7776
* @return the loaded {@link Model} instance
@@ -279,5 +278,4 @@ public static FloatBuffer toFloatBuffer(GGMLTensorEntry tensorEntry) {
279278
default -> throw new UnsupportedOperationException("Conversion to " + ggmlType);
280279
};
281280
}
282-
283281
}

src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ protected Tokenizer createTokenizer(Map<String, Object> metadata, Vocabulary voc
4646
return new Phi3Tokenizer(metadata, vocabulary);
4747
}
4848

49+
// @formatter:off
4950
@Override
5051
protected Phi3Configuration createConfiguration(Map<String, Object> metadata) {
5152
final String modelPrefix = "phi3.";
@@ -67,18 +68,26 @@ protected Phi3Configuration createConfiguration(Map<String, Object> metadata) {
6768
);
6869
return config;
6970
}
71+
// @formatter:off
7072

73+
// @formatter:off
7174
@Override
7275
protected Pair<float[], float[]> precomputeRopeFrequencies(Phi3Configuration config) {
7376
// Calculate head size from dim and numberOfHeads
7477
int headSize = config.dim() / config.numberOfHeads();
7578

76-
return RoPE.precomputeFreqsCis(modelContextLength, // Use model context length for RoPE precomputation
77-
headSize, // Calculated head size
78-
config.ropeTheta(), false, // Phi3 uses standard RoPE, not neox-style based on reference
79-
8, 1, 3, 8192 // Additional RoPE parameters from reference
79+
return RoPE.precomputeFreqsCis(
80+
modelContextLength, // Use model context length for RoPE precomputation
81+
headSize, // Calculated head size
82+
config.ropeTheta(),
83+
false, // Phi3 uses standard RoPE, not neox-style based on reference
84+
8,
85+
1,
86+
3,
87+
8192 // Additional RoPE parameters from reference
8088
);
8189
}
90+
// @formatter:off
8291

8392
@Override
8493
protected Phi3 createModel(Phi3Configuration config, Tokenizer tokenizer, Weights weights) {
@@ -88,33 +97,34 @@ protected Phi3 createModel(Phi3Configuration config, Tokenizer tokenizer, Weight
8897
return new Phi3(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens));
8998
}
9099

100+
// @formatter:off
91101
@Override
92-
protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, Phi3Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
93-
GGMLTensorEntry outputWeight) {
102+
protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, Phi3Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) {
94103
float[] ropeFreqsReal = ropeFreqs.first();
95104
float[] ropeFreqsImag = ropeFreqs.second();
96105

97106
final int nl = config.numberOfLayers();
98107

99108
return new Phi3StandardWeights(
100-
loadTensor(tokenEmbeddings), // token_embedding_table
109+
loadTensor(tokenEmbeddings), // token_embedding_table
101110
loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // rms_att_weight (as FloatTensor[])
102111
loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), // wqkv (combined)
103112
loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo
104113
loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), // rms_ffn_weight (as FloatTensor[])
105114
loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // wDown
106115
loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // wUp (separate, not combined)
107-
loadTensor(tensorEntries.get("output_norm.weight")), // rms_final_weight (as FloatTensor)
108-
new ArrayFloatTensor(ropeFreqsReal), // freq_cis_real
109-
new ArrayFloatTensor(ropeFreqsImag), // freq_cis_imag
110-
loadTensor(outputWeight), // wcls
111-
outputWeight.ggmlType() // weightType
116+
loadTensor(tensorEntries.get("output_norm.weight")), // rms_final_weight (as FloatTensor)
117+
new ArrayFloatTensor(ropeFreqsReal), // freq_cis_real
118+
new ArrayFloatTensor(ropeFreqsImag), // freq_cis_imag
119+
loadTensor(outputWeight), // wcls
120+
outputWeight.ggmlType() // weightType
112121
);
113122
}
123+
// @formatter:on
114124

125+
// @formatter:off
115126
@Override
116-
protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, Phi3Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
117-
GGMLTensorEntry outputWeight) {
127+
protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, Phi3Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) {
118128
GGMLType ggmlType = outputWeight.ggmlType();
119129

120130
if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
@@ -144,4 +154,5 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr
144154
ggmlType
145155
);
146156
}
157+
// @formatter:on
147158
}

src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ protected Tokenizer createTokenizer(Map<String, Object> metadata, Vocabulary voc
4242
return new Qwen3Tokenizer(metadata, vocabulary, isDeepSeekR1DistillQwen);
4343
}
4444

45+
// @formatter:off
4546
@Override
4647
protected Qwen2Configuration createConfiguration(Map<String, Object> metadata) {
4748
int modelContextLength = (int) metadata.get("qwen2.context_length");
@@ -68,12 +69,14 @@ protected Qwen2Configuration createConfiguration(Map<String, Object> metadata) {
6869
(float) metadata.get("qwen2.rope.freq_base")
6970
);
7071
}
72+
// @formatter:on
7173

7274
@Override
7375
protected Pair<float[], float[]> precomputeRopeFrequencies(Qwen2Configuration config) {
7476
return RoPE.precomputeFreqsCis(config.contextLengthModel(), config.headSize(), config.ropeTheta(), false, 8, 1, 3, 8192);
7577
}
7678

79+
// @formatter:off
7780
@Override
7881
protected Qwen2 createModel(Qwen2Configuration config, Tokenizer tokenizer, Weights weights) {
7982
Map<String, Object> metadata = gguf.getMetadata();
@@ -83,7 +86,9 @@ protected Qwen2 createModel(Qwen2Configuration config, Tokenizer tokenizer, Weig
8386
: new ChatTokens("<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>");
8487
return new Qwen2(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens));
8588
}
89+
// @formatter:on
8690

91+
// @formatter:off
8792
@Override
8893
protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, Qwen2Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
8994
GGMLTensorEntry outputWeight) {
@@ -111,7 +116,9 @@ protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntri
111116
outputWeight.ggmlType()
112117
);
113118
}
119+
// @formatter:on
114120

121+
// @formatter:off
115122
@Override
116123
protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, Qwen2Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
117124
GGMLTensorEntry outputWeight) {
@@ -152,4 +159,5 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr
152159
);
153160

154161
}
162+
// @formatter:off
155163
}

src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ protected Tokenizer createTokenizer(Map<String, Object> metadata, Vocabulary voc
4343
return new Qwen3Tokenizer(metadata, vocabulary, isDeepSeekR1DistillQwen);
4444
}
4545

46+
// @formatter:off
4647
@Override
4748
protected Qwen3Configuration createConfiguration(Map<String, Object> metadata) {
4849
int modelContextLength = (int) metadata.get("qwen3.context_length");
@@ -70,12 +71,14 @@ protected Qwen3Configuration createConfiguration(Map<String, Object> metadata) {
7071
(float) metadata.get("qwen3.rope.freq_base")
7172
);
7273
}
74+
// @formatter:on
7375

7476
@Override
7577
protected Pair<float[], float[]> precomputeRopeFrequencies(Qwen3Configuration config) {
7678
return RoPE.precomputeFreqsCis(config.contextLengthModel(), config.numberOfHeadsKey(), config.ropeTheta(), false, 0, 0, 0, 0);
7779
}
7880

81+
// @formatter:off
7982
@Override
8083
protected Qwen3 createModel(Qwen3Configuration config, Tokenizer tokenizer, Weights weights) {
8184
Map<String, Object> metadata = gguf.getMetadata();
@@ -85,7 +88,9 @@ protected Qwen3 createModel(Qwen3Configuration config, Tokenizer tokenizer, Weig
8588
: new ChatTokens("<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>");
8689
return new Qwen3(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens));
8790
}
91+
// @formatter:off
8892

93+
// @formatter:off
8994
@Override
9095
protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, Qwen3Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
9196
GGMLTensorEntry outputWeight) {
@@ -116,7 +121,9 @@ protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntri
116121
null
117122
);
118123
}
124+
// @formatter:on
119125

126+
// @formatter:off
120127
@Override
121128
protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, Qwen3Configuration config,
122129
Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
@@ -151,4 +158,5 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr
151158
);
152159

153160
}
161+
// @formatter:on
154162
}

0 commit comments

Comments
 (0)