Skip to content

Adding dtype casting for direct state transfer.#1097

Open
NicoGrande wants to merge 1 commit intomainfrom
nicogrande/add-dtype-cast
Open

Adding dtype casting for direct state transfer.#1097
NicoGrande wants to merge 1 commit intomainfrom
nicogrande/add-dtype-cast

Conversation

@NicoGrande
Copy link
Collaborator

@NicoGrande NicoGrande commented Feb 13, 2026

This PR adds support for dtype casting in the direct_state_transfer method if we have different dtypes for identical keys in source and target pytrees. This is useful for the case where we want to train a model with a certain dtype precision and then serve the same model with lower precision.

Checklist

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and all unit tests pass.
  • I have added all appropriate doc-strings/documentation.
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have signed the Contributor License Agreement.
  • I have followed Contribution Guidelines.

@gemini-code-assist
Copy link

Summary of Changes

Hello @NicoGrande, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the state transfer utility by introducing automatic data type casting. This improvement allows models with different precision requirements (e.g., training in float32 and deploying in bfloat16) to seamlessly transfer weights, significantly improving flexibility and reducing manual conversion steps. The changes ensure that parameters are correctly cast to the target data type during the transfer process, maintaining numerical accuracy where possible.

Highlights

  • Data Type Casting in State Transfer: Implemented automatic data type casting during direct state transfer operations, allowing source parameters (e.g., float32) to be converted to the target's data type (e.g., bfloat16) when their shapes align.
  • Support for Scanned Layers: Extended the data type casting functionality to correctly handle parameters within scanned layers, ensuring proper slicing and casting when transferring state from scanned to unrolled layer structures.
  • New Unit Tests: Added comprehensive unit tests to validate the new data type casting behavior for both direct parameter mapping and scanned layer scenarios, confirming correct conversion and value preservation.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • tests/generate/utils_test.py
    • Added test_transfer_state_directly_with_dtype_casting to verify direct and scanned layer state transfer with dtype casting from float32 to bfloat16.
    • Added test_transfer_state_directly_scanned_layers_casting to specifically test dtype casting during transfer from scanned layers to unrolled layers.
  • tunix/generate/utils.py
    • Modified intersect_trees to apply _apply_dtype_cast to source values when a direct key match is found, ensuring the source data type matches the target.
    • Modified intersect_trees to apply _apply_dtype_cast to sliced source values when transferring state from scanned layers, ensuring correct data type conversion for unrolled layers.
Activity
  • The pull request author, NicoGrande, has provided a checklist indicating that unit tests have been added and verified, existing code is not broken, documentation is appropriate, the PR is based on the latest main branch, the Contributor License Agreement has been signed, and contribution guidelines have been followed. However, the checkboxes are currently unchecked.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces dtype casting during direct state transfer, which is a valuable feature for handling weights of different precisions. The implementation correctly covers both direct parameter matches and mappings from scanned layers. The new tests are comprehensive and effectively validate this functionality. I've included two suggestions to enhance the robustness of the code by adding checks to prevent potential crashes when encountering non-tensor values within state dictionaries.

@NicoGrande NicoGrande force-pushed the nicogrande/add-dtype-cast branch from 224caa0 to 27823a3 Compare February 13, 2026 21:34
@NicoGrande NicoGrande marked this pull request as ready for review February 13, 2026 21:37
copybara-service bot pushed a commit that referenced this pull request Feb 13, 2026
--
224caa0 by Nicolas Grande <nicogrande@google.com>:

adding dtype casting for direct sync.

COPYBARA_INTEGRATE_REVIEW=#1097 from google:nicogrande/add-dtype-cast 224caa0
PiperOrigin-RevId: 869872501
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant