Add native FSDP2 module + migration#46707
Conversation
Move FSDP2 wrapping and plan verification to distributed/fsdp.py, keep integrations/fsdp.py as a backward-compatible re-export, and update core call sites to import from transformers.distributed.fsdp.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Much better, still some obvious cleanups to do!
|
|
||
|
|
||
| def expand_fsdp_plan(model, fsdp_plan: dict[str, str]) -> list[tuple[str, nn.Module, str]]: | ||
| """Expand plan keys into ``(module_name, module, sharding_strategy)`` shard targets.""" |
There was a problem hiding this comment.
why? {module_name: (module, sharding_strategy)} looks more manageable
There was a problem hiding this comment.
this has been rework in a much clearer way now. It now returns reshard_targets, no_reshard_targets as well
| tie_word_embeddings = getattr(model.config, "tie_word_embeddings", False) | ||
|
|
||
| adapted_fsdp_plan = _resolve_tied_embed_lm_head_plan(fsdp_plan, tie_word_embeddings=tie_word_embeddings) | ||
| shard_targets = expand_fsdp_plan(model, adapted_fsdp_plan) |
There was a problem hiding this comment.
you could just return the alraready reshard / no reshard
| if tie_word_embeddings and hasattr(model, "tie_weights"): | ||
| model.tie_weights() |
There was a problem hiding this comment.
again I would advise against this, unless absolutely required but do we need all the machinery behind this? prob not no?
There was a problem hiding this comment.
do you mean the guarding around model.tie_weights() or the fact that we call model.tie_weights() ? This is needed but will confirm once I added the fsdp_mixin test on the next PR
There was a problem hiding this comment.
this is my only friction point. Do you need to tie the weights, or just pass the hooks or what?
In general its bad to call it twice and we would want to FSDP before its call in the main call site, or not call!
There was a problem hiding this comment.
I need to tie the weights indeed.
But I see that it is called later in _finalize_model_loading(). My initial thought was to call in fsdp as a lot of stuff regarding embedding tying were performed. Maybe a comment saying #tie_embedding will be done in _finalize_model_loading() will be more appropriate ?
naming Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
…face/transformers into split/a-pr-2-fsdp-module
ArthurZucker
left a comment
There was a problem hiding this comment.
Mostly check the tie weight please!
| for plan_key, sharding_strategy in fsdp_plan.items(): | ||
| if plan_key in module_lookup: | ||
| # model.norm, lm_head etc. | ||
| targets = [(plan_key, module_lookup[plan_key])] | ||
| else: | ||
| # model.layers.* | ||
| targets = [ | ||
| (module_name, module) | ||
| for module_name, module in module_lookup.items() | ||
| if replace_layer_number_by_wildcard(module_name) == plan_key | ||
| ] |
There was a problem hiding this comment.
since you have more keys than plans, you would be better off iterating the keys rather than the plans no?
And we have a single regexx omputation in tp / core model loading that makes it faster (tho this is probably not super slow!)
There was a problem hiding this comment.
you join all the plans on "|" with numbered group capture !
There was a problem hiding this comment.
Fine by me, performance-wise it's negligeable and both read well, so I can go with your version but I dont think there is a need to join all the plans on "|"
| fsdp_policy_kwargs = _get_fsdp_policy_kwargs(distributed_config) | ||
| tie_word_embeddings = getattr(model.config, "tie_word_embeddings", False) | ||
|
|
||
| adapted_fsdp_plan = _resolve_tied_embed_lm_head_plan(fsdp_plan, model) |
| if tie_word_embeddings and hasattr(model, "tie_weights"): | ||
| model.tie_weights() |
There was a problem hiding this comment.
this is my only friction point. Do you need to tie the weights, or just pass the hooks or what?
In general its bad to call it twice and we would want to FSDP before its call in the main call site, or not call!
…face/transformers into split/a-pr-2-fsdp-module
CI recapDashboard: View test results in Grafana |
#46705
applying FSDP + test_fsdp_mixin.py will be added in PR #46990