I'm trying to use your implementation of memory efficient attention (and pure Pytorch FlashAttention implementation as well) to replace the eager implementation of Qwen3 as in this function.
https://github.com/huggingface/transformers/blob/0e1c2817455602d182bd8ebf5fba212e14fb187e/src/transformers/models/qwen3/modeling_qwen3.py#L135
But I noticed that there's a difference between the output of the vanilla eager implementation and your implementation. For one layer, the differences could be up to 1e-2 for one layer.
It will be accumulated after multiple layers and eventually lead to a very weird output. I wonder if you have some insights on the place that caused the error? When I debug with a very small input (everything fits one chunk) and check the output at each step, the differences start from the softmax computation.