Skip to content

Add native FSDP2 module + migration#46707

Open
3outeille wants to merge 47 commits into
mainfrom
split/a-pr-2-fsdp-module
Open

Add native FSDP2 module + migration#46707
3outeille wants to merge 47 commits into
mainfrom
split/a-pr-2-fsdp-module

Conversation

@3outeille

@3outeille 3outeille commented Jun 17, 2026

Copy link
Copy Markdown
Member

CI

#46705

applying FSDP + test_fsdp_mixin.py will be added in PR #46990

Move FSDP2 wrapping and plan verification to distributed/fsdp.py, keep
integrations/fsdp.py as a backward-compatible re-export, and update core
call sites to import from transformers.distributed.fsdp.
@3outeille 3outeille changed the base branch from main to split/a-pr-1-distributed-config June 17, 2026 04:12
@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@3outeille 3outeille mentioned this pull request Jun 17, 2026
5 tasks
@3outeille 3outeille requested a review from ArthurZucker June 23, 2026 07:42

@ArthurZucker ArthurZucker left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much better, still some obvious cleanups to do!

Comment thread src/transformers/distributed/fsdp.py Outdated
Comment thread src/transformers/distributed/fsdp.py Outdated
Comment thread src/transformers/distributed/fsdp.py Outdated


def expand_fsdp_plan(model, fsdp_plan: dict[str, str]) -> list[tuple[str, nn.Module, str]]:
"""Expand plan keys into ``(module_name, module, sharding_strategy)`` shard targets."""

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? {module_name: (module, sharding_strategy)} looks more manageable

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this has been rework in a much clearer way now. It now returns reshard_targets, no_reshard_targets as well

Comment thread src/transformers/distributed/fsdp.py Outdated
Comment thread src/transformers/distributed/fsdp.py Outdated
tie_word_embeddings = getattr(model.config, "tie_word_embeddings", False)

adapted_fsdp_plan = _resolve_tied_embed_lm_head_plan(fsdp_plan, tie_word_embeddings=tie_word_embeddings)
shard_targets = expand_fsdp_plan(model, adapted_fsdp_plan)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could just return the alraready reshard / no reshard

Comment thread src/transformers/distributed/fsdp.py Outdated
Comment on lines +235 to +236
if tie_word_embeddings and hasattr(model, "tie_weights"):
model.tie_weights()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again I would advise against this, unless absolutely required but do we need all the machinery behind this? prob not no?

@3outeille 3outeille Jun 24, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean the guarding around model.tie_weights() or the fact that we call model.tie_weights() ? This is needed but will confirm once I added the fsdp_mixin test on the next PR

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is my only friction point. Do you need to tie the weights, or just pass the hooks or what?
In general its bad to call it twice and we would want to FSDP before its call in the main call site, or not call!

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to tie the weights indeed.

But I see that it is called later in _finalize_model_loading(). My initial thought was to call in fsdp as a lot of stuff regarding embedding tying were performed. Maybe a comment saying #tie_embedding will be done in _finalize_model_loading() will be more appropriate ?

@3outeille 3outeille requested a review from ArthurZucker June 24, 2026 10:20
Base automatically changed from split/a-pr-1-distributed-config to main June 24, 2026 14:47

@ArthurZucker ArthurZucker left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly check the tie weight please!

Comment thread src/transformers/distributed/fsdp.py Outdated
Comment on lines +161 to +171
for plan_key, sharding_strategy in fsdp_plan.items():
if plan_key in module_lookup:
# model.norm, lm_head etc.
targets = [(plan_key, module_lookup[plan_key])]
else:
# model.layers.*
targets = [
(module_name, module)
for module_name, module in module_lookup.items()
if replace_layer_number_by_wildcard(module_name) == plan_key
]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since you have more keys than plans, you would be better off iterating the keys rather than the plans no?
And we have a single regexx omputation in tp / core model loading that makes it faster (tho this is probably not super slow!)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you join all the plans on "|" with numbered group capture !

@3outeille 3outeille Jul 3, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine by me, performance-wise it's negligeable and both read well, so I can go with your version but I dont think there is a need to join all the plans on "|"

fsdp_policy_kwargs = _get_fsdp_policy_kwargs(distributed_config)
tie_word_embeddings = getattr(model.config, "tie_word_embeddings", False)

adapted_fsdp_plan = _resolve_tied_embed_lm_head_plan(fsdp_plan, model)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

much better

Comment thread src/transformers/distributed/fsdp.py Outdated
Comment on lines +235 to +236
if tie_word_embeddings and hasattr(model, "tie_weights"):
model.tie_weights()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is my only friction point. Do you need to tie the weights, or just pass the hooks or what?
In general its bad to call it twice and we would want to FSDP before its call in the main call site, or not call!

@github-actions

github-actions Bot commented Jul 3, 2026

Copy link
Copy Markdown
Contributor

CI recap

Dashboard: View test results in Grafana
Latest run: 28639742471:1
Result: success | Jobs: 14 | Tests: 47,277 | Failures: 0 | Duration: 18h 48m

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants