diff --git a/.idea/workspace.xml b/.idea/workspace.xml
index cc065936b..486215dad 100644
--- a/.idea/workspace.xml
+++ b/.idea/workspace.xml
@@ -4,7 +4,10 @@
-
+
+
+
+
@@ -18,21 +21,21 @@
- {
+ "lastFilter": {
+ "state": "OPEN",
+ "assignee": "momochen"
}
-}]]>
-
+ {
+ "selectedUrlAndAccountId": {
+ "url": "https://github.com/momochen/Liger-Kernel",
+ "accountId": "639f3e12-86db-4b12-a409-51cc017415fb"
}
-}]]>
-
+}
+ {
+ "associatedIndex": 5
+}
@@ -42,9 +45,21 @@
"keyToString": {
"RunOnceActivity.ShowReadmeOnStart": "true",
"git-widget-placeholder": "ref__unsloth",
- "last_opened_file_path": "/Users/ychen/workspace/github/Liger-Kernel"
+ "go.import.settings.migrated": "true",
+ "last_opened_file_path": "/Users/ychen/workspace/github/Liger-Kernel",
+ "node.js.detected.package.eslint": "true",
+ "node.js.detected.package.tslint": "true",
+ "node.js.selected.package.eslint": "(autodetect)",
+ "node.js.selected.package.tslint": "(autodetect)"
}
}]]>
+
+
+
+
+
+
+
diff --git a/src/liger_kernel/transformers/model/gemma3.py b/src/liger_kernel/transformers/model/gemma3.py
index 317dc8ebd..6b369a0d3 100644
--- a/src/liger_kernel/transformers/model/gemma3.py
+++ b/src/liger_kernel/transformers/model/gemma3.py
@@ -208,7 +208,7 @@ def multimodal_forward(
is_training = token_type_ids is not None and labels is not None
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
- if input_ids is not None and self.config.image_token_index >= self.vocab_size:
+ if input_ids is not None and self.config.image_token_index >= self.config.text_config.vocab_size:
special_image_mask = input_ids == self.config.image_token_index
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
@@ -250,17 +250,17 @@ def multimodal_forward(
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
# mask out pad-token-ids in labels for BC
- if labels is not None and self.pad_token_id in labels:
+ if labels is not None and self.config.pad_token_id in labels:
logger.warning_once(
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
)
labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
- causal_mask = self._update_causal_mask(
+ causal_mask = self.model._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
)
- outputs = self.language_model.model(
+ outputs = self.language_model(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
@@ -304,9 +304,9 @@ def multimodal_forward(
shift_labels = shift_labels.view(-1).to(hidden_device)
lce = LigerFusedLinearCrossEntropyLoss()
- loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
else:
- logits = self.language_model.lm_head(hidden_states)
+ logits = self.lm_head(hidden_states)
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()