Skip to content

Fix optimizer states dim(ckpt) of rowwise adagrad#305

Open
jiashuy wants to merge 1 commit intoNVIDIA:mainfrom
jiashuy:fix_opt_dim_in_ckpt
Open

Fix optimizer states dim(ckpt) of rowwise adagrad#305
jiashuy wants to merge 1 commit intoNVIDIA:mainfrom
jiashuy:fix_opt_dim_in_ckpt

Conversation

@jiashuy
Copy link
Collaborator

@jiashuy jiashuy commented Feb 11, 2026

Description

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

@jiashuy
Copy link
Collaborator Author

jiashuy commented Feb 11, 2026

CI

@greptile-apps
Copy link

greptile-apps bot commented Feb 11, 2026

Greptile Overview

Greptile Summary

This PR fixes a checkpoint dimension mismatch issue for RowWise AdaGrad optimizer. RowWise AdaGrad stores optimizer states in a reduced dimension at runtime (calculated as 16 // DTYPE_NUM_BYTES based on embedding dtype) but only needs 1 element per row in checkpoints. The fix introduces get_ckpt_state_dim() to distinguish between runtime and checkpoint dimensions, then updates serialization logic to:

  • Truncate optimizer states to checkpoint dimension when saving
  • Pad loaded states back to runtime dimension using initial_accumulator_value

This ensures checkpoint compatibility and prevents dimension mismatches during save/load operations.

Confidence Score: 5/5

  • This PR is safe to merge with no issues identified
  • The implementation is clean and follows a clear design pattern: introducing an abstraction (get_ckpt_state_dim) with a sensible default, then overriding it only where needed. The truncation/padding logic is symmetric and uses the correct initial value for padding. The changes are localized and don't affect other optimizers.
  • No files require special attention

Important Files Changed

Filename Overview
corelib/dynamicemb/dynamicemb/optimizer.py Added get_ckpt_state_dim() method to base optimizer class and overrode it in RowWiseAdaGradDynamicEmbeddingOptimizerV2 to return 1 instead of the runtime state dimension
corelib/dynamicemb/dynamicemb/key_value_table.py Updated dump/load methods to use get_ckpt_state_dim() for checkpoint dimensions, with logic to truncate on save and pad on load when checkpoint dim differs from runtime dim
corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py Minor formatting change to improve code style consistency (multi-line assignment)

Sequence Diagram

sequenceDiagram
    participant Caller
    participant KeyValueTable
    participant Optimizer
    participant Checkpoint

    Note over Caller,Checkpoint: Checkpoint Save Flow
    Caller->>KeyValueTable: dump()
    KeyValueTable->>Optimizer: get_ckpt_state_dim(emb_dim)
    Optimizer-->>KeyValueTable: 1 (for RowWise AdaGrad)
    KeyValueTable->>Optimizer: get_state_dim(emb_dim)
    Optimizer-->>KeyValueTable: _optim_state_dim (e.g., 4 or 8)
    KeyValueTable->>KeyValueTable: truncate opt_states[:, :ckpt_state_dim]
    KeyValueTable->>Checkpoint: write truncated states

    Note over Caller,Checkpoint: Checkpoint Load Flow
    Caller->>KeyValueTable: load()
    KeyValueTable->>Optimizer: get_ckpt_state_dim(emb_dim)
    Optimizer-->>KeyValueTable: 1 (for RowWise AdaGrad)
    KeyValueTable->>Checkpoint: read ckpt_state_dim elements per row
    KeyValueTable->>KeyValueTable: create padded tensor with initial_accumulator_value
    KeyValueTable->>KeyValueTable: padded[:, :ckpt_state_dim] = loaded_states
    KeyValueTable-->>Caller: return padded states
Loading

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant