Skip to content

[Qwem3.5] atom native support for qwen3.5#517

Merged
valarLip merged 7 commits intomainfrom
ganyi/qwen3.5_atom_native_support
Apr 14, 2026
Merged

[Qwem3.5] atom native support for qwen3.5#517
valarLip merged 7 commits intomainfrom
ganyi/qwen3.5_atom_native_support

Conversation

@ganyi1996ppo
Copy link
Copy Markdown
Contributor

@ganyi1996ppo ganyi1996ppo commented Apr 8, 2026

Motivation

# Qwen/Qwen3.5-35B-A3B-FP8
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     3|exact_match|↑  |0.8650|±  |0.0094|
|     |       |strict-match    |     3|exact_match|↑  |0.8491|±  |0.0099|
# Qwen/Qwen3.5-35B-A3B
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     3|exact_match|↑  |0.8711|±  |0.0092|
|     |       |strict-match    |     3|exact_match|↑  |0.8567|±  |0.0097|

Technical Details

Test Plan

Test Result

Submission Checklist

Copilot AI review requested due to automatic review settings April 8, 2026 07:20
@ganyi1996ppo ganyi1996ppo marked this pull request as draft April 8, 2026 07:28
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds native ATOM runtime support for Qwen3.5 multimodal architectures by extracting the text sub-config, registering the appropriate model classes, and enabling Qwen3.5 to reuse the Qwen3-Next GDN (linear attention) KV-cache path.

Changes:

  • Add Qwen3.5 text_config extraction helper + “bare” ConditionalGeneration wrappers for non-plugin execution.
  • Register Qwen3.5 ConditionalGeneration architectures in the model runner and treat Qwen3.5 text configs as “qwen_next” for GDN state/KV-cache handling.
  • Extend config loading to recognize Qwen3.5 multimodal model types and instantiate custom Qwen3.5 text config classes when extracting text sub-config.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
atom/models/qwen3_5.py Adds text-config fallback helper and introduces bare ConditionalGeneration wrappers to load multimodal checkpoints while skipping vision weights.
atom/model_engine/model_runner.py Registers Qwen3.5 ConditionalGeneration architectures and extends the Qwen-next (GDN) path to include Qwen3.5 text model types; adjusts GDN cache sizing/allocation.
atom/config.py Adds Qwen3.5 multimodal model_type extraction + custom text-config registry to build Qwen3.5 text configs from extracted dicts.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


def __init__(self, atom_config: Config, prefix: str = ""):
config: Qwen3_5MoeTextConfig = atom_config.hf_config.text_config
config: Qwen3_5MoeTextConfig = get_qwen3_5_text_config(atom_config)
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Qwen3_5ForCausalLMBase.__init__, config is annotated as Qwen3_5MoeTextConfig, but get_qwen3_5_text_config() can return either Qwen3_5TextConfig or Qwen3_5MoeTextConfig. This makes type checking/mypy misleading and can hide real config-shape bugs. Consider updating the annotation to a union (or PretrainedConfig) to reflect actual possible values.

Suggested change
config: Qwen3_5MoeTextConfig = get_qwen3_5_text_config(atom_config)
config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig = get_qwen3_5_text_config(
atom_config
)

Copilot uses AI. Check for mistakes.
super().__init__()
self.config = atom_config.hf_config
self.visual = PPMissingLayer()
self.language_model = Qwen3_5ForCausalLM(atom_config=atom_config, prefix="")
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen3_5ForConditionalGenerationBare.__init__ accepts a prefix argument but does not use it (and also hardcodes prefix="" when constructing Qwen3_5ForCausalLM). This is confusing for callers and makes it harder to add proper namespacing later. Either remove the unused parameter or thread prefix through consistently (including into language_model).

Suggested change
self.language_model = Qwen3_5ForCausalLM(atom_config=atom_config, prefix="")
self.language_model = Qwen3_5ForCausalLM(
atom_config=atom_config, prefix=prefix
)

Copilot uses AI. Check for mistakes.
Comment on lines +1016 to +1018
* torch.tensor([], dtype=mamba_dtypes[0]).element_size()
+ math.prod(mamba_shape[1])
* torch.tensor([], dtype=mamba_dtypes[1]).element_size()
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_compute_block_bytes computes element_size() by instantiating torch.tensor([], dtype=...). Since ModelRunner sets a CUDA default device, this can create tensors on GPU unnecessarily. Prefer a device-agnostic approach (e.g., use torch.empty((), dtype=..., device="cpu").element_size() or a small dtype->bytes lookup) so this estimation stays cheap and doesn't depend on global default device.

Suggested change
* torch.tensor([], dtype=mamba_dtypes[0]).element_size()
+ math.prod(mamba_shape[1])
* torch.tensor([], dtype=mamba_dtypes[1]).element_size()
* torch.empty((), dtype=mamba_dtypes[0], device="cpu").element_size()
+ math.prod(mamba_shape[1])
* torch.empty((), dtype=mamba_dtypes[1], device="cpu").element_size()

Copilot uses AI. Check for mistakes.
@ganyi1996ppo ganyi1996ppo marked this pull request as ready for review April 8, 2026 10:11
Copilot AI review requested due to automatic review settings April 8, 2026 10:11
@ganyi1996ppo ganyi1996ppo force-pushed the ganyi/qwen3.5_atom_native_support branch from f64b81c to 9f22fb3 Compare April 8, 2026 10:17
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

Comments suppressed due to low confidence (1)

atom/models/qwen3_5.py:400

  • get_qwen3_5_text_config() can return either Qwen3_5TextConfig or Qwen3_5MoeTextConfig, but this local annotation is Qwen3_5MoeTextConfig. Updating the annotation (or using a common base type) would better reflect the runtime behavior and avoid misleading type checks.
    def __init__(self, atom_config: Config, prefix: str = ""):
        config: Qwen3_5MoeTextConfig = get_qwen3_5_text_config(atom_config)
        self.atom_config = atom_config

        self.quant_config = atom_config.quant_config

        super().__init__()
        self.config = config

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@ganyi1996ppo ganyi1996ppo force-pushed the ganyi/qwen3.5_atom_native_support branch 2 times, most recently from f091704 to 325f40e Compare April 9, 2026 14:22
Copilot AI review requested due to automatic review settings April 10, 2026 01:49
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@ganyi1996ppo ganyi1996ppo force-pushed the ganyi/qwen3.5_atom_native_support branch from 3318205 to 4c30785 Compare April 10, 2026 03:05
Copilot AI review requested due to automatic review settings April 10, 2026 13:04
@ganyi1996ppo ganyi1996ppo force-pushed the ganyi/qwen3.5_atom_native_support branch from 4c30785 to 1544337 Compare April 10, 2026 13:04
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +791 to +795
return self.load_fused_expert_weights(
original_name,
name,
params_dict,
loaded_weight,
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen3_5MoeForConditionalGeneration_::load_fused_expert_weights currently calls self.load_fused_expert_weights(...), which is the same method, leading to infinite recursion during weight loading. This should instead delegate to the module-level helper (e.g., load_fused_expert_weights(...)) or rename one of the symbols to avoid shadowing.

Copilot uses AI. Check for mistakes.
@ganyi1996ppo ganyi1996ppo force-pushed the ganyi/qwen3.5_atom_native_support branch from 1544337 to bf24209 Compare April 12, 2026 15:17
Copilot AI review requested due to automatic review settings April 13, 2026 07:29
@ganyi1996ppo ganyi1996ppo force-pushed the ganyi/qwen3.5_atom_native_support branch from bf24209 to 7567ce1 Compare April 13, 2026 07:29
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 1410 to +1411
head_v_dim,
head_k_dim,
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

temporal_state_shape is used to allocate ssm_state that is passed as initial_state to fused_recurrent_gated_delta_rule, which expects the last two dims to be (K, V) (from k.shape[-1] and v.shape[-1]). Swapping to (head_v_dim, head_k_dim) allocates the state with transposed K/V and will break the GDN attention kernels. The shape here should remain (num_v_heads // tp_world_size, head_k_dim, head_v_dim) (i.e., (HV, K, V)).

Suggested change
head_v_dim,
head_k_dim,
head_k_dim,
head_v_dim,

Copilot uses AI. Check for mistakes.
Comment on lines +27 to 29
from atom.utils import resolve_obj_by_qualname

logger = logging.getLogger("atom")
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen3_5TextConfig / Qwen3_5MoeTextConfig are imported but not referenced in this module. If they’re not needed for side-effects, remove the imports to avoid unused-import lint failures and keep config loading lightweight.

Copilot uses AI. Check for mistakes.
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Copilot AI review requested due to automatic review settings April 13, 2026 13:11
@ganyi1996ppo ganyi1996ppo force-pushed the ganyi/qwen3.5_atom_native_support branch from d6afd19 to 1381767 Compare April 13, 2026 13:11
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated no new comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@valarLip valarLip merged commit 7ac3aea into main Apr 14, 2026
34 of 40 checks passed
@valarLip valarLip deleted the ganyi/qwen3.5_atom_native_support branch April 14, 2026 09:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants