Skip to content

Conversation

@Jackmin801
Copy link
Member

@Jackmin801 Jackmin801 commented Dec 30, 2025

Note

Reduces peak memory during RL training by chunking logits computation and refines FSDP sharding for models with tied embeddings.

  • Introduces logits_chunk_size in RLTrainerConfig to control per-sequence chunking for logits materialization
  • In train.py, replaces single-shot logits with chunked application of lm_head over hidden_states, computing loss/entropy and backward per chunk, then backpropagating through the backbone; handles CP all-gather per chunk and configures lm_head FSDP resharding/blocking accordingly
  • In model.py, updates FSDP setup: if config.tie_word_embeddings, shard [model.model.embed_tokens, model.lm_head]; else shard model.lm_head; and shard model.model (not the entire model), consistently using config.reshard_after_forward

Written by Cursor Bugbot for commit 1313a42. This will update automatically on new commits. Configure here.

if cp_enabled:
left_pad_logit = get_padding_logit_from_prev_cp_rank(logits, cp_rank, cp_size, cp_group)
else:
left_pad_logit = None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing left pad logit between chunks breaks shift

When logits_chunks > 1, the shift_logits function requires the last logit from the previous chunk to properly shift logits at chunk boundaries. However, the current implementation only considers CP rank boundaries, not chunk boundaries. For chunks after the first one: without CP, left_pad_logit is None (causing zeros to be used); with CP, get_padding_logit_from_prev_cp_rank returns logits from a different rank rather than the previous chunk. The logits tensor is also deleted at line 350 before it can be used for the next chunk. This causes incorrect trainer_logprobs calculations and corrupted loss values whenever chunking is enabled.

Fix in Cursor Fix in Web

ge=1,
description="Number of chunks to split the sequence into for logits materialization. Higher values reduce memory usage but may increase computation time. Default is 1 (no chunking).",
),
] = 1
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New config field missing CHANGELOG entry (Bugbot Rules)

A new config field logits_chunks has been added to src/prime_rl/trainer/rl/config.py, which matches the pattern src/prime_rl/*/config.py. According to the review rules, any PR that modifies configuration structures must update CHANGELOG.md, but no corresponding entry was added.

Fix in Cursor Fix in Web

loss_mask=loss_mask_chunk.squeeze().split(response_lengths_chunk),
loss_config=config.loss,
loss_scale=loss_scale,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chunking breaks sequence boundary detection for loss

When logits_chunks > 1, get_response_lengths(position_ids_chunk) is called on each chunk independently. This function detects sequence boundaries by looking for position_ids resetting to 0 followed by 1. When sequences span chunk boundaries, they are incorrectly identified as separate sequences in each chunk. This causes incorrect sequence-level loss normalization (when ratio_type == "sequence"), wrong sequence-level importance ratio calculations, and incorrect application of sequence_mask_low/sequence_mask_high thresholds. The loss values and gradient flow will be semantically incorrect for any packed batch where sequence boundaries don't align with chunk boundaries.

Fix in Cursor Fix in Web

position_ids_list = position_ids.chunk(num_logits_chunks, dim=1)
inference_logprobs_list = inference_logprobs.chunk(num_logits_chunks, dim=1)
advantages_list = advantages.chunk(num_logits_chunks, dim=1)
loss_mask_list = loss_mask.chunk(num_logits_chunks, dim=1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CP mode mixes sharded and non-sharded tensors when chunking

When context parallelism is enabled, input_ids and hidden_states are sharded (have seq_len/cp_size tokens), while position_ids, inference_logprobs, advantages, and loss_mask remain non-sharded (have full seq_len tokens). The num_logits_chunks is computed using the full seq_len, but then applied to chunk both sharded and non-sharded tensors. This causes the chunked tensors to have mismatched sizes - for example, hidden_states_chunk[i] may have 2048 tokens while position_ids_chunk[i] has 4096 tokens - leading to incorrect loss computation or runtime errors in CP mode.

Fix in Cursor Fix in Web

tensors[key].append(loss_tensor)

# Now backward through the rest of the model with accumulated gradients
hidden_states.backward(detached_hidden_states.grad)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug log shows only last chunk's loss not total

After chunking, tensors['loss'], tensors['entropy'], and tensors['mismatch_kl'] are appended once per chunk rather than once per micro-batch. The debug log message uses tensors['loss'][-1] which previously retrieved the micro-batch's total loss but now only retrieves the last chunk's loss. This makes the debug output misleading and complicates training diagnostics.

Fix in Cursor Fix in Web

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants