Fix xLSTM train mode crash when return_last_states=False#47026
Open
lcheng321 wants to merge 2 commits into
Open
Fix xLSTM train mode crash when return_last_states=False#47026lcheng321 wants to merge 2 commits into
lcheng321 wants to merge 2 commits into
Conversation
Contributor
|
[For maintainers] Suggested jobs to run (before merge) run-slow: xlstm |
2 tasks
Contributor
CI recapDashboard: View test results in Grafana |
Author
|
@vasqu since you commented on #47013 with the output recorder suggestion, would you have a moment to review this when you can. This PR keeps the fix scoped to xLSTMBackend.forward, always returning a (h, last_states) tuple, and I confirmed it resolves the crash for both batch_size=2 and batch_size=3. No rush at all, thanks for taking a look. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #47013
Root cause
In
xLSTMBackend.forward, the train branch forwarded the kernel's return value as is. Withreturn_last_states=False,mlstm_chunkwise_native_autogradreturns a bare tensor, but every caller assumes a(h, state)tuple.xLSTMLayer.forwarddoesh, state = self.mlstm_backend(...).xLSTMModel.forwardlater indexesstatewhen copying into the cache.This crashed for any batch size other than 2 with
ValueError: too many values to unpack. For batch size 2 it silently unpacked the tensor along the batch dimension and hit the shape assertion instead.There is also a second, latent crash on the same path. Since
use_cachedefaults toTrue,xLSTMModel.forwardcreates a cache and unconditionally indexes the returned state, so returningNonethere raisesTypeError: 'NoneType' object is not subscriptable.Fix
Normalize the train branch of
xLSTMBackend.forwardto always return a(h, last_states)tuple, consistent with the inference branch. The last states are a free byproduct of the chunkwise recurrence,mlstm_chunkwise_recurrent_fw_Calready computes them, so they are always returned instead of discarded.return_last_statesno longer changes the return type, and the cache copy path inxLSTMModel.forwardalways gets a valid state tuple.The
train_with_paddingbranch keeps its existing guard and returns(h, None), since padding pollutes the last state and it is not meaningful there.No kernel functions or callers were changed.
Testing
Repro script, from the issue, extended to batch 2 and batch 3:
Before, on current main:
After the fix:
Full terminal logs below for reference.
Full log, before fix
Full log, after fix
Screenshots
For extra confirmation that the logs above are unedited, raw terminal output, here are screenshots of the actual session.
Before the fix, batch_size=2 hits the shape assertion and batch_size=3 hits the unpack error.

After the fix, both batch sizes complete and return the expected hidden state shape.

Note
This keeps the fix localized to
xLSTMBackend.forward. The larger output recorder refactor mentioned in the issue, routingreturn_last_statesthrough the output recorder likeoutput_hidden_states, could be a follow up.cc @vasqu @ArthurZucker