Conversation
Signed-off-by: n1ck-guo <heng.guo@intel.com>
There was a problem hiding this comment.
Pull request overview
Adds runtime handling for the gemma4 model type by patching Gemma4 decoder layers to avoid shape mismatches during auto-round block-wise quantization.
Changes:
- Added
gemma4to the special model list and introduced a Gemma4-specific patch routine. - Hooked the patch into
_handle_special_modelwhenmodel.config.model_type == "gemma4". - Removed a couple of stray whitespace-only lines near ignore-layer registrations.
| per_layer_input = per_layer_input[:, :hs_seq, :] | ||
| else: | ||
| pad = per_layer_input[:, -1:, :].expand(-1, hs_seq - pl_seq, -1) | ||
| per_layer_input = torch.cat([per_layer_input, pad], dim=1) |
There was a problem hiding this comment.
torch is used in _patch_gemma4_model (for torch.cat) but is not imported in that function (and may not be in module scope). This will raise NameError at runtime in the padding branch. Import torch in _patch_gemma4_model (or ensure it’s available in this module’s scope) before using it.
| import types as _types | ||
|
|
||
| try: | ||
| from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask |
There was a problem hiding this comment.
create_sliding_window_causal_mask is imported but not referenced anywhere in the added code, which can fail linting and adds noise. Either remove it or use it to detect/handle whether the cached mask is sliding-window (which the comments imply you intended to do).
| from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask | |
| from transformers.masking_utils import create_causal_mask |
| # Recompute attention_mask for full-attention layers when a | ||
| # sliding-window mask was cached (it would be too restrictive) | ||
| if is_full_attn and attention_mask is not None and position_ids is not None: | ||
| # Only rebuild if the mask was created for a shorter context | ||
| # (sliding window masks have finite bandwidth) | ||
| try: | ||
| attention_mask = create_causal_mask( | ||
| config=cfg, | ||
| inputs_embeds=hidden_states, | ||
| attention_mask=None, | ||
| past_key_values=kwargs.get("past_key_values"), | ||
| position_ids=position_ids, | ||
| ) |
There was a problem hiding this comment.
The inline comment says masking is rebuilt only if the cached mask was created for a shorter (sliding-window) context, but the code currently attempts to rebuild unconditionally for any full-attention layer whenever attention_mask is not None. Either implement an explicit check that the cached mask is sliding-window/too restrictive (and only rebuild then), or update the comment to reflect the actual behavior.
|
|
||
| model.forward = partial(_qwen3_omni_moe_forward, model) | ||
| if hasattr(model, "config") and model.config.model_type == "gemma4": | ||
| _patch_gemma4_model(model) |
There was a problem hiding this comment.
_patch_gemma4_model returns model, but _handle_special_model ignores the return value. Since the patch currently mutates in place this works, but capturing the return makes the integration robust if _patch_gemma4_model later needs to wrap/replace the model (similar to other handlers). Consider model = _patch_gemma4_model(model) here.
| _patch_gemma4_model(model) | |
| model = _patch_gemma4_model(model) |
Description
Please briefly describe your main changes, the motivation.
Type of Change
Related Issues
Fixes or relates to #
Checklist Before Submitting