Describe the bug
In sd_attention_torch.py, when q_tensor (same for k and v) is generated using torch.rand, the results of fused_self_attn_for_SD_small_head_size match the expected output from cpu_golden_attn. However, when generated using torch.randn, the computed results show a significant discrepancy.
Expected Behavior
q_tensor = torch.randn((4096, 64), dtype=torch.float32).to(device=device)
k_tensor = torch.randn((4096, 64), dtype=torch.float32).to(device=device)
v_tensor = torch.randn((4096, 64), dtype=torch.float32).to(device=device)
output_nki = fused_self_attn_for_SD_small_head_size(q_tensor, k_tensor, v_tensor)
output_torch = cpu_golden_attn(q_tensor, k_tensor, v_tensor)
allclose = torch.allclose(output_torch, output_nki, atol=1e-5, rtol=1e-3)
if allclose:
print("NKI and Torch match")
else:
print("NKI and Torch differ")
Expected output - "NKI and Torch match"
Current Behavior
NKI and Torch differ
Reproduction Steps
In sd_attention_torch.py, replace q_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device) with q_tensor = torch.randn((4096, 64), dtype=torch.float32).to(device=device). Do the same thing with k_tensor, v_tensor
Regression Issue
Possible Solution
No response
Additional Information/Context
No response
neuronx-cc version used
aws_neuronx_venv_pytorch_2_5_nxd_inference
Framework(s) and their versions used (JAX, PyTorch, etc..)
No response
Describe the bug
In
sd_attention_torch.py, whenq_tensor(same forkandv) is generated usingtorch.rand, the results offused_self_attn_for_SD_small_head_sizematch the expected output fromcpu_golden_attn. However, when generated usingtorch.randn, the computed results show a significant discrepancy.Expected Behavior
Expected output - "NKI and Torch match"
Current Behavior
NKI and Torch differ
Reproduction Steps
In
sd_attention_torch.py, replaceq_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device)withq_tensor = torch.randn((4096, 64), dtype=torch.float32).to(device=device). Do the same thing withk_tensor,v_tensorRegression Issue
Possible Solution
No response
Additional Information/Context
No response
neuronx-cc version used
aws_neuronx_venv_pytorch_2_5_nxd_inference
Framework(s) and their versions used (JAX, PyTorch, etc..)
No response