Conversation
| seed_t = sync_tensor(seed_t, dim=0, group=None) | ||
| seed_t = seed_t.chunk(world_size, dim=0)[0] | ||
| seed = seed_t.item() | ||
| seed -= torch.iinfo(torch.int64).min |
There was a problem hiding this comment.
Bug: Incorrect seed calculation produces excessively large values
The seed calculation subtracts torch.iinfo(torch.int64).min (which equals -2^63) from the seed, effectively adding 2^63. Since torch.randint already produces non-negative values in [0, 2^63-1), this subtraction results in seed values in [2^63, 2^64-1), which are extremely large. This appears unintentional - the seed is already suitable for manual_seed() without this transformation. The unnecessary arithmetic could cause overflow issues or unexpected behavior with the random number generator.
| torch.Tensor | ||
| The gradient of the output tensor. | ||
| """ | ||
| return ring_attention._scaled_dot_product_ring_flash_attention_backward(*args, **kwargs) |
There was a problem hiding this comment.
Bug: Incomplete backward pass missing saved tensors for gradient computation
The LocalFunc autograd function's backward method is incomplete. The forward method doesn't call ctx.save_for_backward() to save the tensors needed for gradient computation (mesh, query, key, value, output, lse). The backward method only receives gradient outputs via *args and passes them directly to _scaled_dot_product_ring_flash_attention_backward, but this function typically requires the original inputs and outputs to compute input gradients. This would cause training (backward pass) to fail with incorrect arguments or missing data.
johannaSommer
left a comment
There was a problem hiding this comment.
LGTM! There is one hook missing in smash.py that checks for this ring attention algorithm and spawns the distribtued server, otherwise no notes 🌻
|
This PR has been inactive for 10 days and is now marked as stale. |
c15480e to
38ced18
Compare
|
This PR has been inactive for 10 days and is now marked as stale. |
38ced18 to
1ea806e
Compare
|
@johannaSommer could you take a quick look? I've made the changes that you pointed out |
| # perform any necessary setup steps before the smashing process begins | ||
| execute_algorithm_pre_smash_hooks(model, smash_config, algorithm_order) | ||
|
|
||
| # ring_attn needs a process group; if we're not already under torchrun/torch.distributed, |
There was a problem hiding this comment.
LGTM! Sorry for being naggy about it but could we move this into a function somwhere in the distributed utils? I feel like the smash function is the entry point for a lot of people trying to understand the code so I would love to keep it as lean as possible
Description
Adding ring_attn algorithm
Type of Change
How Has This Been Tested?
I ran the tests
Checklist