Skip to content

[Tunix] Support scanned to unscanned weight transfer in transfer_state_directly#1008

Open
gagika wants to merge 1 commit intomainfrom
agagik-scan
Open

[Tunix] Support scanned to unscanned weight transfer in transfer_state_directly#1008
gagika wants to merge 1 commit intomainfrom
agagik-scan

Conversation

@gagika
Copy link
Collaborator

@gagika gagika commented Jan 28, 2026

This PR extends the transfer_state_directly utility to support weight synchronization from scanned MaxText model (where layers are stacked in a single tensor) to unscanned MaxText + vLLM models (where layers are separate parameters).

Previously, transfer_state_directly only supported 1-to-1 mapping (Unscanned -> Unscanned). This change adds logic to detect and unroll scanned layers during the transfer process.

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and all unit tests pass.
  • I have added all appropriate doc-strings/documentation.
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have signed the Contributor License Agreement.
  • I have followed Contribution Guidelines.



def _slice_scanned_param(
src_val: Any, tgt_val: Any, slice_idx: int, key_path: str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use more detailed types instead of Any?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, added a few relevant types.

def _slice_scanned_param(
src_val: Any, tgt_val: Any, slice_idx: int, key_path: str
) -> Any:
"""Slices a scanned parameter dynamically detecting the scan axis."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put a more detailed doc string? And maybe also include the input output descriptions?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

src: Any, tgt_spec: Any, path: str = ''
) -> Tuple[Any, Any]:
# Stop recursion if we hit a leaf (non-dict)
# Helper: Intersect Trees (Handle KVCache/RNG mismatches and Scanned Layers)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you fold it into the docstring?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

) -> Tuple[Any, Any]:
# Stop recursion if we hit a leaf (non-dict)
# Helper: Intersect Trees (Handle KVCache/RNG mismatches and Scanned Layers)
def intersect_trees(src: Any, tgt_spec: Any) -> Tuple[Any, Any]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand it's there before your PR, but can you still add the detailed types?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more specific typing here.

# Stop recursion if we hit a leaf (non-dict)
# Helper: Intersect Trees (Handle KVCache/RNG mismatches and Scanned Layers)
def intersect_trees(src: Any, tgt_spec: Any) -> Tuple[Any, Any]:
"""Optimized intersection using flat dictionary traversal."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

try:
return src_val[slice_idx]

except (IndexError, TypeError):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add more debugging information in case of this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

candidate_b.pop(match_index)
candidate_b = tuple(candidate_b)

if candidate_b in src_flat:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate code as candidate a, consider make it simpler?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified the code.

…e_directly

Dynamically detecting scan dim + path caching.

adding explicit cleanup.

improving docstrings and typing
copybara-service bot pushed a commit that referenced this pull request Feb 12, 2026
--
24224a5 by Gagik Amirkhanyan <agagik@google.com>:

[Tunix] Support scanned to unscanned weight transfer in transfer_state_directly

Dynamically detecting scan dim + path caching.

adding explicit cleanup.

improving docstrings and typing

COPYBARA_INTEGRATE_REVIEW=#1008 from google:agagik-scan 24224a5
PiperOrigin-RevId: 869308029
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants