Refactor model loading and add FSDP2 support#247
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors distributed state management and model initialization, moving shared logic for loading transformers, text encoders, and VAEs into the base Pipeline class with added support for FSDP sharding and rank-0 weight broadcasting. Review feedback identified an incorrect attribute path for sharding text encoder layers and recommended using local VAE implementations across QwenImage pipelines to ensure a necessary distributed all_reduce fix is applied.
| for layer in model.model.language_model.layers: | ||
| fully_shard(layer) |
There was a problem hiding this comment.
The attribute path model.model.language_model.layers appears to be incorrect for Qwen2_5_VLForConditionalGeneration (which is the model class used in the QwenImage pipelines). In the standard transformers implementation for Qwen2.5-VL, the layers are located at model.model.layers. Using the current path will result in an AttributeError when FSDP is enabled.
| for layer in model.model.language_model.layers: | |
| fully_shard(layer) | |
| if use_fsdp: | |
| for layer in model.model.layers: | |
| fully_shard(layer) | |
| fully_shard(model) |
| import torch | ||
| from accelerate import init_empty_weights | ||
| from diffusers.image_processor import VaeImageProcessor | ||
| from diffusers.models import AutoencoderKLQwenImage |
There was a problem hiding this comment.
AutoencoderKLQwenImage should be imported from the local models directory (diffsynth_engine.models.qwen_image) instead of diffusers.models. The local implementation contains a critical fix for all_reduce (switching from dist.all_reduce to GroupCoordinator.all_reduce) which is necessary for correct behavior in distributed environments. Importing from diffusers will bypass this fix.
| from diffusers.models import AutoencoderKLQwenImage | |
| from diffsynth_engine.models.qwen_image.autoencoder_kl_qwenimage import AutoencoderKLQwenImage |
| import torch | ||
| from accelerate import init_empty_weights | ||
| from diffusers.image_processor import PipelineImageInput, VaeImageProcessor | ||
| from diffusers.models import AutoencoderKLQwenImage |
There was a problem hiding this comment.
AutoencoderKLQwenImage should be imported from diffsynth_engine.models.qwen_image.autoencoder_kl_qwenimage to ensure the local all_reduce fix is applied, consistent with the layered pipeline implementation.
| from diffusers.models import AutoencoderKLQwenImage | |
| from diffsynth_engine.models.qwen_image.autoencoder_kl_qwenimage import AutoencoderKLQwenImage |
| import torch | ||
| from accelerate import init_empty_weights | ||
| from diffusers.image_processor import PipelineImageInput, VaeImageProcessor | ||
| from diffusers.models import AutoencoderKLQwenImage |
There was a problem hiding this comment.
init_transformer/init_text_encoder/init_vaefrom 4 QwenImage pipeline subclasses intoPipelinebase class, acceptingmodel_clsas parameter.fully_shard) support in base model loading: transformer blocks and text encoder layers are sharded, weights are broadcast from rank0 viaset_model_state_dict.is_*_group_initialized()helpers inparallel_state.pyto replacedist.is_initialized()checks, enabling finer-grained distributed state detection.load_model_weightsto only read files on rank0 and broadcast viatorch.distributed, reducing redundant I/O on multi-node setups.all_reduceto useGroupCoordinator.all_reduce()instead ofdist.all_reduce(tensor, group=group).DiffusionModel.dtypeproperty; clean up pipeline reference inEngine.close().