Skip to content

FSDP v2 with DCP checkpoint#240

Open
jd-nuva wants to merge 1 commit into
ByteDance-Seed:mainfrom
jd-nuva:fsdp_v2
Open

FSDP v2 with DCP checkpoint#240
jd-nuva wants to merge 1 commit into
ByteDance-Seed:mainfrom
jd-nuva:fsdp_v2

Conversation

@jd-nuva
Copy link
Copy Markdown

@jd-nuva jd-nuva commented Aug 19, 2025

Summary

Existing Bagel code initializes on CPU first that materializes all tensors before sharding or moving to GPU. As a result, on my 8xH100 machine, it takes 15~20mins to simply initialize before able to do anything, thus significantly slow down iteration.

This PR uses FSDP v2 that natively uses distributed DTensor that provides finer control of initializing, sharding, gradient clipping, checkpointing.

The benefit with FSDP v2 approach is that we can use meta / empty initialization that locally shard the model without materializing all tensors on CPU, and ensure each GPU worker only need to read the tensors it's responsible for instead of entire copy.

Perf difference

On 8xH100 dev machine, model init can be reduced from ~15 mins to ~30 seconds, and naturally works across multi-node distributed training too via replication.

Next Steps

The proposed scripts and functions are tested on my local 8xH100 host with numerically correct and stable loss (~0.31 across first ~20 ish with gradient norm <0.1, no explosion or NaN)

But it didn't have training script integration yet given I've made significant modifications of existing bagel codebase on this, that will need more inputs and guidance from bagel team.

NOTE: There're many sharp edges of existing script given post init and weight copying, thus I found the safest path is to initialize on CPU first using exist CPU + FSDP v1 code, but shard with FSDP v2 and save as DCP, then for all subsequence training runs, directly load DCP checkpoint for step 0.

@jd-nuva jd-nuva marked this pull request as ready for review August 19, 2025 19:31
@Andy1621
Copy link
Copy Markdown
Collaborator

Cool! I wiill check it these two days when free!

@Andy1621
Copy link
Copy Markdown
Collaborator

@jd-nuva Looks correct!

Could you please provide the relevant training file, such as pretrain_unified_navit_fsdpv2, that calls these functions?

@jd-nuva
Copy link
Copy Markdown
Author

jd-nuva commented Aug 21, 2025

@jd-nuva Looks correct!

Could you please provide the relevant training file, such as pretrain_unified_navit_fsdpv2, that calls these functions?

This is the part that would require deeper integration because it only works the best when it's resumed from a FSDP v2 checkpoint at step 0 in the first place.

From numerous attempts, i found existing bagel code has a few weight copy / post init operations that simply running training script directly calling these utils function would lead to subtle but severe training quality errors. So far the best path is

  1. we still use existing script to CPU init on target node of X GPUs
  2. FSDP v2 shard the model
  3. Save DCP checkpoint to disk
  4. modify training script with init_empty_weights() with meta init and load checkpoint from 3)

I think alternatively we can get the utilities function merged first, then I can put up a separate script like convert_dcp_checkpoint.py or something that only needs to be executed once, then a modified version of pretrain_unified_navit_fsdpv2 that starts from step 3) and 4) above.

@Andy1621
Copy link
Copy Markdown
Collaborator

Cool! Can you provide these files that help me debug?

@Andy1621
Copy link
Copy Markdown
Collaborator

@jd-nuva I tried the current logic, but I’m running into another issue.

Since I need to train both T2I and I2T, I have to run convert_conv2d_to_linear. However, this transformation can’t be applied when using init_empty_weights. With the current setup, the model is materialized only after wrapping with FSDP, which prevents applying convert_conv2d_to_linear beforehand.

@jd-nuva
Copy link
Copy Markdown
Author

jd-nuva commented Aug 23, 2025

@jd-nuva I tried the current logic, but I’m running into another issue.

Since I need to train both T2I and I2T, I have to run convert_conv2d_to_linear. However, this transformation can’t be applied when using init_empty_weights. With the current setup, the model is materialized only after wrapping with FSDP, which prevents applying convert_conv2d_to_linear beforehand.

hmm i haven't run convert_conv2d_to_linear myself before, do you mean this function work like other post init operations that might copy certain weights over?

the fsdp v2 logic should strictly be a placement of existing fsdp v1 only, and init_empty_weights works later only after a correctly fsdp v2 sharded checkpoint is saved. even for t2i, i never had success with init_empty_weights directly using fsdp v2 either

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants