Skip to content

add support for gemma4 model#1655

Open
n1ck-guo wants to merge 2 commits intomainfrom
hengguo/support_for_gemma4
Open

add support for gemma4 model#1655
n1ck-guo wants to merge 2 commits intomainfrom
hengguo/support_for_gemma4

Conversation

@n1ck-guo
Copy link
Copy Markdown
Contributor

@n1ck-guo n1ck-guo commented Apr 3, 2026

Description

Please briefly describe your main changes, the motivation.

Type of Change

  • Bug fix
  • New feature
  • Documentation update
  • Performance improvement
  • Code refactoring
  • Other (please specify):

Related Issues

Fixes or relates to #

Checklist Before Submitting

  • My code has been tested locally.
  • Documentation has been updated as needed.
  • New or updated tests are included where applicable.

Signed-off-by: n1ck-guo <heng.guo@intel.com>
Copilot AI review requested due to automatic review settings April 3, 2026 07:25
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds runtime handling for the gemma4 model type by patching Gemma4 decoder layers to avoid shape mismatches during auto-round block-wise quantization.

Changes:

  • Added gemma4 to the special model list and introduced a Gemma4-specific patch routine.
  • Hooked the patch into _handle_special_model when model.config.model_type == "gemma4".
  • Removed a couple of stray whitespace-only lines near ignore-layer registrations.

per_layer_input = per_layer_input[:, :hs_seq, :]
else:
pad = per_layer_input[:, -1:, :].expand(-1, hs_seq - pl_seq, -1)
per_layer_input = torch.cat([per_layer_input, pad], dim=1)
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch is used in _patch_gemma4_model (for torch.cat) but is not imported in that function (and may not be in module scope). This will raise NameError at runtime in the padding branch. Import torch in _patch_gemma4_model (or ensure it’s available in this module’s scope) before using it.

Copilot uses AI. Check for mistakes.
import types as _types

try:
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_sliding_window_causal_mask is imported but not referenced anywhere in the added code, which can fail linting and adds noise. Either remove it or use it to detect/handle whether the cached mask is sliding-window (which the comments imply you intended to do).

Suggested change
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
from transformers.masking_utils import create_causal_mask

Copilot uses AI. Check for mistakes.
Comment on lines +134 to +146
# Recompute attention_mask for full-attention layers when a
# sliding-window mask was cached (it would be too restrictive)
if is_full_attn and attention_mask is not None and position_ids is not None:
# Only rebuild if the mask was created for a shorter context
# (sliding window masks have finite bandwidth)
try:
attention_mask = create_causal_mask(
config=cfg,
inputs_embeds=hidden_states,
attention_mask=None,
past_key_values=kwargs.get("past_key_values"),
position_ids=position_ids,
)
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inline comment says masking is rebuilt only if the cached mask was created for a shorter (sliding-window) context, but the code currently attempts to rebuild unconditionally for any full-attention layer whenever attention_mask is not None. Either implement an explicit check that the cached mask is sliding-window/too restrictive (and only rebuild then), or update the comment to reflect the actual behavior.

Copilot uses AI. Check for mistakes.

model.forward = partial(_qwen3_omni_moe_forward, model)
if hasattr(model, "config") and model.config.model_type == "gemma4":
_patch_gemma4_model(model)
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_patch_gemma4_model returns model, but _handle_special_model ignores the return value. Since the patch currently mutates in place this works, but capturing the return makes the integration robust if _patch_gemma4_model later needs to wrap/replace the model (similar to other handlers). Consider model = _patch_gemma4_model(model) here.

Suggested change
_patch_gemma4_model(model)
model = _patch_gemma4_model(model)

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants