Skip to content

Fix xLSTM train mode crash when return_last_states=False#47026

Open
lcheng321 wants to merge 2 commits into
huggingface:mainfrom
lcheng321:fix-xlstm-train-return-last-states
Open

Fix xLSTM train mode crash when return_last_states=False#47026
lcheng321 wants to merge 2 commits into
huggingface:mainfrom
lcheng321:fix-xlstm-train-return-last-states

Conversation

@lcheng321

@lcheng321 lcheng321 commented Jul 2, 2026

Copy link
Copy Markdown

CI

Fixes #47013

Root cause

In xLSTMBackend.forward, the train branch forwarded the kernel's return value as is. With return_last_states=False, mlstm_chunkwise_native_autograd returns a bare tensor, but every caller assumes a (h, state) tuple.

xLSTMLayer.forward does h, state = self.mlstm_backend(...). xLSTMModel.forward later indexes state when copying into the cache.

This crashed for any batch size other than 2 with ValueError: too many values to unpack. For batch size 2 it silently unpacked the tensor along the batch dimension and hit the shape assertion instead.

There is also a second, latent crash on the same path. Since use_cache defaults to True, xLSTMModel.forward creates a cache and unconditionally indexes the returned state, so returning None there raises TypeError: 'NoneType' object is not subscriptable.

Fix

Normalize the train branch of xLSTMBackend.forward to always return a (h, last_states) tuple, consistent with the inference branch. The last states are a free byproduct of the chunkwise recurrence, mlstm_chunkwise_recurrent_fw_C already computes them, so they are always returned instead of discarded. return_last_states no longer changes the return type, and the cache copy path in xLSTMModel.forward always gets a valid state tuple.

The train_with_padding branch keeps its existing guard and returns (h, None), since padding pollutes the last state and it is not meaningful there.

No kernel functions or callers were changed.

Testing

Repro script, from the issue, extended to batch 2 and batch 3:

import torch
from transformers import xLSTMConfig
from transformers.models.xlstm.modeling_xlstm import xLSTMModel

config = xLSTMConfig(
    hidden_size=128, num_hidden_layers=2, num_heads=4,
    mode="train", return_last_states=False,
    chunkwise_kernel="chunkwise--native_autograd",
    sequence_kernel="native_sequence__native", step_kernel="native",
)
model = xLSTMModel(config)
for bs in (2, 3):
    out = model(torch.randint(0, config.vocab_size, (bs, 64)))
    print(bs, out.last_hidden_state.shape)

Before, on current main:

=== batch_size=2 ===
FAILED: ValueError Got torch.Size([4, 64, 32]), expected (2, 4, 64, 32)

=== batch_size=3 ===
FAILED: ValueError too many values to unpack (expected 2)

After the fix:

=== batch_size=2 ===
SUCCESS, shape = torch.Size([2, 64, 128])

=== batch_size=3 ===
SUCCESS, shape = torch.Size([3, 64, 128])

Full terminal logs below for reference.

Full log, before fix
=== batch_size=2 ===
FAILED: ValueError Got torch.Size([4, 64, 32]), expected (2, 4, 64, 32)
Traceback (most recent call last):
  File "modeling_xlstm.py", line 1481, in forward
    hidden_states, rnn_state = xlstm_block(
  File "modeling_xlstm.py", line 1198, in forward
    x_mlstm, state = self.mlstm_layer(x_mlstm, state)
  File "modeling_xlstm.py", line 1164, in forward
    raise ValueError(f"Got {h.shape}, expected {expected_h_shape}")
ValueError: Got torch.Size([4, 64, 32]), expected (2, 4, 64, 32)

=== batch_size=3 ===
FAILED: ValueError too many values to unpack (expected 2)
Traceback (most recent call last):
  File "modeling_xlstm.py", line 1481, in forward
    hidden_states, rnn_state = xlstm_block(
  File "modeling_xlstm.py", line 1198, in forward
    x_mlstm, state = self.mlstm_layer(x_mlstm, state)
  File "modeling_xlstm.py", line 1147, in forward
    h, state = self.mlstm_backend(
ValueError: too many values to unpack (expected 2)
Full log, after fix
=== batch_size=2 ===
SUCCESS, shape = torch.Size([2, 64, 128])

=== batch_size=3 ===
SUCCESS, shape = torch.Size([3, 64, 128])

Screenshots

For extra confirmation that the logs above are unedited, raw terminal output, here are screenshots of the actual session.

Before the fix, batch_size=2 hits the shape assertion and batch_size=3 hits the unpack error.
before

After the fix, both batch sizes complete and return the expected hidden state shape.
after

Note

This keeps the fix localized to xLSTMBackend.forward. The larger output recorder refactor mentioned in the issue, routing return_last_states through the output recorder like output_hidden_states, could be a follow up.

cc @vasqu @ArthurZucker

@github-actions

github-actions Bot commented Jul 2, 2026

Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: xlstm

@github-actions

github-actions Bot commented Jul 2, 2026

Copy link
Copy Markdown
Contributor

CI recap

Dashboard: View test results in Grafana
Latest run: 28596841571:2
Result: success | Jobs: 5 | Tests: 312 | Failures: 0 | Duration: 2m 12s

@lcheng321

Copy link
Copy Markdown
Author

@vasqu since you commented on #47013 with the output recorder suggestion, would you have a moment to review this when you can. This PR keeps the fix scoped to xLSTMBackend.forward, always returning a (h, last_states) tuple, and I confirmed it resolves the crash for both batch_size=2 and batch_size=3. No rush at all, thanks for taking a look.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[xLSTM] Crash in train mode with return_last_states=False: backend returns bare tensor but xLSTMLayer always unpacks (h, state)

1 participant