Skip to content

Different batch sizes result in large numerical differences in transformer models #1481

@gilkeren1

Description

@gilkeren1

Describe the bug:
I was debugging some issue where inference results are completely diverging, while training looks good. It seems to be stemming from numerical differences due to smaller batch size during inference.

Describe how to reproduce:
I do not have a fully reproducible example, but I see this on my end:

> embeded_maps_top.shape
torch.Size([15, 56, 4096])

> self.llama_decoder_top(
    embeded_maps_top[:1],
    padding_mask=None,
)[0][0, :4]
tensor([[-0.0197,  0.1034,  0.2976,  ...,  0.1380, -0.0656, -0.1645],
        [-0.2249, -0.4262, -0.9900,  ..., -0.2404,  0.7360,  0.1237],
        [ 1.2366,  0.4611, -1.8283,  ..., -0.4055,  0.3350,  1.5457],
        [-0.0480, -0.7974, -0.4915,  ..., -0.1564, -0.2716,  0.2660]],
       device='cuda:0', grad_fn=<SliceBackward0>)

> self.llama_decoder_top(
     embeded_maps_top,
     padding_mask=None,
)[0][0, :4]
tensor([[-0.0190,  0.1034,  0.2991,  ...,  0.1373, -0.0656, -0.1645],
        [-0.2310, -0.4302, -0.9899,  ..., -0.2414,  0.7400,  0.1247],
        [ 1.2365,  0.4610, -1.8141,  ..., -0.4020,  0.3368,  1.5527],
        [-0.0459, -0.7936, -0.4916,  ..., -0.1583, -0.2678,  0.2602]],
       device='cuda:0', grad_fn=<SliceBackward0>)

note that the difference between the two calls is the batch size, and the results are different up to 1e-2 is some places. compounded over multiple calls during autoregressive inference, this lead to a complete divergent results.

Describe the expected behavior:
see above.

Environment:
fs2 branch that is off v0.4.6 (the mms one)
torch 2.4.0+cu121
single A100 gpu on a devserver.

Additional Context:

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions