-
Notifications
You must be signed in to change notification settings - Fork 168
chunk the logits materialization for rl loss function to decrease peak memory usage #1511
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| if cp_enabled: | ||
| left_pad_logit = get_padding_logit_from_prev_cp_rank(logits, cp_rank, cp_size, cp_group) | ||
| else: | ||
| left_pad_logit = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing left pad logit between chunks breaks shift
When logits_chunks > 1, the shift_logits function requires the last logit from the previous chunk to properly shift logits at chunk boundaries. However, the current implementation only considers CP rank boundaries, not chunk boundaries. For chunks after the first one: without CP, left_pad_logit is None (causing zeros to be used); with CP, get_padding_logit_from_prev_cp_rank returns logits from a different rank rather than the previous chunk. The logits tensor is also deleted at line 350 before it can be used for the next chunk. This causes incorrect trainer_logprobs calculations and corrupted loss values whenever chunking is enabled.
src/prime_rl/trainer/rl/config.py
Outdated
| ge=1, | ||
| description="Number of chunks to split the sequence into for logits materialization. Higher values reduce memory usage but may increase computation time. Default is 1 (no chunking).", | ||
| ), | ||
| ] = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New config field missing CHANGELOG entry (Bugbot Rules)
A new config field logits_chunks has been added to src/prime_rl/trainer/rl/config.py, which matches the pattern src/prime_rl/*/config.py. According to the review rules, any PR that modifies configuration structures must update CHANGELOG.md, but no corresponding entry was added.
| loss_mask=loss_mask_chunk.squeeze().split(response_lengths_chunk), | ||
| loss_config=config.loss, | ||
| loss_scale=loss_scale, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Chunking breaks sequence boundary detection for loss
When logits_chunks > 1, get_response_lengths(position_ids_chunk) is called on each chunk independently. This function detects sequence boundaries by looking for position_ids resetting to 0 followed by 1. When sequences span chunk boundaries, they are incorrectly identified as separate sequences in each chunk. This causes incorrect sequence-level loss normalization (when ratio_type == "sequence"), wrong sequence-level importance ratio calculations, and incorrect application of sequence_mask_low/sequence_mask_high thresholds. The loss values and gradient flow will be semantically incorrect for any packed batch where sequence boundaries don't align with chunk boundaries.
| position_ids_list = position_ids.chunk(num_logits_chunks, dim=1) | ||
| inference_logprobs_list = inference_logprobs.chunk(num_logits_chunks, dim=1) | ||
| advantages_list = advantages.chunk(num_logits_chunks, dim=1) | ||
| loss_mask_list = loss_mask.chunk(num_logits_chunks, dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CP mode mixes sharded and non-sharded tensors when chunking
When context parallelism is enabled, input_ids and hidden_states are sharded (have seq_len/cp_size tokens), while position_ids, inference_logprobs, advantages, and loss_mask remain non-sharded (have full seq_len tokens). The num_logits_chunks is computed using the full seq_len, but then applied to chunk both sharded and non-sharded tensors. This causes the chunked tensors to have mismatched sizes - for example, hidden_states_chunk[i] may have 2048 tokens while position_ids_chunk[i] has 4096 tokens - leading to incorrect loss computation or runtime errors in CP mode.
| tensors[key].append(loss_tensor) | ||
|
|
||
| # Now backward through the rest of the model with accumulated gradients | ||
| hidden_states.backward(detached_hidden_states.grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Debug log shows only last chunk's loss not total
After chunking, tensors['loss'], tensors['entropy'], and tensors['mismatch_kl'] are appended once per chunk rather than once per micro-batch. The debug log message uses tensors['loss'][-1] which previously retrieved the micro-batch's total loss but now only retrieves the last chunk's loss. This makes the debug output misleading and complicates training diagnostics.
Note
Reduces peak memory during RL training by chunking logits computation and refines FSDP sharding for models with tied embeddings.
logits_chunk_sizeinRLTrainerConfigto control per-sequence chunking for logits materializationtrain.py, replaces single-shot logits with chunked application oflm_headoverhidden_states, computing loss/entropy and backward per chunk, then backpropagating through the backbone; handles CP all-gather per chunk and configureslm_headFSDP resharding/blocking accordinglymodel.py, updates FSDP setup: ifconfig.tie_word_embeddings, shard[model.model.embed_tokens, model.lm_head]; else shardmodel.lm_head; and shardmodel.model(not the entiremodel), consistently usingconfig.reshard_after_forwardWritten by Cursor Bugbot for commit 1313a42. This will update automatically on new commits. Configure here.