Describe the bug:
I was debugging some issue where inference results are completely diverging, while training looks good. It seems to be stemming from numerical differences due to smaller batch size during inference.
Describe how to reproduce:
I do not have a fully reproducible example, but I see this on my end:
> embeded_maps_top.shape
torch.Size([15, 56, 4096])
> self.llama_decoder_top(
embeded_maps_top[:1],
padding_mask=None,
)[0][0, :4]
tensor([[-0.0197, 0.1034, 0.2976, ..., 0.1380, -0.0656, -0.1645],
[-0.2249, -0.4262, -0.9900, ..., -0.2404, 0.7360, 0.1237],
[ 1.2366, 0.4611, -1.8283, ..., -0.4055, 0.3350, 1.5457],
[-0.0480, -0.7974, -0.4915, ..., -0.1564, -0.2716, 0.2660]],
device='cuda:0', grad_fn=<SliceBackward0>)
> self.llama_decoder_top(
embeded_maps_top,
padding_mask=None,
)[0][0, :4]
tensor([[-0.0190, 0.1034, 0.2991, ..., 0.1373, -0.0656, -0.1645],
[-0.2310, -0.4302, -0.9899, ..., -0.2414, 0.7400, 0.1247],
[ 1.2365, 0.4610, -1.8141, ..., -0.4020, 0.3368, 1.5527],
[-0.0459, -0.7936, -0.4916, ..., -0.1583, -0.2678, 0.2602]],
device='cuda:0', grad_fn=<SliceBackward0>)
note that the difference between the two calls is the batch size, and the results are different up to 1e-2 is some places. compounded over multiple calls during autoregressive inference, this lead to a complete divergent results.
Describe the expected behavior:
see above.
Environment:
fs2 branch that is off v0.4.6 (the mms one)
torch 2.4.0+cu121
single A100 gpu on a devserver.
Additional Context:
Describe the bug:
I was debugging some issue where inference results are completely diverging, while training looks good. It seems to be stemming from numerical differences due to smaller batch size during inference.
Describe how to reproduce:
I do not have a fully reproducible example, but I see this on my end:
note that the difference between the two calls is the batch size, and the results are different up to 1e-2 is some places. compounded over multiple calls during autoregressive inference, this lead to a complete divergent results.
Describe the expected behavior:
see above.
Environment:
fs2 branch that is off v0.4.6 (the mms one)
torch 2.4.0+cu121
single A100 gpu on a devserver.
Additional Context: