From 3250a32898acc8608a21ac91f7f68f8a1ef542d2 Mon Sep 17 00:00:00 2001 From: Koratahiu Date: Sat, 27 Dec 2025 21:37:52 +0300 Subject: [PATCH 01/11] Update adv_optm version to 1.4.1 (#1229) --- requirements-global.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-global.txt b/requirements-global.txt index afa429bf2..97c2b9add 100644 --- a/requirements-global.txt +++ b/requirements-global.txt @@ -41,7 +41,7 @@ prodigyopt==1.1.2 # prodigy optimizer schedulefree==1.4.1 # schedule-free optimizers pytorch_optimizer==3.6.0 # pytorch optimizers prodigy-plus-schedule-free==2.0.1 # Prodigy plus optimizer -adv_optm==1.4.0 # advanced optimizers +adv_optm==1.4.1 # advanced optimizers -e git+https://github.com/KellerJordan/Muon.git@f90a42b#egg=muon-optimizer # Profiling From e5a55b23ec157305cfc732e75d76cadd84cbbbea Mon Sep 17 00:00:00 2001 From: dxqb Date: Tue, 30 Dec 2025 17:55:28 +0100 Subject: [PATCH 02/11] Flux --- modules/dataLoader/BaseDataLoader.py | 39 +- modules/dataLoader/ChromaBaseDataLoader.py | 176 ++----- modules/dataLoader/Flux2BaseDataLoader.py | 147 ++++++ modules/dataLoader/FluxBaseDataLoader.py | 196 ++------ modules/dataLoader/HiDreamBaseDataLoader.py | 236 +++------ .../dataLoader/HunyuanVideoBaseDataLoader.py | 194 ++------ .../dataLoader/PixArtAlphaBaseDataLoader.py | 171 ++----- modules/dataLoader/QwenBaseDataLoader.py | 178 ++----- modules/dataLoader/SanaBaseDataLoader.py | 170 ++----- .../StableDiffusion3BaseDataLoader.py | 213 ++------- .../StableDiffusionBaseDataLoader.py | 176 ++----- .../StableDiffusionFineTuneVaeDataLoader.py | 66 +-- .../StableDiffusionXLBaseDataLoader.py | 193 ++------ .../dataLoader/WuerstchenBaseDataLoader.py | 150 ++---- modules/dataLoader/ZImageBaseDataLoader.py | 166 ++----- .../mixin/DataLoaderText2ImageMixin.py | 139 +++++- modules/model/ChromaModel.py | 5 + modules/model/Flux2Model.py | 298 ++++++++++++ modules/model/FluxModel.py | 2 +- modules/model/WuerstchenModel.py | 3 + modules/modelLoader/Flux2ModelLoader.py | 227 +++++++++ .../GenericEmbeddingModelLoader.py | 5 + .../modelLoader/GenericFineTuneModelLoader.py | 9 + modules/modelLoader/GenericLoRAModelLoader.py | 4 + .../StableDiffusionFineTuneModelLoader.py | 2 + modules/modelSampler/ChromaSampler.py | 3 + modules/modelSampler/Flux2Sampler.py | 190 ++++++++ modules/modelSampler/FluxSampler.py | 7 +- modules/modelSampler/HiDreamSampler.py | 3 + modules/modelSampler/HunyuanVideoSampler.py | 3 + modules/modelSampler/PixArtAlphaSampler.py | 5 +- modules/modelSampler/QwenSampler.py | 3 + modules/modelSampler/SanaSampler.py | 3 + .../modelSampler/StableDiffusion3Sampler.py | 4 + .../modelSampler/StableDiffusionSampler.py | 11 +- .../modelSampler/StableDiffusionVaeSampler.py | 11 + .../modelSampler/StableDiffusionXLSampler.py | 5 +- modules/modelSampler/WuerstchenSampler.py | 4 + modules/modelSampler/ZImageSampler.py | 3 + .../modelSaver/ChromaEmbeddingModelSaver.py | 34 +- .../modelSaver/ChromaFineTuneModelSaver.py | 36 +- modules/modelSaver/ChromaLoRAModelSaver.py | 37 +- modules/modelSaver/Flux2FineTuneModelSaver.py | 11 + modules/modelSaver/Flux2LoRAModelSaver.py | 11 + modules/modelSaver/FluxEmbeddingModelSaver.py | 34 +- modules/modelSaver/FluxFineTuneModelSaver.py | 36 +- modules/modelSaver/FluxLoRAModelSaver.py | 37 +- .../modelSaver/GenericEmbeddingModelSaver.py | 46 ++ .../modelSaver/GenericFineTuneModelSaver.py | 49 ++ modules/modelSaver/GenericLoRAModelSaver.py | 50 ++ .../modelSaver/HiDreamEmbeddingModelSaver.py | 34 +- modules/modelSaver/HiDreamLoRAModelSaver.py | 37 +- .../HunyuanVideoEmbeddingModelSaver.py | 34 +- .../HunyuanVideoFineTuneModelSaver.py | 36 +- .../modelSaver/HunyuanVideoLoRAModelSaver.py | 38 +- .../PixArtAlphaEmbeddingModelSaver.py | 34 +- .../PixArtAlphaFineTuneModelSaver.py | 36 +- .../modelSaver/PixArtAlphaLoRAModelSaver.py | 37 +- modules/modelSaver/QwenFineTuneModelSaver.py | 34 +- modules/modelSaver/QwenLoRAModelSaver.py | 33 +- modules/modelSaver/SanaEmbeddingModelSaver.py | 34 +- modules/modelSaver/SanaFineTuneModelSaver.py | 36 +- modules/modelSaver/SanaLoRAModelSaver.py | 37 +- .../StableDiffusion3EmbeddingModelSaver.py | 34 +- .../StableDiffusion3FineTuneModelSaver.py | 36 +- .../StableDiffusion3LoRAModelSaver.py | 37 +- .../StableDiffusionEmbeddingModelSaver.py | 35 +- .../StableDiffusionFineTuneModelSaver.py | 39 +- .../StableDiffusionLoRAModelSaver.py | 38 +- .../StableDiffusionXLEmbeddingModelSaver.py | 34 +- .../StableDiffusionXLFineTuneModelSaver.py | 36 +- .../StableDiffusionXLLoRAModelSaver.py | 37 +- .../WuerstchenEmbeddingModelSaver.py | 34 +- .../WuerstchenFineTuneModelSaver.py | 36 +- .../modelSaver/WuerstchenLoRAModelSaver.py | 37 +- .../modelSaver/ZImageFineTuneModelSaver.py | 34 +- modules/modelSaver/ZImageLoRAModelSaver.py | 33 +- modules/modelSaver/flux2/Flux2LoRASaver.py | 52 ++ modules/modelSaver/flux2/Flux2ModelSaver.py | 85 ++++ modules/modelSetup/BaseChromaSetup.py | 89 +--- modules/modelSetup/BaseFlux2Setup.py | 198 ++++++++ modules/modelSetup/BaseFluxSetup.py | 88 +--- modules/modelSetup/BaseHiDreamSetup.py | 95 ++-- modules/modelSetup/BaseHunyuanVideoSetup.py | 88 +--- modules/modelSetup/BasePixArtAlphaSetup.py | 22 +- modules/modelSetup/BaseQwenSetup.py | 88 +--- modules/modelSetup/BaseSanaSetup.py | 86 +--- .../modelSetup/BaseStableDiffusion3Setup.py | 89 ++-- .../modelSetup/BaseStableDiffusionSetup.py | 20 +- .../modelSetup/BaseStableDiffusionXLSetup.py | 23 +- modules/modelSetup/BaseWuerstchenSetup.py | 30 +- modules/modelSetup/BaseZImageSetup.py | 83 +--- modules/modelSetup/ChromaEmbeddingSetup.py | 6 + modules/modelSetup/ChromaFineTuneSetup.py | 6 + modules/modelSetup/ChromaLoRASetup.py | 6 + modules/modelSetup/Flux2FineTuneSetup.py | 88 ++++ modules/modelSetup/Flux2LoRASetup.py | 101 ++++ modules/modelSetup/FluxEmbeddingSetup.py | 7 + modules/modelSetup/FluxFineTuneSetup.py | 7 + modules/modelSetup/FluxLoRASetup.py | 7 + modules/modelSetup/HiDreamEmbeddingSetup.py | 7 +- modules/modelSetup/HiDreamFineTuneSetup.py | 6 + modules/modelSetup/HiDreamLoRASetup.py | 6 + .../modelSetup/HunyuanVideoEmbeddingSetup.py | 7 +- .../modelSetup/HunyuanVideoFineTuneSetup.py | 7 +- modules/modelSetup/HunyuanVideoLoRASetup.py | 6 + .../modelSetup/PixArtAlphaEmbeddingSetup.py | 7 + .../modelSetup/PixArtAlphaFineTuneSetup.py | 7 + modules/modelSetup/PixArtAlphaLoRASetup.py | 7 + modules/modelSetup/QwenFineTuneSetup.py | 6 + modules/modelSetup/QwenLoRASetup.py | 6 + modules/modelSetup/SanaEmbeddingSetup.py | 6 + modules/modelSetup/SanaFineTuneSetup.py | 6 + modules/modelSetup/SanaLoRASetup.py | 6 + .../StableDiffusion3EmbeddingSetup.py | 7 + .../StableDiffusion3FineTuneSetup.py | 7 + .../modelSetup/StableDiffusion3LoRASetup.py | 7 + .../StableDiffusionEmbeddingSetup.py | 13 + .../StableDiffusionFineTuneSetup.py | 13 + .../StableDiffusionFineTuneVaeSetup.py | 13 + .../modelSetup/StableDiffusionLoRASetup.py | 13 + .../StableDiffusionXLEmbeddingSetup.py | 7 + .../StableDiffusionXLFineTuneSetup.py | 7 + .../modelSetup/StableDiffusionXLLoRASetup.py | 7 + .../modelSetup/WuerstchenEmbeddingSetup.py | 7 + modules/modelSetup/WuerstchenFineTuneSetup.py | 7 + modules/modelSetup/WuerstchenLoRASetup.py | 7 + modules/modelSetup/ZImageFineTuneSetup.py | 6 + modules/modelSetup/ZImageLoRASetup.py | 6 + .../modelSetup/mixin/ModelSetupDebugMixin.py | 11 + .../mixin/ModelSetupText2ImageMixin.py | 23 + modules/module/quantized/LinearW8A8.py | 4 +- modules/trainer/BaseTrainer.py | 3 +- modules/trainer/GenericTrainer.py | 2 +- modules/ui/ModelTab.py | 70 ++- modules/ui/TopBar.py | 3 +- modules/ui/TrainingTab.py | 68 +-- modules/util/checkpointing_util.py | 22 +- modules/util/config/SampleConfig.py | 3 + modules/util/config/TrainConfig.py | 11 +- modules/util/convert_util.py | 294 ++++++++++++ modules/util/create.py | 449 ++---------------- modules/util/enum/ModelFormat.py | 3 + modules/util/enum/ModelType.py | 11 +- modules/util/factory.py | 25 + requirements-global.txt | 4 +- .../sd_model_spec/flux_dev_2.0-lora.json | 6 + resources/sd_model_spec/flux_dev_2.0.json | 6 + 148 files changed, 3578 insertions(+), 3782 deletions(-) create mode 100644 modules/dataLoader/Flux2BaseDataLoader.py create mode 100644 modules/model/Flux2Model.py create mode 100644 modules/modelLoader/Flux2ModelLoader.py create mode 100644 modules/modelSampler/Flux2Sampler.py create mode 100644 modules/modelSaver/Flux2FineTuneModelSaver.py create mode 100644 modules/modelSaver/Flux2LoRAModelSaver.py create mode 100644 modules/modelSaver/GenericEmbeddingModelSaver.py create mode 100644 modules/modelSaver/GenericFineTuneModelSaver.py create mode 100644 modules/modelSaver/GenericLoRAModelSaver.py create mode 100644 modules/modelSaver/flux2/Flux2LoRASaver.py create mode 100644 modules/modelSaver/flux2/Flux2ModelSaver.py create mode 100644 modules/modelSetup/BaseFlux2Setup.py create mode 100644 modules/modelSetup/Flux2FineTuneSetup.py create mode 100644 modules/modelSetup/Flux2LoRASetup.py create mode 100644 modules/modelSetup/mixin/ModelSetupText2ImageMixin.py create mode 100644 modules/util/convert_util.py create mode 100644 modules/util/factory.py create mode 100644 resources/sd_model_spec/flux_dev_2.0-lora.json create mode 100644 resources/sd_model_spec/flux_dev_2.0.json diff --git a/modules/dataLoader/BaseDataLoader.py b/modules/dataLoader/BaseDataLoader.py index 0de7175db..c1b325b0d 100644 --- a/modules/dataLoader/BaseDataLoader.py +++ b/modules/dataLoader/BaseDataLoader.py @@ -1,6 +1,11 @@ +import copy from abc import ABCMeta, abstractmethod from modules.dataLoader.mixin.DataLoaderMgdsMixin import DataLoaderMgdsMixin +from modules.model.BaseModel import BaseModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.util.config.TrainConfig import TrainConfig +from modules.util.TrainProgress import TrainProgress from mgds.MGDS import MGDS, TrainDataLoader @@ -16,16 +21,44 @@ def __init__( self, train_device: torch.device, temp_device: torch.device, + config: TrainConfig, + model: BaseModel, + model_setup: BaseModelSetup, + train_progress: TrainProgress, + is_validation: bool = False, ): super().__init__() self.train_device = train_device self.temp_device = temp_device - @abstractmethod + if is_validation: + config = copy.copy(config) + config.batch_size = 1 + config.multi_gpu = False + + self.__ds = self._create_dataset( + config=config, + model=model, + model_setup=model_setup, + train_progress=train_progress, + is_validation=is_validation, + ) + self.__dl = TrainDataLoader(self.__ds, config.batch_size) + def get_data_set(self) -> MGDS: - pass + return self.__ds - @abstractmethod def get_data_loader(self) -> TrainDataLoader: + return self.__dl + + @abstractmethod + def _create_dataset( + self, + config: TrainConfig, + model: BaseModel, + model_setup: BaseModelSetup, + train_progress: TrainProgress, + is_validation, + ): pass diff --git a/modules/dataLoader/ChromaBaseDataLoader.py b/modules/dataLoader/ChromaBaseDataLoader.py index 8af4cfbbb..580152ce5 100644 --- a/modules/dataLoader/ChromaBaseDataLoader.py +++ b/modules/dataLoader/ChromaBaseDataLoader.py @@ -1,17 +1,18 @@ -import copy import os from modules.dataLoader.BaseDataLoader import BaseDataLoader from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin +from modules.model.BaseModel import BaseModel from modules.model.ChromaModel import ChromaModel +from modules.modelSetup.BaseChromaSetup import BaseChromaSetup +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig -from modules.util.torch_util import torch_gc +from modules.util.enum.ModelType import ModelType from modules.util.TrainProgress import TrainProgress -from mgds.MGDS import MGDS, TrainDataLoader from mgds.pipelineModules.DecodeTokens import DecodeTokens from mgds.pipelineModules.DecodeVAE import DecodeVAE -from mgds.pipelineModules.DiskCache import DiskCache from mgds.pipelineModules.EncodeT5Text import EncodeT5Text from mgds.pipelineModules.EncodeVAE import EncodeVAE from mgds.pipelineModules.MapData import MapData @@ -21,49 +22,12 @@ from mgds.pipelineModules.SaveText import SaveText from mgds.pipelineModules.ScaleImage import ScaleImage from mgds.pipelineModules.Tokenize import Tokenize -from mgds.pipelineModules.VariationSorting import VariationSorting -import torch - -#TODO share more code with Flux class ChromaBaseDataLoader( BaseDataLoader, DataLoaderText2ImageMixin, ): - def __init__( - self, - train_device: torch.device, - temp_device: torch.device, - config: TrainConfig, - model: ChromaModel, - train_progress: TrainProgress, - is_validation: bool = False, - ): - super().__init__( - train_device, - temp_device, - ) - - if is_validation: - config = copy.copy(config) - config.batch_size = 1 - config.multi_gpu = False - - self.__ds = self.create_dataset( - config=config, - model=model, - train_progress=train_progress, - is_validation=is_validation, - ) - self.__dl = TrainDataLoader(self.__ds, config.batch_size) - - def get_data_set(self) -> MGDS: - return self.__ds - - def get_data_loader(self) -> TrainDataLoader: - return self.__dl - def _preparation_modules(self, config: TrainConfig, model: ChromaModel): rescale_image = RescaleImageChannels(image_in_name='image', image_out_name='image', in_range_min=0, in_range_max=1, out_range_min=-1, out_range_max=1) encode_image = EncodeVAE(in_name='image', out_name='latent_image_distribution', vae=model.vae, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) @@ -71,22 +35,21 @@ def _preparation_modules(self, config: TrainConfig, model: ChromaModel): downscale_mask = ScaleImage(in_name='mask', out_name='latent_mask', factor=0.125) add_embeddings_to_prompt = MapData(in_name='prompt', out_name='prompt', map_fn=model.add_text_encoder_embeddings_to_prompt) tokenize_prompt = Tokenize(in_name='prompt', tokens_out_name='tokens', mask_out_name='tokens_mask', tokenizer=model.tokenizer, max_token_length=model.tokenizer.model_max_length, expand_mask=1) - encode_prompt = EncodeT5Text(tokens_in_name='tokens', tokens_attention_mask_in_name="tokens_mask", hidden_state_out_name='text_encoder_hidden_state', pooled_out_name=None, add_layer_norm=True, text_encoder=model.text_encoder, hidden_state_output_index=-(1 + config.text_encoder_layer_skip), autocast_contexts=[model.autocast_context, model.text_encoder_autocast_context], dtype=model.text_encoder_train_dtype.torch_dtype()) + encode_prompt = EncodeT5Text(tokens_in_name='tokens', tokens_attention_mask_in_name="tokens_mask", hidden_state_out_name='text_encoder_hidden_state', pooled_out_name=None, add_layer_norm=True, + text_encoder=model.text_encoder, hidden_state_output_index=-(1 + config.text_encoder_layer_skip), autocast_contexts=[model.autocast_context, model.text_encoder_autocast_context], + dtype=model.text_encoder_train_dtype.torch_dtype()) modules = [rescale_image, encode_image, image_sample] - - modules.append(add_embeddings_to_prompt) - modules.append(tokenize_prompt) - if config.masked_training or config.model_type.has_mask_input(): modules.append(downscale_mask) + modules += [add_embeddings_to_prompt, tokenize_prompt] if not config.train_text_encoder_or_embedding(): modules.append(encode_prompt) return modules - def _cache_modules(self, config: TrainConfig, model: ChromaModel): + def _cache_modules(self, config: TrainConfig, model: ChromaModel, model_setup: BaseChromaSetup): image_split_names = ['latent_image', 'original_resolution', 'crop_offset'] if config.masked_training or config.model_type.has_mask_input(): @@ -102,53 +65,19 @@ def _cache_modules(self, config: TrainConfig, model: ChromaModel): ] if not config.train_text_encoder_or_embedding(): - text_split_names.append('tokens') - text_split_names.append('tokens_mask') - text_split_names.append('text_encoder_hidden_state') - - image_cache_dir = os.path.join(config.cache_dir, "image") - text_cache_dir = os.path.join(config.cache_dir, "text") - - #TODO share more code with other models - def before_cache_image_fun(): - model.to(self.temp_device) - model.vae_to(self.train_device) - model.eval() - torch_gc() - - def before_cache_text_fun(): - model.to(self.temp_device) - - if not config.train_text_encoder_or_embedding(): - model.text_encoder_to(self.train_device) - - model.eval() - torch_gc() - - image_disk_cache = DiskCache(cache_dir=image_cache_dir, split_names=image_split_names, aggregate_names=image_aggregate_names, variations_in_name='concept.image_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.image'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_image_fun) - - text_disk_cache = DiskCache(cache_dir=text_cache_dir, split_names=text_split_names, aggregate_names=[], variations_in_name='concept.text_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_text_fun) - - modules = [] - - if config.latent_caching: - modules.append(image_disk_cache) - - if config.latent_caching: - sort_names = [x for x in sort_names if x not in image_aggregate_names] - sort_names = [x for x in sort_names if x not in image_split_names] - - if not config.train_text_encoder_or_embedding(): - modules.append(text_disk_cache) - sort_names = [x for x in sort_names if x not in text_split_names] - - if len(sort_names) > 0: - variation_sorting = VariationSorting(names=sort_names, balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled') - modules.append(variation_sorting) - - return modules + text_split_names += ['tokens', 'tokens_mask', 'text_encoder_hidden_state'] + + return self._cache_modules_from_names( + model, model_setup, + image_split_names=image_split_names, + image_aggregate_names=image_aggregate_names, + text_split_names=text_split_names, + sort_names=sort_names, + config=config, + text_caching = not config.train_text_encoder_or_embedding(), + ) - def _output_modules(self, config: TrainConfig, model: ChromaModel): + def _output_modules(self, config: TrainConfig, model: ChromaModel, model_setup: BaseChromaSetup): output_names = [ 'image_path', 'latent_image', 'prompt', @@ -163,16 +92,10 @@ def _output_modules(self, config: TrainConfig, model: ChromaModel): if not config.train_text_encoder_or_embedding(): output_names.append('text_encoder_hidden_state') - def before_cache_image_fun(): - model.to(self.temp_device) - model.vae_to(self.train_device) - model.eval() - torch_gc() - return self._output_modules_from_out_names( + model, model_setup, output_names=output_names, config=config, - before_cache_image_fun=before_cache_image_fun, use_conditioning_image=False, vae=model.vae, autocast_context=[model.autocast_context], @@ -197,57 +120,26 @@ def before_save_fun(): # SaveImage(image_in_name='mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1), # SaveImage(image_in_name='image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1), - modules = [] - - modules.append(decode_image) - modules.append(save_image) + modules = [decode_image, save_image] if config.masked_training or config.model_type.has_mask_input(): - modules.append(upscale_mask) - modules.append(save_mask) + modules += [upscale_mask, save_mask] - modules.append(decode_prompt) - modules.append(save_prompt) + modules += [decode_prompt, save_prompt] return modules - def create_dataset( + def _create_dataset( self, config: TrainConfig, - model: ChromaModel, + model: BaseModel, + model_setup: BaseModelSetup, train_progress: TrainProgress, is_validation: bool = False, ): - enumerate_input = self._enumerate_input_modules(config) - load_input = self._load_input_modules(config, model.train_dtype) - mask_augmentation = self._mask_augmentation_modules(config) - aspect_bucketing_in = self._aspect_bucketing_in(config, 64) - crop_modules = self._crop_modules(config) - augmentation_modules = self._augmentation_modules(config) - inpainting_modules = self._inpainting_modules(config) - preparation_modules = self._preparation_modules(config, model) - cache_modules = self._cache_modules(config, model) - output_modules = self._output_modules(config, model) - - debug_modules = self._debug_modules(config, model) - - return self._create_mgds( - config, - [ - enumerate_input, - load_input, - mask_augmentation, - aspect_bucketing_in, - crop_modules, - augmentation_modules, - inpainting_modules, - preparation_modules, - cache_modules, - output_modules, - - debug_modules if config.debug_mode else None, - # inserted before output_modules, which contains a sorting operation - ], - train_progress, - is_validation + return DataLoaderText2ImageMixin._create_dataset(self, + config, model, model_setup, train_progress, is_validation, + aspect_bucketing_quantization=64, ) + +factory.register(BaseDataLoader, ChromaBaseDataLoader, ModelType.CHROMA_1) diff --git a/modules/dataLoader/Flux2BaseDataLoader.py b/modules/dataLoader/Flux2BaseDataLoader.py new file mode 100644 index 000000000..d44c7ce58 --- /dev/null +++ b/modules/dataLoader/Flux2BaseDataLoader.py @@ -0,0 +1,147 @@ +import os + +from modules.dataLoader.BaseDataLoader import BaseDataLoader +from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin +from modules.model.Flux2Model import HIDDEN_STATES_LAYERS, SYSTEM_MESSAGE, Flux2Model +from modules.modelSetup.BaseFlux2Setup import BaseFlux2Setup +from modules.util import factory +from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.TrainProgress import TrainProgress + +from mgds.pipelineModules.DecodeTokens import DecodeTokens +from mgds.pipelineModules.DecodeVAE import DecodeVAE +from mgds.pipelineModules.EncodeMistralText import EncodeMistralText +from mgds.pipelineModules.EncodeVAE import EncodeVAE +from mgds.pipelineModules.RescaleImageChannels import RescaleImageChannels +from mgds.pipelineModules.SampleVAEDistribution import SampleVAEDistribution +from mgds.pipelineModules.SaveImage import SaveImage +from mgds.pipelineModules.SaveText import SaveText +from mgds.pipelineModules.ScaleImage import ScaleImage +from mgds.pipelineModules.Tokenize import Tokenize + +from diffusers.pipelines.flux2.pipeline_flux2 import format_input + + +class Flux2BaseDataLoader( #TODO share code + BaseDataLoader, + DataLoaderText2ImageMixin, +): + def _preparation_modules(self, config: TrainConfig, model: Flux2Model): + rescale_image = RescaleImageChannels(image_in_name='image', image_out_name='image', in_range_min=0, in_range_max=1, out_range_min=-1, out_range_max=1) + encode_image = EncodeVAE(in_name='image', out_name='latent_image_distribution', vae=model.vae, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) + image_sample = SampleVAEDistribution(in_name='latent_image_distribution', out_name='latent_image', mode='mean') + downscale_mask = ScaleImage(in_name='mask', out_name='latent_mask', factor=0.125) + tokenize_prompt = Tokenize(in_name='prompt', tokens_out_name='tokens', mask_out_name='tokens_mask', tokenizer=model.tokenizer, max_token_length=config.text_encoder_sequence_length, + apply_chat_template = lambda caption: format_input([caption], SYSTEM_MESSAGE), apply_chat_template_kwargs = {'add_generation_prompt': False}, + ) + encode_prompt = EncodeMistralText(tokens_name='tokens', tokens_attention_mask_in_name='tokens_mask', hidden_state_out_name='text_encoder_hidden_state', tokens_attention_mask_out_name='tokens_mask', + text_encoder=model.text_encoder, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype(), + hidden_state_output_index=HIDDEN_STATES_LAYERS, + ) + + modules = [rescale_image, encode_image, image_sample] + if config.masked_training or config.model_type.has_mask_input(): + modules.append(downscale_mask) + + modules += [tokenize_prompt, encode_prompt] + return modules + + def _cache_modules(self, config: TrainConfig, model: Flux2Model, model_setup: BaseFlux2Setup): + image_split_names = ['latent_image', 'original_resolution', 'crop_offset'] + + if config.masked_training or config.model_type.has_mask_input(): + image_split_names.append('latent_mask') + + image_aggregate_names = ['crop_resolution', 'image_path'] + + text_split_names = [] + + sort_names = image_aggregate_names + image_split_names + [ + 'prompt', 'tokens', 'tokens_mask', 'text_encoder_hidden_state', + 'concept' + ] + + text_split_names += ['tokens', 'tokens_mask', 'text_encoder_hidden_state'] + + return self._cache_modules_from_names( + model, model_setup, + image_split_names=image_split_names, + image_aggregate_names=image_aggregate_names, + text_split_names=text_split_names, + sort_names=sort_names, + config=config, + text_caching=True, + ) + + def _output_modules(self, config: TrainConfig, model: Flux2Model, model_setup: BaseFlux2Setup): + output_names = [ + 'image_path', 'latent_image', + 'prompt', + 'tokens', + 'tokens_mask', + 'original_resolution', 'crop_resolution', 'crop_offset', + ] + + if config.masked_training or config.model_type.has_mask_input(): + output_names.append('latent_mask') + + output_names.append('text_encoder_hidden_state') + + return self._output_modules_from_out_names( + model, model_setup, + output_names=output_names, + config=config, + use_conditioning_image=False, + vae=model.vae, + autocast_context=[model.autocast_context], + train_dtype=model.train_dtype, + ) + + def _debug_modules(self, config: TrainConfig, model: Flux2Model): + debug_dir = os.path.join(config.debug_dir, "dataloader") + + def before_save_fun(): + model.vae_to(self.train_device) + + decode_image = DecodeVAE(in_name='latent_image', out_name='decoded_image', vae=model.vae, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) + upscale_mask = ScaleImage(in_name='latent_mask', out_name='decoded_mask', factor=8) + decode_prompt = DecodeTokens(in_name='tokens', out_name='decoded_prompt', tokenizer=model.tokenizer) + save_image = SaveImage(image_in_name='decoded_image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1, before_save_fun=before_save_fun) + # SaveImage(image_in_name='latent_mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1, before_save_fun=before_save_fun) + save_mask = SaveImage(image_in_name='decoded_mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1, before_save_fun=before_save_fun) + save_prompt = SaveText(text_in_name='decoded_prompt', original_path_in_name='image_path', path=debug_dir, before_save_fun=before_save_fun) + + # These modules don't really work, since they are inserted after a sorting operation that does not include this data + # SaveImage(image_in_name='mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1), + # SaveImage(image_in_name='image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1), + + modules = [] + + modules.append(decode_image) + modules.append(save_image) + + if config.masked_training or config.model_type.has_mask_input(): + modules.append(upscale_mask) + modules.append(save_mask) + + modules.append(decode_prompt) + modules.append(save_prompt) + + return modules + + def _create_dataset( + self, + config: TrainConfig, + model: Flux2Model, + model_setup: BaseFlux2Setup, + train_progress: TrainProgress, + is_validation: bool = False, + ): + return DataLoaderText2ImageMixin._create_dataset(self, + config, model, model_setup, train_progress, is_validation, + aspect_bucketing_quantization=64, + ) + + +factory.register(BaseDataLoader, Flux2BaseDataLoader, ModelType.FLUX_DEV_2) diff --git a/modules/dataLoader/FluxBaseDataLoader.py b/modules/dataLoader/FluxBaseDataLoader.py index b68430478..858b94b1b 100644 --- a/modules/dataLoader/FluxBaseDataLoader.py +++ b/modules/dataLoader/FluxBaseDataLoader.py @@ -1,18 +1,17 @@ -import copy import os from modules.dataLoader.BaseDataLoader import BaseDataLoader from modules.dataLoader.flux.ShuffleFluxFillMaskChannels import ShuffleFluxFillMaskChannels from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin from modules.model.FluxModel import FluxModel +from modules.modelSetup.BaseFluxSetup import BaseFluxSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig -from modules.util.torch_util import torch_gc +from modules.util.enum.ModelType import ModelType from modules.util.TrainProgress import TrainProgress -from mgds.MGDS import MGDS, TrainDataLoader from mgds.pipelineModules.DecodeTokens import DecodeTokens from mgds.pipelineModules.DecodeVAE import DecodeVAE -from mgds.pipelineModules.DiskCache import DiskCache from mgds.pipelineModules.EncodeClipText import EncodeClipText from mgds.pipelineModules.EncodeT5Text import EncodeT5Text from mgds.pipelineModules.EncodeVAE import EncodeVAE @@ -23,48 +22,12 @@ from mgds.pipelineModules.SaveText import SaveText from mgds.pipelineModules.ScaleImage import ScaleImage from mgds.pipelineModules.Tokenize import Tokenize -from mgds.pipelineModules.VariationSorting import VariationSorting - -import torch class FluxBaseDataLoader( BaseDataLoader, DataLoaderText2ImageMixin, ): - def __init__( - self, - train_device: torch.device, - temp_device: torch.device, - config: TrainConfig, - model: FluxModel, - train_progress: TrainProgress, - is_validation: bool = False, - ): - super().__init__( - train_device, - temp_device, - ) - - if is_validation: - config = copy.copy(config) - config.batch_size = 1 - config.multi_gpu = False - - self.__ds = self.create_dataset( - config=config, - model=model, - train_progress=train_progress, - is_validation=is_validation, - ) - self.__dl = TrainDataLoader(self.__ds, config.batch_size) - - def get_data_set(self) -> MGDS: - return self.__ds - - def get_data_loader(self) -> TrainDataLoader: - return self.__dl - def _preparation_modules(self, config: TrainConfig, model: FluxModel): rescale_image = RescaleImageChannels(image_in_name='image', image_out_name='image', in_range_min=0, in_range_max=1, out_range_min=-1, out_range_max=1) rescale_conditioning_image = RescaleImageChannels(image_in_name='conditioning_image', image_out_name='conditioning_image', in_range_min=0, in_range_max=1, out_range_min=-1, out_range_max=1) @@ -78,37 +41,35 @@ def _preparation_modules(self, config: TrainConfig, model: FluxModel): conditioning_image_sample = SampleVAEDistribution(in_name='latent_conditioning_image_distribution', out_name='latent_conditioning_image', mode='mean') tokenize_prompt_1 = Tokenize(in_name='prompt_1', tokens_out_name='tokens_1', mask_out_name='tokens_mask_1', tokenizer=model.tokenizer_1, max_token_length=model.tokenizer_1.model_max_length) tokenize_prompt_2 = Tokenize(in_name='prompt_2', tokens_out_name='tokens_2', mask_out_name='tokens_mask_2', tokenizer=model.tokenizer_2, max_token_length=config.text_encoder_2_sequence_length) - encode_prompt_1 = EncodeClipText(in_name='tokens_1', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_1_hidden_state', pooled_out_name='text_encoder_1_pooled_state', add_layer_norm=False, text_encoder=model.text_encoder_1, hidden_state_output_index=-(2 + config.text_encoder_layer_skip), autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) - encode_prompt_2 = EncodeT5Text(tokens_in_name='tokens_2', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_2_hidden_state', pooled_out_name=None, add_layer_norm=True, text_encoder=model.text_encoder_2, hidden_state_output_index=-(1 + config.text_encoder_2_layer_skip), autocast_contexts=[model.autocast_context, model.text_encoder_2_autocast_context], dtype=model.text_encoder_2_train_dtype.torch_dtype()) + encode_prompt_1 = EncodeClipText(in_name='tokens_1', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_1_hidden_state', pooled_out_name='text_encoder_1_pooled_state', + add_layer_norm=False, text_encoder=model.text_encoder_1, hidden_state_output_index=-(2 + config.text_encoder_layer_skip), autocast_contexts=[model.autocast_context], + dtype=model.train_dtype.torch_dtype()) + encode_prompt_2 = EncodeT5Text(tokens_in_name='tokens_2', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_2_hidden_state', pooled_out_name=None, add_layer_norm=True, + text_encoder=model.text_encoder_2, hidden_state_output_index=-(1 + config.text_encoder_2_layer_skip), autocast_contexts=[model.autocast_context, model.text_encoder_2_autocast_context], + dtype=model.text_encoder_2_train_dtype.torch_dtype()) modules = [rescale_image, encode_image, image_sample] - - if model.tokenizer_1: - modules.append(add_embeddings_to_prompt_1) - modules.append(tokenize_prompt_1) - if model.tokenizer_2: - modules.append(add_embeddings_to_prompt_2) - modules.append(tokenize_prompt_2) - if config.model_type.has_mask_input(): modules.append(shuffle_mask_channels) elif config.masked_training: modules.append(downscale_mask) if config.model_type.has_conditioning_image_input(): - modules.append(rescale_conditioning_image) - modules.append(encode_conditioning_image) - modules.append(conditioning_image_sample) + modules += [rescale_conditioning_image, encode_conditioning_image, conditioning_image_sample] + if model.tokenizer_1: + modules += [add_embeddings_to_prompt_1, tokenize_prompt_1] if not config.train_text_encoder_or_embedding() and model.text_encoder_1: modules.append(encode_prompt_1) + if model.tokenizer_2: + modules += [add_embeddings_to_prompt_2, tokenize_prompt_2] if not config.train_text_encoder_2_or_embedding() and model.text_encoder_2: modules.append(encode_prompt_2) return modules - def _cache_modules(self, config: TrainConfig, model: FluxModel): + def _cache_modules(self, config: TrainConfig, model: FluxModel, model_setup: BaseFluxSetup): image_split_names = ['latent_image', 'original_resolution', 'crop_offset'] if config.masked_training or config.model_type.has_mask_input(): @@ -128,60 +89,22 @@ def _cache_modules(self, config: TrainConfig, model: FluxModel): ] if not config.train_text_encoder_or_embedding(): - text_split_names.append('tokens_1') - text_split_names.append('tokens_mask_1') - text_split_names.append('text_encoder_1_pooled_state') + text_split_names += ['tokens_1', 'tokens_mask_1', 'text_encoder_1_pooled_state'] if not config.train_text_encoder_2_or_embedding(): - text_split_names.append('tokens_2') - text_split_names.append('tokens_mask_2') - text_split_names.append('text_encoder_2_hidden_state') - - image_cache_dir = os.path.join(config.cache_dir, "image") - text_cache_dir = os.path.join(config.cache_dir, "text") - - def before_cache_image_fun(): - model.to(self.temp_device) - model.vae_to(self.train_device) - model.eval() - torch_gc() - - def before_cache_text_fun(): - model.to(self.temp_device) - - if not config.train_text_encoder_or_embedding(): - model.text_encoder_1_to(self.train_device) - - if not config.train_text_encoder_2_or_embedding(): - model.text_encoder_2_to(self.train_device) - - model.eval() - torch_gc() - - image_disk_cache = DiskCache(cache_dir=image_cache_dir, split_names=image_split_names, aggregate_names=image_aggregate_names, variations_in_name='concept.image_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.image'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_image_fun) - - text_disk_cache = DiskCache(cache_dir=text_cache_dir, split_names=text_split_names, aggregate_names=[], variations_in_name='concept.text_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_text_fun) - - modules = [] - - if config.latent_caching: - modules.append(image_disk_cache) - - if config.latent_caching: - sort_names = [x for x in sort_names if x not in image_aggregate_names] - sort_names = [x for x in sort_names if x not in image_split_names] - - if not config.train_text_encoder_or_embedding() or not config.train_text_encoder_2_or_embedding(): - modules.append(text_disk_cache) - sort_names = [x for x in sort_names if x not in text_split_names] - - if len(sort_names) > 0: - variation_sorting = VariationSorting(names=sort_names, balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled') - modules.append(variation_sorting) - - return modules + text_split_names += ['tokens_2', 'tokens_mask_2', 'text_encoder_2_hidden_state'] + + return self._cache_modules_from_names( + model, model_setup, + image_split_names=image_split_names, + image_aggregate_names=image_aggregate_names, + text_split_names=text_split_names, + sort_names=sort_names, + config=config, + text_caching=not config.train_text_encoder_or_embedding() or not config.train_text_encoder_2_or_embedding(), + ) - def _output_modules(self, config: TrainConfig, model: FluxModel): + def _output_modules(self, config: TrainConfig, model: FluxModel, model_setup: BaseFluxSetup): output_names = [ 'image_path', 'latent_image', 'prompt_1', 'prompt_2', @@ -202,16 +125,10 @@ def _output_modules(self, config: TrainConfig, model: FluxModel): if not config.train_text_encoder_2_or_embedding(): output_names.append('text_encoder_2_hidden_state') - def before_cache_image_fun(): - model.to(self.temp_device) - model.vae_to(self.train_device) - model.eval() - torch_gc() - return self._output_modules_from_out_names( + model, model_setup, output_names=output_names, config=config, - before_cache_image_fun=before_cache_image_fun, use_conditioning_image=True, vae=model.vae, autocast_context=[model.autocast_context], @@ -238,61 +155,30 @@ def before_save_fun(): # SaveImage(image_in_name='mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1), # SaveImage(image_in_name='image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1), - modules = [] - - modules.append(decode_image) - modules.append(save_image) + modules = [decode_image, save_image] if config.model_type.has_conditioning_image_input(): - modules.append(decode_conditioning_image) - modules.append(save_conditioning_image) + modules += [decode_conditioning_image, save_conditioning_image] if config.masked_training or config.model_type.has_mask_input(): - modules.append(upscale_mask) - modules.append(save_mask) + modules += [upscale_mask, save_mask] - modules.append(decode_prompt) - modules.append(save_prompt) + modules += [decode_prompt, save_prompt] return modules - def create_dataset( + def _create_dataset( self, config: TrainConfig, model: FluxModel, + model_setup: BaseFluxSetup, train_progress: TrainProgress, is_validation: bool = False, ): - enumerate_input = self._enumerate_input_modules(config) - load_input = self._load_input_modules(config, model.train_dtype) - mask_augmentation = self._mask_augmentation_modules(config) - aspect_bucketing_in = self._aspect_bucketing_in(config, 64) - crop_modules = self._crop_modules(config) - augmentation_modules = self._augmentation_modules(config) - inpainting_modules = self._inpainting_modules(config) - preparation_modules = self._preparation_modules(config, model) - cache_modules = self._cache_modules(config, model) - output_modules = self._output_modules(config, model) - - debug_modules = self._debug_modules(config, model) - - return self._create_mgds( - config, - [ - enumerate_input, - load_input, - mask_augmentation, - aspect_bucketing_in, - crop_modules, - augmentation_modules, - inpainting_modules, - preparation_modules, - cache_modules, - output_modules, - - debug_modules if config.debug_mode else None, - # inserted before output_modules, which contains a sorting operation - ], - train_progress, - is_validation + return DataLoaderText2ImageMixin._create_dataset(self, + config, model, model_setup, train_progress, is_validation, + aspect_bucketing_quantization=64, ) + +factory.register(BaseDataLoader, FluxBaseDataLoader, ModelType.FLUX_DEV_1) +factory.register(BaseDataLoader, FluxBaseDataLoader, ModelType.FLUX_FILL_DEV_1) diff --git a/modules/dataLoader/HiDreamBaseDataLoader.py b/modules/dataLoader/HiDreamBaseDataLoader.py index 226d73e89..dedea17be 100644 --- a/modules/dataLoader/HiDreamBaseDataLoader.py +++ b/modules/dataLoader/HiDreamBaseDataLoader.py @@ -1,18 +1,19 @@ -import copy import os from modules.dataLoader.BaseDataLoader import BaseDataLoader from modules.dataLoader.flux.ShuffleFluxFillMaskChannels import ShuffleFluxFillMaskChannels from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin +from modules.model.BaseModel import BaseModel from modules.model.HiDreamModel import HiDreamModel +from modules.modelSetup.BaseHiDreamSetup import BaseHiDreamSetup +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig -from modules.util.torch_util import torch_gc +from modules.util.enum.ModelType import ModelType from modules.util.TrainProgress import TrainProgress -from mgds.MGDS import MGDS, TrainDataLoader from mgds.pipelineModules.DecodeTokens import DecodeTokens from mgds.pipelineModules.DecodeVAE import DecodeVAE -from mgds.pipelineModules.DiskCache import DiskCache from mgds.pipelineModules.EncodeClipText import EncodeClipText from mgds.pipelineModules.EncodeLlamaText import EncodeLlamaText from mgds.pipelineModules.EncodeT5Text import EncodeT5Text @@ -24,47 +25,12 @@ from mgds.pipelineModules.SaveText import SaveText from mgds.pipelineModules.ScaleImage import ScaleImage from mgds.pipelineModules.Tokenize import Tokenize -from mgds.pipelineModules.VariationSorting import VariationSorting - -import torch class HiDreamBaseDataLoader( BaseDataLoader, DataLoaderText2ImageMixin, ): - def __init__( - self, - train_device: torch.device, - temp_device: torch.device, - config: TrainConfig, - model: HiDreamModel, - train_progress: TrainProgress, - is_validation: bool = False, - ): - super().__init__( - train_device, - temp_device, - ) - - if is_validation: - config = copy.copy(config) - config.batch_size = 1 - - self.__ds = self.create_dataset( - config=config, - model=model, - train_progress=train_progress, - is_validation=is_validation, - ) - self.__dl = TrainDataLoader(self.__ds, config.batch_size) - - def get_data_set(self) -> MGDS: - return self.__ds - - def get_data_loader(self) -> TrainDataLoader: - return self.__dl - def _preparation_modules(self, config: TrainConfig, model: HiDreamModel): rescale_image = RescaleImageChannels(image_in_name='image', image_out_name='image', in_range_min=0, in_range_max=1, out_range_min=-1, out_range_max=1) rescale_conditioning_image = RescaleImageChannels(image_in_name='conditioning_image', image_out_name='conditioning_image', in_range_min=0, in_range_max=1, out_range_min=-1, out_range_max=1) @@ -82,35 +48,35 @@ def _preparation_modules(self, config: TrainConfig, model: HiDreamModel): tokenize_prompt_2 = Tokenize(in_name='prompt_2', tokens_out_name='tokens_2', mask_out_name='tokens_mask_2', tokenizer=model.tokenizer_2, max_token_length=128) tokenize_prompt_3 = Tokenize(in_name='prompt_3', tokens_out_name='tokens_3', mask_out_name='tokens_mask_3', tokenizer=model.tokenizer_3, max_token_length=128) tokenize_prompt_4 = Tokenize(in_name='prompt_4', tokens_out_name='tokens_4', mask_out_name='tokens_mask_4', tokenizer=model.tokenizer_4, max_token_length=128) - encode_prompt_1 = EncodeClipText(in_name='tokens_1', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_1_hidden_state', pooled_out_name='text_encoder_1_pooled_state', add_layer_norm=False, text_encoder=model.text_encoder_1, hidden_state_output_index=-(2 + config.text_encoder_layer_skip), autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) - encode_prompt_2 = EncodeClipText(in_name='tokens_2', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_2_hidden_state', pooled_out_name='text_encoder_2_pooled_state', add_layer_norm=False, text_encoder=model.text_encoder_2, hidden_state_output_index=-(2 + config.text_encoder_2_layer_skip), autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) - encode_prompt_3 = EncodeT5Text(tokens_in_name='tokens_3', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_3_hidden_state', pooled_out_name=None, add_layer_norm=True, text_encoder=model.text_encoder_3, hidden_state_output_index=-(1 + config.text_encoder_3_layer_skip), autocast_contexts=[model.autocast_context, model.text_encoder_3_autocast_context], dtype=model.text_encoder_3_train_dtype.torch_dtype()) - encode_prompt_4 = EncodeLlamaText(tokens_name='tokens_4', tokens_attention_mask_in_name='tokens_mask_4', hidden_state_out_name='text_encoder_4_hidden_state', tokens_attention_mask_out_name='tokens_mask_4', text_encoder=model.text_encoder_4, output_all_hidden_states=True, all_hidden_state_output_indices=model.transformer.config.llama_layers, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) + encode_prompt_1 = EncodeClipText(in_name='tokens_1', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_1_hidden_state', pooled_out_name='text_encoder_1_pooled_state', + add_layer_norm=False, text_encoder=model.text_encoder_1, hidden_state_output_index=-(2 + config.text_encoder_layer_skip), autocast_contexts=[model.autocast_context], + dtype=model.train_dtype.torch_dtype()) + encode_prompt_2 = EncodeClipText(in_name='tokens_2', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_2_hidden_state', pooled_out_name='text_encoder_2_pooled_state', + add_layer_norm=False, text_encoder=model.text_encoder_2, hidden_state_output_index=-(2 + config.text_encoder_2_layer_skip), autocast_contexts=[model.autocast_context], + dtype=model.train_dtype.torch_dtype()) + encode_prompt_3 = EncodeT5Text(tokens_in_name='tokens_3', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_3_hidden_state', pooled_out_name=None, add_layer_norm=True, + text_encoder=model.text_encoder_3, hidden_state_output_index=-(1 + config.text_encoder_3_layer_skip), autocast_contexts=[model.autocast_context, model.text_encoder_3_autocast_context], + dtype=model.text_encoder_3_train_dtype.torch_dtype()) + encode_prompt_4 = EncodeLlamaText(tokens_name='tokens_4', tokens_attention_mask_in_name='tokens_mask_4', hidden_state_out_name='text_encoder_4_hidden_state', tokens_attention_mask_out_name='tokens_mask_4', text_encoder=model.text_encoder_4, + output_all_hidden_states=True, all_hidden_state_output_indices=model.transformer.config.llama_layers, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) modules = [rescale_image, encode_image, image_sample] - - if model.tokenizer_1: - modules.append(add_embeddings_to_prompt_1) - modules.append(tokenize_prompt_1) - if model.tokenizer_2: - modules.append(add_embeddings_to_prompt_2) - modules.append(tokenize_prompt_2) - if model.tokenizer_3: - modules.append(add_embeddings_to_prompt_3) - modules.append(tokenize_prompt_3) - if model.tokenizer_4: - modules.append(add_embeddings_to_prompt_4) - modules.append(tokenize_prompt_4) - if config.model_type.has_mask_input(): modules.append(shuffle_mask_channels) elif config.masked_training: modules.append(downscale_mask) if config.model_type.has_conditioning_image_input(): - modules.append(rescale_conditioning_image) - modules.append(encode_conditioning_image) - modules.append(conditioning_image_sample) + modules += [rescale_conditioning_image, encode_conditioning_image, conditioning_image_sample] + + if model.tokenizer_1: + modules += [add_embeddings_to_prompt_1, tokenize_prompt_1] + if model.tokenizer_2: + modules += [add_embeddings_to_prompt_2, tokenize_prompt_2] + if model.tokenizer_3: + modules += [add_embeddings_to_prompt_3, tokenize_prompt_3] + if model.tokenizer_4: + modules += [add_embeddings_to_prompt_4, tokenize_prompt_4] if not config.train_text_encoder_or_embedding() and model.text_encoder_1: modules.append(encode_prompt_1) @@ -126,7 +92,7 @@ def _preparation_modules(self, config: TrainConfig, model: HiDreamModel): return modules - def _cache_modules(self, config: TrainConfig, model: HiDreamModel): + def _cache_modules(self, config: TrainConfig, model: HiDreamModel, model_setup: BaseHiDreamSetup): image_split_names = ['latent_image', 'original_resolution', 'crop_offset'] if config.masked_training or config.model_type.has_mask_input(): @@ -148,79 +114,31 @@ def _cache_modules(self, config: TrainConfig, model: HiDreamModel): ] if not config.train_text_encoder_or_embedding(): - text_split_names.append('tokens_1') - text_split_names.append('tokens_mask_1') - text_split_names.append('text_encoder_1_pooled_state') + text_split_names += ['tokens_1', 'tokens_mask_1', 'text_encoder_1_pooled_state'] if not config.train_text_encoder_2_or_embedding(): - text_split_names.append('tokens_2') - text_split_names.append('tokens_mask_2') - text_split_names.append('text_encoder_2_pooled_state') + text_split_names += ['tokens_2', 'tokens_mask_2', 'text_encoder_2_pooled_state'] if not config.train_text_encoder_3_or_embedding(): - text_split_names.append('tokens_3') - text_split_names.append('tokens_mask_3') - text_split_names.append('text_encoder_3_hidden_state') + text_split_names += ['tokens_3', 'tokens_mask_3', 'text_encoder_3_hidden_state'] if not config.train_text_encoder_4_or_embedding(): - text_split_names.append('tokens_4') - text_split_names.append('tokens_mask_4') - text_split_names.append('text_encoder_4_hidden_state') - - image_cache_dir = os.path.join(config.cache_dir, "image") - text_cache_dir = os.path.join(config.cache_dir, "text") - - def before_cache_image_fun(): - model.to(self.temp_device) - model.vae_to(self.train_device) - model.eval() - torch_gc() - - def before_cache_text_fun(): - model.to(self.temp_device) - - if not config.train_text_encoder_or_embedding(): - model.text_encoder_1_to(self.train_device) - - if not config.train_text_encoder_2_or_embedding(): - model.text_encoder_2_to(self.train_device) - - if not config.train_text_encoder_3_or_embedding(): - model.text_encoder_3_to(self.train_device) - - if not config.train_text_encoder_4_or_embedding(): - model.text_encoder_4_to(self.train_device) - - model.eval() - torch_gc() - - image_disk_cache = DiskCache(cache_dir=image_cache_dir, split_names=image_split_names, aggregate_names=image_aggregate_names, variations_in_name='concept.image_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.image'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_image_fun) - - text_disk_cache = DiskCache(cache_dir=text_cache_dir, split_names=text_split_names, aggregate_names=[], variations_in_name='concept.text_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_text_fun) - - modules = [] - - if config.latent_caching: - modules.append(image_disk_cache) - - if config.latent_caching: - sort_names = [x for x in sort_names if x not in image_aggregate_names] - sort_names = [x for x in sort_names if x not in image_split_names] - - if not config.train_text_encoder_or_embedding() \ - or not config.train_text_encoder_2_or_embedding() \ - or not config.train_text_encoder_3_or_embedding() \ - or not config.train_text_encoder_4_or_embedding(): - modules.append(text_disk_cache) - sort_names = [x for x in sort_names if x not in text_split_names] - - if len(sort_names) > 0: - variation_sorting = VariationSorting(names=sort_names, balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled') - modules.append(variation_sorting) - - return modules + text_split_names += ['tokens_4', 'tokens_mask_4', 'text_encoder_4_hidden_state'] + + return self._cache_modules_from_names( + model, model_setup, + image_split_names=image_split_names, + image_aggregate_names=image_aggregate_names, + text_split_names=text_split_names, + sort_names=sort_names, + config=config, + text_caching=not config.train_text_encoder_or_embedding() \ + or not config.train_text_encoder_2_or_embedding() \ + or not config.train_text_encoder_3_or_embedding() \ + or not config.train_text_encoder_4_or_embedding(), + ) - def _output_modules(self, config: TrainConfig, model: HiDreamModel): + def _output_modules(self, config: TrainConfig, model: HiDreamModel, model_setup: BaseHiDreamSetup): output_names = [ 'image_path', 'latent_image', 'prompt_1', 'prompt_2', 'prompt_3', 'prompt_4', @@ -247,16 +165,10 @@ def _output_modules(self, config: TrainConfig, model: HiDreamModel): if not config.train_text_encoder_4_or_embedding(): output_names.append('text_encoder_4_hidden_state') - def before_cache_image_fun(): - model.to(self.temp_device) - model.vae_to(self.train_device) - model.eval() - torch_gc() - return self._output_modules_from_out_names( + model, model_setup, output_names=output_names, config=config, - before_cache_image_fun=before_cache_image_fun, use_conditioning_image=True, vae=model.vae, autocast_context=[model.autocast_context], @@ -283,61 +195,29 @@ def before_save_fun(): # SaveImage(image_in_name='mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1), # SaveImage(image_in_name='image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1), - modules = [] - - modules.append(decode_image) - modules.append(save_image) + modules = [decode_image, save_image] if config.model_type.has_conditioning_image_input(): - modules.append(decode_conditioning_image) - modules.append(save_conditioning_image) + modules += [decode_conditioning_image, save_conditioning_image] if config.masked_training or config.model_type.has_mask_input(): - modules.append(upscale_mask) - modules.append(save_mask) + modules += [upscale_mask, save_mask] - modules.append(decode_prompt) - modules.append(save_prompt) + modules += [decode_prompt, save_prompt] return modules - def create_dataset( + def _create_dataset( self, config: TrainConfig, - model: HiDreamModel, + model: BaseModel, + model_setup: BaseModelSetup, train_progress: TrainProgress, is_validation: bool = False, ): - enumerate_input = self._enumerate_input_modules(config) - load_input = self._load_input_modules(config, model.train_dtype) - mask_augmentation = self._mask_augmentation_modules(config) - aspect_bucketing_in = self._aspect_bucketing_in(config, 64) - crop_modules = self._crop_modules(config) - augmentation_modules = self._augmentation_modules(config) - inpainting_modules = self._inpainting_modules(config) - preparation_modules = self._preparation_modules(config, model) - cache_modules = self._cache_modules(config, model) - output_modules = self._output_modules(config, model) - - debug_modules = self._debug_modules(config, model) - - return self._create_mgds( - config, - [ - enumerate_input, - load_input, - mask_augmentation, - aspect_bucketing_in, - crop_modules, - augmentation_modules, - inpainting_modules, - preparation_modules, - cache_modules, - output_modules, - - debug_modules if config.debug_mode else None, - # inserted before output_modules, which contains a sorting operation - ], - train_progress, - is_validation + return DataLoaderText2ImageMixin._create_dataset(self, + config, model, model_setup, train_progress, is_validation, + aspect_bucketing_quantization=64, ) + +factory.register(BaseDataLoader, HiDreamBaseDataLoader, ModelType.HI_DREAM_FULL) diff --git a/modules/dataLoader/HunyuanVideoBaseDataLoader.py b/modules/dataLoader/HunyuanVideoBaseDataLoader.py index bfb93ffa8..78ab93620 100644 --- a/modules/dataLoader/HunyuanVideoBaseDataLoader.py +++ b/modules/dataLoader/HunyuanVideoBaseDataLoader.py @@ -1,21 +1,22 @@ -import copy import os from modules.dataLoader.BaseDataLoader import BaseDataLoader from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin +from modules.model.BaseModel import BaseModel from modules.model.HunyuanVideoModel import ( DEFAULT_PROMPT_TEMPLATE, DEFAULT_PROMPT_TEMPLATE_CROP_START, HunyuanVideoModel, ) +from modules.modelSetup.BaseHunyuanVideoSetup import BaseHunyuanVideoSetup +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig -from modules.util.torch_util import torch_gc +from modules.util.enum.ModelType import ModelType from modules.util.TrainProgress import TrainProgress -from mgds.MGDS import MGDS, TrainDataLoader from mgds.pipelineModules.DecodeTokens import DecodeTokens from mgds.pipelineModules.DecodeVAE import DecodeVAE -from mgds.pipelineModules.DiskCache import DiskCache from mgds.pipelineModules.EncodeClipText import EncodeClipText from mgds.pipelineModules.EncodeLlamaText import EncodeLlamaText from mgds.pipelineModules.EncodeVAE import EncodeVAE @@ -26,48 +27,12 @@ from mgds.pipelineModules.SaveText import SaveText from mgds.pipelineModules.ScaleImage import ScaleImage from mgds.pipelineModules.Tokenize import Tokenize -from mgds.pipelineModules.VariationSorting import VariationSorting - -import torch class HunyuanVideoBaseDataLoader( BaseDataLoader, DataLoaderText2ImageMixin, ): - def __init__( - self, - train_device: torch.device, - temp_device: torch.device, - config: TrainConfig, - model: HunyuanVideoModel, - train_progress: TrainProgress, - is_validation: bool = False, - ): - super().__init__( - train_device, - temp_device, - ) - - if is_validation: - config = copy.copy(config) - config.batch_size = 1 - config.multi_gpu = False - - self.__ds = self.create_dataset( - config=config, - model=model, - train_progress=train_progress, - is_validation=is_validation, - ) - self.__dl = TrainDataLoader(self.__ds, config.batch_size) - - def get_data_set(self) -> MGDS: - return self.__ds - - def get_data_loader(self) -> TrainDataLoader: - return self.__dl - def _preparation_modules(self, config: TrainConfig, model: HunyuanVideoModel): rescale_image = RescaleImageChannels(image_in_name='image', image_out_name='image', in_range_min=0, in_range_max=1, out_range_min=-1, out_range_max=1) encode_image = EncodeVAE(in_name='image', out_name='latent_image_distribution', vae=model.vae, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) @@ -75,22 +40,22 @@ def _preparation_modules(self, config: TrainConfig, model: HunyuanVideoModel): downscale_mask = ScaleImage(in_name='mask', out_name='latent_mask', factor=0.125) add_embeddings_to_prompt_1 = MapData(in_name='prompt', out_name='prompt_1', map_fn=model.add_text_encoder_1_embeddings_to_prompt) add_embeddings_to_prompt_2 = MapData(in_name='prompt', out_name='prompt_2', map_fn=model.add_text_encoder_2_embeddings_to_prompt) - tokenize_prompt_1 = Tokenize(in_name='prompt_1', tokens_out_name='tokens_1', mask_out_name='tokens_mask_1', tokenizer=model.tokenizer_1, max_token_length=77, format_text=DEFAULT_PROMPT_TEMPLATE, additional_format_text_tokens=DEFAULT_PROMPT_TEMPLATE_CROP_START) + tokenize_prompt_1 = Tokenize(in_name='prompt_1', tokens_out_name='tokens_1', mask_out_name='tokens_mask_1', tokenizer=model.tokenizer_1, max_token_length=77, + format_text=DEFAULT_PROMPT_TEMPLATE, additional_format_text_tokens=DEFAULT_PROMPT_TEMPLATE_CROP_START) tokenize_prompt_2 = Tokenize(in_name='prompt_2', tokens_out_name='tokens_2', mask_out_name='tokens_mask_2', tokenizer=model.tokenizer_2, max_token_length=77) - encode_prompt_1 = EncodeLlamaText(tokens_name='tokens_1', tokens_attention_mask_in_name='tokens_mask_1', hidden_state_out_name='text_encoder_1_hidden_state', tokens_attention_mask_out_name='tokens_mask_1', text_encoder=model.text_encoder_1, hidden_state_output_index=-(1 + config.text_encoder_2_layer_skip), autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype(), crop_start=DEFAULT_PROMPT_TEMPLATE_CROP_START) - encode_prompt_2 = EncodeClipText(in_name='tokens_2', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_2_hidden_states', pooled_out_name='text_encoder_2_pooled_state', add_layer_norm=False, text_encoder=model.text_encoder_2, hidden_state_output_index=-(2 + config.text_encoder_layer_skip), autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) + encode_prompt_1 = EncodeLlamaText(tokens_name='tokens_1', tokens_attention_mask_in_name='tokens_mask_1', hidden_state_out_name='text_encoder_1_hidden_state', tokens_attention_mask_out_name='tokens_mask_1', text_encoder=model.text_encoder_1, + hidden_state_output_index=-(1 + config.text_encoder_2_layer_skip), autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype(), crop_start=DEFAULT_PROMPT_TEMPLATE_CROP_START) + encode_prompt_2 = EncodeClipText(in_name='tokens_2', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_2_hidden_states', pooled_out_name='text_encoder_2_pooled_state', add_layer_norm=False, + text_encoder=model.text_encoder_2, hidden_state_output_index=-(2 + config.text_encoder_layer_skip), autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) modules = [rescale_image, encode_image, image_sample] + if config.masked_training: + modules.append(downscale_mask) if model.tokenizer_1: - modules.append(add_embeddings_to_prompt_1) - modules.append(tokenize_prompt_1) + modules += [add_embeddings_to_prompt_1, tokenize_prompt_1] if model.tokenizer_2: - modules.append(add_embeddings_to_prompt_2) - modules.append(tokenize_prompt_2) - - if config.masked_training: - modules.append(downscale_mask) + modules += [add_embeddings_to_prompt_2, tokenize_prompt_2] if not config.train_text_encoder_or_embedding() and model.text_encoder_1: modules.append(encode_prompt_1) @@ -100,7 +65,7 @@ def _preparation_modules(self, config: TrainConfig, model: HunyuanVideoModel): return modules - def _cache_modules(self, config: TrainConfig, model: HunyuanVideoModel): + def _cache_modules(self, config: TrainConfig, model: HunyuanVideoModel, model_setup: BaseHunyuanVideoSetup): image_split_names = ['latent_image', 'original_resolution', 'crop_offset'] if config.masked_training or config.model_type.has_mask_input(): @@ -120,60 +85,22 @@ def _cache_modules(self, config: TrainConfig, model: HunyuanVideoModel): ] if not config.train_text_encoder_or_embedding(): - text_split_names.append('tokens_1') - text_split_names.append('tokens_mask_1') - text_split_names.append('text_encoder_1_hidden_state') + text_split_names += ['tokens_1', 'tokens_mask_1', 'text_encoder_1_hidden_state'] if not config.train_text_encoder_2_or_embedding(): - text_split_names.append('tokens_2') - text_split_names.append('tokens_mask_2') - text_split_names.append('text_encoder_2_pooled_state') - - image_cache_dir = os.path.join(config.cache_dir, "image") - text_cache_dir = os.path.join(config.cache_dir, "text") - - def before_cache_image_fun(): - model.to(self.temp_device) - model.vae_to(self.train_device) - model.eval() - torch_gc() - - def before_cache_text_fun(): - model.to(self.temp_device) - - if not config.train_text_encoder_or_embedding(): - model.text_encoder_1_to(self.train_device) - - if not config.train_text_encoder_2_or_embedding(): - model.text_encoder_2_to(self.train_device) - - model.eval() - torch_gc() - - image_disk_cache = DiskCache(cache_dir=image_cache_dir, split_names=image_split_names, aggregate_names=image_aggregate_names, variations_in_name='concept.image_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.image'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_image_fun) - - text_disk_cache = DiskCache(cache_dir=text_cache_dir, split_names=text_split_names, aggregate_names=[], variations_in_name='concept.text_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_text_fun) - - modules = [] - - if config.latent_caching: - modules.append(image_disk_cache) - - if config.latent_caching: - sort_names = [x for x in sort_names if x not in image_aggregate_names] - sort_names = [x for x in sort_names if x not in image_split_names] - - if not config.train_text_encoder_or_embedding() or not config.train_text_encoder_2_or_embedding(): - modules.append(text_disk_cache) - sort_names = [x for x in sort_names if x not in text_split_names] - - if len(sort_names) > 0: - variation_sorting = VariationSorting(names=sort_names, balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled') - modules.append(variation_sorting) - - return modules + text_split_names += ['tokens_2', 'tokens_mask_2', 'text_encoder_2_pooled_state'] + + return self._cache_modules_from_names( + model, model_setup, + image_split_names=image_split_names, + image_aggregate_names=image_aggregate_names, + text_split_names=text_split_names, + sort_names=sort_names, + config=config, + text_caching=not config.train_text_encoder_or_embedding() or not config.train_text_encoder_2_or_embedding(), + ) - def _output_modules(self, config: TrainConfig, model: HunyuanVideoModel): + def _output_modules(self, config: TrainConfig, model: HunyuanVideoModel, model_setup: BaseHunyuanVideoSetup): output_names = [ 'image_path', 'latent_image', 'prompt_1', 'prompt_2', @@ -194,16 +121,10 @@ def _output_modules(self, config: TrainConfig, model: HunyuanVideoModel): if not config.train_text_encoder_2_or_embedding(): output_names.append('text_encoder_2_pooled_state') - def before_cache_image_fun(): - model.to(self.temp_device) - model.vae_to(self.train_device) - model.eval() - torch_gc() - return self._output_modules_from_out_names( + model, model_setup, output_names=output_names, config=config, - before_cache_image_fun=before_cache_image_fun, use_conditioning_image=True, vae=model.vae, autocast_context=[model.autocast_context], @@ -232,9 +153,7 @@ def before_save_fun(): # SaveImage(image_in_name='mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1), # SaveImage(image_in_name='image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1), - modules = [] - - modules.append(decode_image) + modules = [decode_image] #FIXME https://github.com/Nerogar/OneTrainer/issues/1015 #modules.append(save_image) @@ -243,51 +162,26 @@ def before_save_fun(): # modules.append(save_conditioning_image) if config.masked_training or config.model_type.has_mask_input(): - modules.append(upscale_mask) - modules.append(save_mask) + modules += [upscale_mask, save_mask] - modules.append(decode_prompt) - modules.append(save_prompt) + modules += [decode_prompt, save_prompt] return modules - def create_dataset( + def _create_dataset( self, config: TrainConfig, - model: HunyuanVideoModel, + model: BaseModel, + model_setup: BaseModelSetup, train_progress: TrainProgress, is_validation: bool = False, ): - enumerate_input = self._enumerate_input_modules(config, allow_videos=True) - load_input = self._load_input_modules(config, model.train_dtype, allow_video=True) - mask_augmentation = self._mask_augmentation_modules(config) - aspect_bucketing_in = self._aspect_bucketing_in(config, 64, True) - crop_modules = self._crop_modules(config) - augmentation_modules = self._augmentation_modules(config) - inpainting_modules = self._inpainting_modules(config) - preparation_modules = self._preparation_modules(config, model) - cache_modules = self._cache_modules(config, model) - output_modules = self._output_modules(config, model) - - debug_modules = self._debug_modules(config, model) - - return self._create_mgds( - config, - [ - enumerate_input, - load_input, - mask_augmentation, - aspect_bucketing_in, - crop_modules, - augmentation_modules, - inpainting_modules, - preparation_modules, - cache_modules, - output_modules, - - debug_modules if config.debug_mode else None, - # inserted before output_modules, which contains a sorting operation - ], - train_progress, - is_validation + return DataLoaderText2ImageMixin._create_dataset(self, + config, model, model_setup, train_progress, is_validation, + aspect_bucketing_quantization=64, + frame_dim_enabled=True, + allow_video_files=True, + allow_video=True, ) + +factory.register(BaseDataLoader, HunyuanVideoBaseDataLoader, ModelType.HUNYUAN_VIDEO) diff --git a/modules/dataLoader/PixArtAlphaBaseDataLoader.py b/modules/dataLoader/PixArtAlphaBaseDataLoader.py index ba4636d33..f2dc37857 100644 --- a/modules/dataLoader/PixArtAlphaBaseDataLoader.py +++ b/modules/dataLoader/PixArtAlphaBaseDataLoader.py @@ -1,17 +1,18 @@ -import copy import os from modules.dataLoader.BaseDataLoader import BaseDataLoader from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin +from modules.model.BaseModel import BaseModel from modules.model.PixArtAlphaModel import PixArtAlphaModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.modelSetup.BasePixArtAlphaSetup import BasePixArtAlphaSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig -from modules.util.torch_util import torch_gc +from modules.util.enum.ModelType import ModelType from modules.util.TrainProgress import TrainProgress -from mgds.MGDS import MGDS, TrainDataLoader from mgds.pipelineModules.DecodeTokens import DecodeTokens from mgds.pipelineModules.DecodeVAE import DecodeVAE -from mgds.pipelineModules.DiskCache import DiskCache from mgds.pipelineModules.EncodeT5Text import EncodeT5Text from mgds.pipelineModules.EncodeVAE import EncodeVAE from mgds.pipelineModules.MapData import MapData @@ -21,48 +22,12 @@ from mgds.pipelineModules.SaveText import SaveText from mgds.pipelineModules.ScaleImage import ScaleImage from mgds.pipelineModules.Tokenize import Tokenize -from mgds.pipelineModules.VariationSorting import VariationSorting - -import torch class PixArtAlphaBaseDataLoader( BaseDataLoader, DataLoaderText2ImageMixin, ): - def __init__( - self, - train_device: torch.device, - temp_device: torch.device, - config: TrainConfig, - model: PixArtAlphaModel, - train_progress: TrainProgress, - is_validation: bool = False, - ): - super().__init__( - train_device, - temp_device, - ) - - if is_validation: - config = copy.copy(config) - config.batch_size = 1 - config.multi_gpu = False - - self.__ds = self.create_dataset( - config=config, - model=model, - train_progress=train_progress, - is_validation=is_validation, - ) - self.__dl = TrainDataLoader(self.__ds, config.batch_size) - - def get_data_set(self) -> MGDS: - return self.__ds - - def get_data_loader(self) -> TrainDataLoader: - return self.__dl - def _preparation_modules(self, config: TrainConfig, model: PixArtAlphaModel): max_token_length = 120 # deactivated for performance reasons. most people don't need 300 tokens @@ -78,24 +43,25 @@ def _preparation_modules(self, config: TrainConfig, model: PixArtAlphaModel): encode_conditioning_image = EncodeVAE(in_name='conditioning_image', out_name='latent_conditioning_image_distribution', vae=model.vae, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) conditioning_image_sample = SampleVAEDistribution(in_name='latent_conditioning_image_distribution', out_name='latent_conditioning_image', mode='mean') tokenize_prompt = Tokenize(in_name='prompt', tokens_out_name='tokens', mask_out_name='tokens_mask', tokenizer=model.tokenizer, max_token_length=max_token_length) - encode_prompt = EncodeT5Text(tokens_in_name='tokens', tokens_attention_mask_in_name='tokens_mask', hidden_state_out_name='text_encoder_hidden_state', pooled_out_name=None, add_layer_norm=True, text_encoder=model.text_encoder, hidden_state_output_index=-(1 + config.text_encoder_layer_skip), autocast_contexts=[model.autocast_context, model.text_encoder_autocast_context], dtype=model.text_encoder_train_dtype.torch_dtype()) + encode_prompt = EncodeT5Text(tokens_in_name='tokens', tokens_attention_mask_in_name='tokens_mask', hidden_state_out_name='text_encoder_hidden_state', pooled_out_name=None, add_layer_norm=True, + text_encoder=model.text_encoder, hidden_state_output_index=-(1 + config.text_encoder_layer_skip), autocast_contexts=[model.autocast_context, model.text_encoder_autocast_context], + dtype=model.text_encoder_train_dtype.torch_dtype()) - modules = [rescale_image, encode_image, image_sample, add_embeddings_to_prompt, tokenize_prompt] + modules = [rescale_image, encode_image, image_sample] if config.masked_training or config.model_type.has_mask_input(): modules.append(downscale_mask) if config.model_type.has_conditioning_image_input(): - modules.append(rescale_conditioning_image) - modules.append(encode_conditioning_image) - modules.append(conditioning_image_sample) + modules += [rescale_conditioning_image, encode_conditioning_image, conditioning_image_sample] + modules += [add_embeddings_to_prompt, tokenize_prompt] if not config.train_text_encoder_or_embedding(): modules.append(encode_prompt) return modules - def _cache_modules(self, config: TrainConfig, model: PixArtAlphaModel): + def _cache_modules(self, config: TrainConfig, model: PixArtAlphaModel, model_setup: BasePixArtAlphaSetup): image_split_names = ['latent_image'] if config.masked_training or config.model_type.has_mask_input(): @@ -112,45 +78,17 @@ def _cache_modules(self, config: TrainConfig, model: PixArtAlphaModel): 'prompt', 'concept' ] - image_cache_dir = os.path.join(config.cache_dir, "image") - text_cache_dir = os.path.join(config.cache_dir, "text") - - def before_cache_image_fun(): - model.to(self.temp_device) - model.vae_to(self.train_device) - model.eval() - torch_gc() - - def before_cache_text_fun(): - model.to(self.temp_device) - model.text_encoder_to(self.train_device) - model.eval() - torch_gc() - - image_disk_cache = DiskCache(cache_dir=image_cache_dir, split_names=image_split_names, aggregate_names=image_aggregate_names, variations_in_name='concept.image_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.image'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_image_fun) - - text_disk_cache = DiskCache(cache_dir=text_cache_dir, split_names=text_split_names, aggregate_names=[], variations_in_name='concept.text_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_text_fun) - - modules = [] - - if config.latent_caching: - modules.append(image_disk_cache) - - if config.latent_caching: - sort_names = [x for x in sort_names if x not in image_aggregate_names] - sort_names = [x for x in sort_names if x not in image_split_names] - - if not config.train_text_encoder_or_embedding(): - modules.append(text_disk_cache) - sort_names = [x for x in sort_names if x not in text_split_names] - - if len(sort_names) > 0: - variation_sorting = VariationSorting(names=sort_names, balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled') - modules.append(variation_sorting) - - return modules + return self._cache_modules_from_names( + model, model_setup, + image_split_names=image_split_names, + image_aggregate_names=image_aggregate_names, + text_split_names=text_split_names, + sort_names=sort_names, + config=config, + text_caching=not config.train_text_encoder_or_embedding(), + ) - def _output_modules(self, config: TrainConfig, model: PixArtAlphaModel): + def _output_modules(self, config: TrainConfig, model: PixArtAlphaModel, model_setup: BasePixArtAlphaSetup): output_names = [ 'image_path', 'latent_image', 'prompt', @@ -167,16 +105,10 @@ def _output_modules(self, config: TrainConfig, model: PixArtAlphaModel): if not config.train_text_encoder_or_embedding(): output_names.append('text_encoder_hidden_state') - def before_cache_image_fun(): - model.to(self.temp_device) - model.vae_to(self.train_device) - model.eval() - torch_gc() - return self._output_modules_from_out_names( + model, model_setup, output_names=output_names, config=config, - before_cache_image_fun=before_cache_image_fun, use_conditioning_image=True, vae=model.vae, autocast_context=[model.autocast_context], @@ -203,61 +135,30 @@ def before_save_fun(): # SaveImage(image_in_name='mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1), # SaveImage(image_in_name='image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1), - modules = [] - - modules.append(decode_image) - modules.append(save_image) + modules = [decode_image, save_image] if config.model_type.has_conditioning_image_input(): - modules.append(decode_conditioning_image) - modules.append(save_conditioning_image) + modules += [decode_conditioning_image, save_conditioning_image] if config.masked_training or config.model_type.has_mask_input(): - modules.append(upscale_mask) - modules.append(save_mask) + modules += [upscale_mask, save_mask] - modules.append(decode_prompt) - modules.append(save_prompt) + modules += [decode_prompt, save_prompt] return modules - def create_dataset( + def _create_dataset( self, config: TrainConfig, - model: PixArtAlphaModel, + model: BaseModel, + model_setup: BaseModelSetup, train_progress: TrainProgress, is_validation: bool = False, ): - enumerate_input = self._enumerate_input_modules(config) - load_input = self._load_input_modules(config, model.train_dtype) - mask_augmentation = self._mask_augmentation_modules(config) - aspect_bucketing_in = self._aspect_bucketing_in(config, 16) - crop_modules = self._crop_modules(config) - augmentation_modules = self._augmentation_modules(config) - inpainting_modules = self._inpainting_modules(config) - preparation_modules = self._preparation_modules(config, model) - cache_modules = self._cache_modules(config, model) - output_modules = self._output_modules(config, model) - - debug_modules = self._debug_modules(config, model) - - return self._create_mgds( - config, - [ - enumerate_input, - load_input, - mask_augmentation, - aspect_bucketing_in, - crop_modules, - augmentation_modules, - inpainting_modules, - preparation_modules, - cache_modules, - output_modules, - - debug_modules if config.debug_mode else None, - # inserted before output_modules, which contains a sorting operation - ], - train_progress, - is_validation, + return DataLoaderText2ImageMixin._create_dataset(self, + config, model, model_setup, train_progress, is_validation, + aspect_bucketing_quantization=16, ) + +factory.register(BaseDataLoader, PixArtAlphaBaseDataLoader, ModelType.PIXART_ALPHA) +factory.register(BaseDataLoader, PixArtAlphaBaseDataLoader, ModelType.PIXART_SIGMA) diff --git a/modules/dataLoader/QwenBaseDataLoader.py b/modules/dataLoader/QwenBaseDataLoader.py index 37335e5ad..5ae3b7159 100644 --- a/modules/dataLoader/QwenBaseDataLoader.py +++ b/modules/dataLoader/QwenBaseDataLoader.py @@ -1,22 +1,23 @@ -import copy import os from modules.dataLoader.BaseDataLoader import BaseDataLoader from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin +from modules.model.BaseModel import BaseModel from modules.model.QwenModel import ( DEFAULT_PROMPT_TEMPLATE, DEFAULT_PROMPT_TEMPLATE_CROP_START, PROMPT_MAX_LENGTH, QwenModel, ) +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.modelSetup.BaseQwenSetup import BaseQwenSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig -from modules.util.torch_util import torch_gc +from modules.util.enum.ModelType import ModelType from modules.util.TrainProgress import TrainProgress -from mgds.MGDS import MGDS, TrainDataLoader from mgds.pipelineModules.DecodeTokens import DecodeTokens from mgds.pipelineModules.DecodeVAE import DecodeVAE -from mgds.pipelineModules.DiskCache import DiskCache from mgds.pipelineModules.EncodeQwenText import EncodeQwenText from mgds.pipelineModules.EncodeVAE import EncodeVAE from mgds.pipelineModules.RescaleImageChannels import RescaleImageChannels @@ -25,9 +26,6 @@ from mgds.pipelineModules.SaveText import SaveText from mgds.pipelineModules.ScaleImage import ScaleImage from mgds.pipelineModules.Tokenize import Tokenize -from mgds.pipelineModules.VariationSorting import VariationSorting - -import torch #TODO share more code with other models @@ -35,60 +33,28 @@ class QwenBaseDataLoader( BaseDataLoader, DataLoaderText2ImageMixin, ): - def __init__( - self, - train_device: torch.device, - temp_device: torch.device, - config: TrainConfig, - model: QwenModel, - train_progress: TrainProgress, - is_validation: bool = False, - ): - super().__init__( - train_device, - temp_device, - ) - - if is_validation: - config = copy.copy(config) - config.batch_size = 1 - config.multi_gpu = False - - self.__ds = self.create_dataset( - config=config, - model=model, - train_progress=train_progress, - is_validation=is_validation, - ) - self.__dl = TrainDataLoader(self.__ds, config.batch_size) - - def get_data_set(self) -> MGDS: - return self.__ds - - def get_data_loader(self) -> TrainDataLoader: - return self.__dl - def _preparation_modules(self, config: TrainConfig, model: QwenModel): rescale_image = RescaleImageChannels(image_in_name='image', image_out_name='image', in_range_min=0, in_range_max=1, out_range_min=-1, out_range_max=1) encode_image = EncodeVAE(in_name='image', out_name='latent_image_distribution', vae=model.vae, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) image_sample = SampleVAEDistribution(in_name='latent_image_distribution', out_name='latent_image', mode='mean') downscale_mask = ScaleImage(in_name='mask', out_name='latent_mask', factor=0.125) - tokenize_prompt = Tokenize(in_name='prompt', tokens_out_name='tokens', mask_out_name='tokens_mask', tokenizer=model.tokenizer, max_token_length=PROMPT_MAX_LENGTH, format_text=DEFAULT_PROMPT_TEMPLATE, additional_format_text_tokens=DEFAULT_PROMPT_TEMPLATE_CROP_START) - encode_prompt = EncodeQwenText(tokens_name='tokens', tokens_attention_mask_in_name='tokens_mask', hidden_state_out_name='text_encoder_hidden_state', tokens_attention_mask_out_name='tokens_mask', text_encoder=model.text_encoder, hidden_state_output_index=-1, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype(), crop_start=DEFAULT_PROMPT_TEMPLATE_CROP_START) + tokenize_prompt = Tokenize(in_name='prompt', tokens_out_name='tokens', mask_out_name='tokens_mask', tokenizer=model.tokenizer, max_token_length=PROMPT_MAX_LENGTH, + format_text=DEFAULT_PROMPT_TEMPLATE, additional_format_text_tokens=DEFAULT_PROMPT_TEMPLATE_CROP_START) + encode_prompt = EncodeQwenText(tokens_name='tokens', tokens_attention_mask_in_name='tokens_mask', hidden_state_out_name='text_encoder_hidden_state', tokens_attention_mask_out_name='tokens_mask', + text_encoder=model.text_encoder, hidden_state_output_index=-1, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype(), crop_start=DEFAULT_PROMPT_TEMPLATE_CROP_START) modules = [rescale_image, encode_image, image_sample] - - modules.append(tokenize_prompt) - if config.masked_training or config.model_type.has_mask_input(): modules.append(downscale_mask) + modules.append(tokenize_prompt) + if not config.train_text_encoder_or_embedding(): modules.append(encode_prompt) return modules - def _cache_modules(self, config: TrainConfig, model: QwenModel): + def _cache_modules(self, config: TrainConfig, model: QwenModel, model_setup: BaseQwenSetup): image_split_names = ['latent_image', 'original_resolution', 'crop_offset'] if config.masked_training or config.model_type.has_mask_input(): @@ -104,53 +70,19 @@ def _cache_modules(self, config: TrainConfig, model: QwenModel): ] if not config.train_text_encoder_or_embedding(): - text_split_names.append('tokens') - text_split_names.append('tokens_mask') - text_split_names.append('text_encoder_hidden_state') - - image_cache_dir = os.path.join(config.cache_dir, "image") - text_cache_dir = os.path.join(config.cache_dir, "text") - - #TODO share more code with other models - def before_cache_image_fun(): - model.to(self.temp_device) - model.vae_to(self.train_device) - model.eval() - torch_gc() - - def before_cache_text_fun(): - model.to(self.temp_device) - - if not config.train_text_encoder_or_embedding(): - model.text_encoder_to(self.train_device) - - model.eval() - torch_gc() - - image_disk_cache = DiskCache(cache_dir=image_cache_dir, split_names=image_split_names, aggregate_names=image_aggregate_names, variations_in_name='concept.image_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.image'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_image_fun) - - text_disk_cache = DiskCache(cache_dir=text_cache_dir, split_names=text_split_names, aggregate_names=[], variations_in_name='concept.text_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_text_fun) - - modules = [] - - if config.latent_caching: - modules.append(image_disk_cache) - - if config.latent_caching: - sort_names = [x for x in sort_names if x not in image_aggregate_names] - sort_names = [x for x in sort_names if x not in image_split_names] - - if not config.train_text_encoder_or_embedding(): - modules.append(text_disk_cache) - sort_names = [x for x in sort_names if x not in text_split_names] - - if len(sort_names) > 0: - variation_sorting = VariationSorting(names=sort_names, balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled') - modules.append(variation_sorting) - - return modules + text_split_names += ['tokens', 'tokens_mask', 'text_encoder_hidden_state'] + + return self._cache_modules_from_names( + model, model_setup, + image_split_names=image_split_names, + image_aggregate_names=image_aggregate_names, + text_split_names=text_split_names, + sort_names=sort_names, + config=config, + text_caching=not config.train_text_encoder_or_embedding(), + ) - def _output_modules(self, config: TrainConfig, model: QwenModel): + def _output_modules(self, config: TrainConfig, model: QwenModel, model_setup: BaseQwenSetup): output_names = [ 'image_path', 'latent_image', 'prompt', @@ -165,16 +97,10 @@ def _output_modules(self, config: TrainConfig, model: QwenModel): if not config.train_text_encoder_or_embedding(): output_names.append('text_encoder_hidden_state') - def before_cache_image_fun(): - model.to(self.temp_device) - model.vae_to(self.train_device) - model.eval() - torch_gc() - return self._output_modules_from_out_names( + model, model_setup, output_names=output_names, config=config, - before_cache_image_fun=before_cache_image_fun, use_conditioning_image=False, vae=model.vae, autocast_context=[model.autocast_context], @@ -202,59 +128,31 @@ def before_save_fun(): # SaveImage(image_in_name='mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1), # SaveImage(image_in_name='image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1), - modules = [] - - modules.append(decode_image) + modules = [decode_image] #FIXME https://github.com/Nerogar/OneTrainer/issues/1015 #modules.append(save_image) if config.masked_training or config.model_type.has_mask_input(): - modules.append(upscale_mask) - modules.append(save_mask) + modules += [upscale_mask, save_mask] - modules.append(decode_prompt) - modules.append(save_prompt) + modules += [decode_prompt, save_prompt] return modules - def create_dataset( + def _create_dataset( self, config: TrainConfig, - model: QwenModel, + model: BaseModel, + model_setup: BaseModelSetup, train_progress: TrainProgress, is_validation: bool = False, ): - enumerate_input = self._enumerate_input_modules(config, allow_videos=False) #don't allow video files, but... - load_input = self._load_input_modules(config, model.train_dtype, allow_video=True) #...Qwen has a video-capable VAE: convert images to video dimensions - mask_augmentation = self._mask_augmentation_modules(config) - aspect_bucketing_in = self._aspect_bucketing_in(config, 64) - crop_modules = self._crop_modules(config) - augmentation_modules = self._augmentation_modules(config) - inpainting_modules = self._inpainting_modules(config) - preparation_modules = self._preparation_modules(config, model) - cache_modules = self._cache_modules(config, model) - output_modules = self._output_modules(config, model) - - debug_modules = self._debug_modules(config, model) - - return self._create_mgds( - config, - [ - enumerate_input, - load_input, - mask_augmentation, - aspect_bucketing_in, - crop_modules, - augmentation_modules, - inpainting_modules, - preparation_modules, - cache_modules, - output_modules, - - debug_modules if config.debug_mode else None, - # inserted before output_modules, which contains a sorting operation - ], - train_progress, - is_validation + return DataLoaderText2ImageMixin._create_dataset(self, + config, model, model_setup, train_progress, is_validation, + aspect_bucketing_quantization=64, + allow_video_files=False, #don't allow video files, but... + allow_video=True, #...Qwen has a video-capable VAE: convert images to video dimensions #TODO the same as frame_dim_enabled? ) + +factory.register(BaseDataLoader, QwenBaseDataLoader, ModelType.QWEN) diff --git a/modules/dataLoader/SanaBaseDataLoader.py b/modules/dataLoader/SanaBaseDataLoader.py index 4c59a8fc0..38d5c31b0 100644 --- a/modules/dataLoader/SanaBaseDataLoader.py +++ b/modules/dataLoader/SanaBaseDataLoader.py @@ -1,17 +1,18 @@ -import copy import os from modules.dataLoader.BaseDataLoader import BaseDataLoader from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin +from modules.model.BaseModel import BaseModel from modules.model.SanaModel import SanaModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.modelSetup.BaseSanaSetup import BaseSanaSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig -from modules.util.torch_util import torch_gc +from modules.util.enum.ModelType import ModelType from modules.util.TrainProgress import TrainProgress -from mgds.MGDS import MGDS, TrainDataLoader from mgds.pipelineModules.DecodeTokens import DecodeTokens from mgds.pipelineModules.DecodeVAE import DecodeVAE -from mgds.pipelineModules.DiskCache import DiskCache from mgds.pipelineModules.EncodeGemmaText import EncodeGemmaText from mgds.pipelineModules.EncodeVAE import EncodeVAE from mgds.pipelineModules.MapData import MapData @@ -20,48 +21,12 @@ from mgds.pipelineModules.SaveText import SaveText from mgds.pipelineModules.ScaleImage import ScaleImage from mgds.pipelineModules.Tokenize import Tokenize -from mgds.pipelineModules.VariationSorting import VariationSorting - -import torch class SanaBaseDataLoader( BaseDataLoader, DataLoaderText2ImageMixin, ): - def __init__( - self, - train_device: torch.device, - temp_device: torch.device, - config: TrainConfig, - model: SanaModel, - train_progress: TrainProgress, - is_validation: bool = False, - ): - super().__init__( - train_device, - temp_device, - ) - - if is_validation: - config = copy.copy(config) - config.batch_size = 1 - config.multi_gpu = False - - self.__ds = self.create_dataset( - config=config, - model=model, - train_progress=train_progress, - is_validation=is_validation, - ) - self.__dl = TrainDataLoader(self.__ds, config.batch_size) - - def get_data_set(self) -> MGDS: - return self.__ds - - def get_data_loader(self) -> TrainDataLoader: - return self.__dl - def _preparation_modules(self, config: TrainConfig, model: SanaModel): max_token_length = 300 @@ -72,23 +37,24 @@ def _preparation_modules(self, config: TrainConfig, model: SanaModel): add_embeddings_to_prompt = MapData(in_name='prompt', out_name='prompt', map_fn=model.add_text_encoder_embeddings_to_prompt) encode_conditioning_image = EncodeVAE(in_name='conditioning_image', out_name='latent_conditioning_image', vae=model.vae, autocast_contexts=[model.autocast_context, model.vae_autocast_context], dtype=model.train_dtype.torch_dtype()) tokenize_prompt = Tokenize(in_name='prompt', tokens_out_name='tokens', mask_out_name='tokens_mask', tokenizer=model.tokenizer, max_token_length=max_token_length) - encode_prompt = EncodeGemmaText(tokens_in_name='tokens', tokens_attention_mask_in_name='tokens_mask', hidden_state_out_name='text_encoder_hidden_state', add_layer_norm=True, text_encoder=model.text_encoder, hidden_state_output_index=-(1 + config.text_encoder_layer_skip), autocast_contexts=[model.autocast_context, model.text_encoder_autocast_context], dtype=model.text_encoder_train_dtype.torch_dtype()) + encode_prompt = EncodeGemmaText(tokens_in_name='tokens', tokens_attention_mask_in_name='tokens_mask', hidden_state_out_name='text_encoder_hidden_state', add_layer_norm=True, text_encoder=model.text_encoder, + hidden_state_output_index=-(1 + config.text_encoder_layer_skip), autocast_contexts=[model.autocast_context, model.text_encoder_autocast_context], dtype=model.text_encoder_train_dtype.torch_dtype()) - modules = [rescale_image, encode_image, add_embeddings_to_prompt, tokenize_prompt] + modules = [rescale_image, encode_image] if config.masked_training or config.model_type.has_mask_input(): modules.append(downscale_mask) if config.model_type.has_conditioning_image_input(): - modules.append(rescale_conditioning_image) - modules.append(encode_conditioning_image) + modules += [rescale_conditioning_image, encode_conditioning_image] + modules += [add_embeddings_to_prompt, tokenize_prompt] if not config.train_text_encoder_or_embedding(): modules.append(encode_prompt) return modules - def _cache_modules(self, config: TrainConfig, model: SanaModel): + def _cache_modules(self, config: TrainConfig, model: SanaModel, model_setup: BaseSanaSetup): image_split_names = ['latent_image'] if config.masked_training or config.model_type.has_mask_input(): @@ -101,49 +67,21 @@ def _cache_modules(self, config: TrainConfig, model: SanaModel): text_split_names = ['tokens', 'tokens_mask', 'text_encoder_hidden_state'] - sort_names = text_split_names + image_aggregate_names + image_split_names +[ + sort_names = text_split_names + image_aggregate_names + image_split_names + [ 'prompt', 'concept' ] - image_cache_dir = os.path.join(config.cache_dir, "image") - text_cache_dir = os.path.join(config.cache_dir, "text") - - def before_cache_image_fun(): - model.to(self.temp_device) - model.vae_to(self.train_device) - model.eval() - torch_gc() - - def before_cache_text_fun(): - model.to(self.temp_device) - model.text_encoder_to(self.train_device) - model.eval() - torch_gc() - - image_disk_cache = DiskCache(cache_dir=image_cache_dir, split_names=image_split_names, aggregate_names=image_aggregate_names, variations_in_name='concept.image_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.image'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_image_fun) - - text_disk_cache = DiskCache(cache_dir=text_cache_dir, split_names=text_split_names, aggregate_names=[], variations_in_name='concept.text_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_text_fun) - - modules = [] - - if config.latent_caching: - modules.append(image_disk_cache) - - if config.latent_caching: - sort_names = [x for x in sort_names if x not in image_aggregate_names] - sort_names = [x for x in sort_names if x not in image_split_names] - - if not config.train_text_encoder_or_embedding(): - modules.append(text_disk_cache) - sort_names = [x for x in sort_names if x not in text_split_names] - - if len(sort_names) > 0: - variation_sorting = VariationSorting(names=sort_names, balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled') - modules.append(variation_sorting) - - return modules + return self._cache_modules_from_names( + model, model_setup, + image_split_names=image_split_names, + image_aggregate_names=image_aggregate_names, + text_split_names=text_split_names, + sort_names=sort_names, + config=config, + text_caching=not config.train_text_encoder_or_embedding(), + ) - def _output_modules(self, config: TrainConfig, model: SanaModel): + def _output_modules(self, config: TrainConfig, model: SanaModel, model_setup: BaseSanaSetup): output_names = [ 'image_path', 'latent_image', 'prompt', @@ -160,16 +98,10 @@ def _output_modules(self, config: TrainConfig, model: SanaModel): if not config.train_text_encoder_or_embedding(): output_names.append('text_encoder_hidden_state') - def before_cache_image_fun(): - model.to(self.temp_device) - model.vae_to(self.train_device) - model.eval() - torch_gc() - return self._output_modules_from_out_names( + model, model_setup, output_names=output_names, config=config, - before_cache_image_fun=before_cache_image_fun, use_conditioning_image=True, vae=model.vae, autocast_context=[model.autocast_context], @@ -196,61 +128,29 @@ def before_save_fun(): # SaveImage(image_in_name='mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1), # SaveImage(image_in_name='image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1), - modules = [] - - modules.append(decode_image) - modules.append(save_image) + modules = [decode_image, save_image] if config.model_type.has_conditioning_image_input(): - modules.append(decode_conditioning_image) - modules.append(save_conditioning_image) + modules += [decode_conditioning_image, save_conditioning_image] if config.masked_training or config.model_type.has_mask_input(): - modules.append(upscale_mask) - modules.append(save_mask) + modules += [upscale_mask, save_mask] - modules.append(decode_prompt) - modules.append(save_prompt) + modules += [decode_prompt, save_prompt] return modules - def create_dataset( + def _create_dataset( self, config: TrainConfig, - model: SanaModel, + model: BaseModel, + model_setup: BaseModelSetup, train_progress: TrainProgress, is_validation: bool = False, ): - enumerate_input = self._enumerate_input_modules(config) - load_input = self._load_input_modules(config, model.train_dtype) - mask_augmentation = self._mask_augmentation_modules(config) - aspect_bucketing_in = self._aspect_bucketing_in(config, 32) - crop_modules = self._crop_modules(config) - augmentation_modules = self._augmentation_modules(config) - inpainting_modules = self._inpainting_modules(config) - preparation_modules = self._preparation_modules(config, model) - cache_modules = self._cache_modules(config, model) - output_modules = self._output_modules(config, model) - - debug_modules = self._debug_modules(config, model) - - return self._create_mgds( - config, - [ - enumerate_input, - load_input, - mask_augmentation, - aspect_bucketing_in, - crop_modules, - augmentation_modules, - inpainting_modules, - preparation_modules, - cache_modules, - output_modules, - - debug_modules if config.debug_mode else None, - # inserted before output_modules, which contains a sorting operation - ], - train_progress, - is_validation, + return DataLoaderText2ImageMixin._create_dataset(self, + config, model, model_setup, train_progress, is_validation, + aspect_bucketing_quantization=32, ) + +factory.register(BaseDataLoader, SanaBaseDataLoader, ModelType.SANA) diff --git a/modules/dataLoader/StableDiffusion3BaseDataLoader.py b/modules/dataLoader/StableDiffusion3BaseDataLoader.py index af8bb5350..c497a20c6 100644 --- a/modules/dataLoader/StableDiffusion3BaseDataLoader.py +++ b/modules/dataLoader/StableDiffusion3BaseDataLoader.py @@ -1,17 +1,18 @@ -import copy import os from modules.dataLoader.BaseDataLoader import BaseDataLoader from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin +from modules.model.BaseModel import BaseModel from modules.model.StableDiffusion3Model import StableDiffusion3Model +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.modelSetup.BaseStableDiffusion3Setup import BaseStableDiffusion3Setup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig -from modules.util.torch_util import torch_gc +from modules.util.enum.ModelType import ModelType from modules.util.TrainProgress import TrainProgress -from mgds.MGDS import MGDS, TrainDataLoader from mgds.pipelineModules.DecodeTokens import DecodeTokens from mgds.pipelineModules.DecodeVAE import DecodeVAE -from mgds.pipelineModules.DiskCache import DiskCache from mgds.pipelineModules.EncodeClipText import EncodeClipText from mgds.pipelineModules.EncodeT5Text import EncodeT5Text from mgds.pipelineModules.EncodeVAE import EncodeVAE @@ -22,48 +23,12 @@ from mgds.pipelineModules.SaveText import SaveText from mgds.pipelineModules.ScaleImage import ScaleImage from mgds.pipelineModules.Tokenize import Tokenize -from mgds.pipelineModules.VariationSorting import VariationSorting - -import torch class StableDiffusion3BaseDataLoader( BaseDataLoader, DataLoaderText2ImageMixin, ): - def __init__( - self, - train_device: torch.device, - temp_device: torch.device, - config: TrainConfig, - model: StableDiffusion3Model, - train_progress: TrainProgress, - is_validation: bool = False, - ): - super().__init__( - train_device, - temp_device, - ) - - if is_validation: - config = copy.copy(config) - config.batch_size = 1 - config.multi_gpu = False - - self.__ds = self.create_dataset( - config=config, - model=model, - train_progress=train_progress, - is_validation=is_validation, - ) - self.__dl = TrainDataLoader(self.__ds, config.batch_size) - - def get_data_set(self) -> MGDS: - return self.__ds - - def get_data_loader(self) -> TrainDataLoader: - return self.__dl - def _preparation_modules(self, config: TrainConfig, model: StableDiffusion3Model): max_tokens = model.tokenizer_1.model_max_length if model.tokenizer_1 is not None else 77 @@ -80,11 +45,20 @@ def _preparation_modules(self, config: TrainConfig, model: StableDiffusion3Model tokenize_prompt_1 = Tokenize(in_name='prompt_1', tokens_out_name='tokens_1', mask_out_name='tokens_mask_1', tokenizer=model.tokenizer_1, max_token_length=max_tokens) tokenize_prompt_2 = Tokenize(in_name='prompt_2', tokens_out_name='tokens_2', mask_out_name='tokens_mask_2', tokenizer=model.tokenizer_2, max_token_length=max_tokens) tokenize_prompt_3 = Tokenize(in_name='prompt_3', tokens_out_name='tokens_3', mask_out_name='tokens_mask_3', tokenizer=model.tokenizer_3, max_token_length=max_tokens) - encode_prompt_1 = EncodeClipText(in_name='tokens_1', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_1_hidden_state', pooled_out_name='text_encoder_1_pooled_state', add_layer_norm=False, text_encoder=model.text_encoder_1, hidden_state_output_index=-(2 + config.text_encoder_layer_skip), autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) - encode_prompt_2 = EncodeClipText(in_name='tokens_2', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_2_hidden_state', pooled_out_name='text_encoder_2_pooled_state', add_layer_norm=False, text_encoder=model.text_encoder_2, hidden_state_output_index=-(2 + config.text_encoder_2_layer_skip), autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) - encode_prompt_3 = EncodeT5Text(tokens_in_name='tokens_3', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_3_hidden_state', pooled_out_name=None, add_layer_norm=True, text_encoder=model.text_encoder_3, hidden_state_output_index=-(1 + config.text_encoder_3_layer_skip), autocast_contexts=[model.autocast_context, model.text_encoder_3_autocast_context], dtype=model.text_encoder_3_train_dtype.torch_dtype()) + encode_prompt_1 = EncodeClipText(in_name='tokens_1', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_1_hidden_state', pooled_out_name='text_encoder_1_pooled_state', add_layer_norm=False, + text_encoder=model.text_encoder_1, hidden_state_output_index=-(2 + config.text_encoder_layer_skip), autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) + encode_prompt_2 = EncodeClipText(in_name='tokens_2', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_2_hidden_state', pooled_out_name='text_encoder_2_pooled_state', add_layer_norm=False, + text_encoder=model.text_encoder_2, hidden_state_output_index=-(2 + config.text_encoder_2_layer_skip), autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) + encode_prompt_3 = EncodeT5Text(tokens_in_name='tokens_3', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_3_hidden_state', pooled_out_name=None, add_layer_norm=True, + text_encoder=model.text_encoder_3, hidden_state_output_index=-(1 + config.text_encoder_3_layer_skip), autocast_contexts=[model.autocast_context, model.text_encoder_3_autocast_context], + dtype=model.text_encoder_3_train_dtype.torch_dtype()) modules = [rescale_image, encode_image, image_sample] + if config.masked_training or config.model_type.has_mask_input(): + modules.append(downscale_mask) + + if config.model_type.has_conditioning_image_input(): + modules += [rescale_conditioning_image, encode_conditioning_image, conditioning_image_sample] if model.tokenizer_1: modules.append(add_embeddings_to_prompt_1) @@ -96,14 +70,6 @@ def _preparation_modules(self, config: TrainConfig, model: StableDiffusion3Model modules.append(add_embeddings_to_prompt_3) modules.append(tokenize_prompt_3) - if config.masked_training or config.model_type.has_mask_input(): - modules.append(downscale_mask) - - if config.model_type.has_conditioning_image_input(): - modules.append(rescale_conditioning_image) - modules.append(encode_conditioning_image) - modules.append(conditioning_image_sample) - if not config.train_text_encoder_or_embedding() and model.text_encoder_1: modules.append(encode_prompt_1) @@ -115,7 +81,7 @@ def _preparation_modules(self, config: TrainConfig, model: StableDiffusion3Model return modules - def _cache_modules(self, config: TrainConfig, model: StableDiffusion3Model): + def _cache_modules(self, config: TrainConfig, model: StableDiffusion3Model, model_setup: BaseStableDiffusion3Setup): image_split_names = ['latent_image', 'original_resolution', 'crop_offset'] if config.masked_training or config.model_type.has_mask_input(): @@ -136,70 +102,25 @@ def _cache_modules(self, config: TrainConfig, model: StableDiffusion3Model): ] if not config.train_text_encoder_or_embedding(): - text_split_names.append('tokens_1') - text_split_names.append('tokens_mask_1') - text_split_names.append('text_encoder_1_hidden_state') - text_split_names.append('text_encoder_1_pooled_state') + text_split_names += ['tokens_1', 'tokens_mask_1', 'text_encoder_1_hidden_state', 'text_encoder_1_pooled_state'] if not config.train_text_encoder_2_or_embedding(): - text_split_names.append('tokens_2') - text_split_names.append('tokens_mask_2') - text_split_names.append('text_encoder_2_hidden_state') - text_split_names.append('text_encoder_2_pooled_state') + text_split_names += ['tokens_2', 'tokens_mask_2', 'text_encoder_2_hidden_state', 'text_encoder_2_pooled_state'] if not config.train_text_encoder_3_or_embedding(): - text_split_names.append('tokens_3') - text_split_names.append('tokens_mask_3') - text_split_names.append('text_encoder_3_hidden_state') - - image_cache_dir = os.path.join(config.cache_dir, "image") - text_cache_dir = os.path.join(config.cache_dir, "text") - - def before_cache_image_fun(): - model.to(self.temp_device) - model.vae_to(self.train_device) - model.eval() - torch_gc() - - def before_cache_text_fun(): - model.to(self.temp_device) - - if not config.train_text_encoder_or_embedding(): - model.text_encoder_1_to(self.train_device) - - if not config.train_text_encoder_2_or_embedding(): - model.text_encoder_2_to(self.train_device) - - if not config.train_text_encoder_3_or_embedding(): - model.text_encoder_3_to(self.train_device) - - model.eval() - torch_gc() - - image_disk_cache = DiskCache(cache_dir=image_cache_dir, split_names=image_split_names, aggregate_names=image_aggregate_names, variations_in_name='concept.image_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.image'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_image_fun) - - text_disk_cache = DiskCache(cache_dir=text_cache_dir, split_names=text_split_names, aggregate_names=[], variations_in_name='concept.text_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_text_fun) - - modules = [] - - if config.latent_caching: - modules.append(image_disk_cache) - - if config.latent_caching: - sort_names = [x for x in sort_names if x not in image_aggregate_names] - sort_names = [x for x in sort_names if x not in image_split_names] - - if not config.train_text_encoder_or_embedding() or not config.train_text_encoder_2_or_embedding() or not config.train_text_encoder_3_or_embedding(): - modules.append(text_disk_cache) - sort_names = [x for x in sort_names if x not in text_split_names] - - if len(sort_names) > 0: - variation_sorting = VariationSorting(names=sort_names, balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled') - modules.append(variation_sorting) - - return modules + text_split_names += ['tokens_3', 'tokens_mask_3', 'text_encoder_3_hidden_state'] + + return self._cache_modules_from_names( + model, model_setup, + image_split_names=image_split_names, + image_aggregate_names=image_aggregate_names, + text_split_names=text_split_names, + sort_names=sort_names, + config=config, + text_caching=not config.train_text_encoder_or_embedding() or not config.train_text_encoder_2_or_embedding() or not config.train_text_encoder_3_or_embedding(), + ) - def _output_modules(self, config: TrainConfig, model: StableDiffusion3Model): + def _output_modules(self, config: TrainConfig, model: StableDiffusion3Model, model_setup: BaseStableDiffusion3Setup): output_names = [ 'image_path', 'latent_image', 'prompt_1', 'prompt_2', 'prompt_3', @@ -215,26 +136,18 @@ def _output_modules(self, config: TrainConfig, model: StableDiffusion3Model): output_names.append('latent_conditioning_image') if not config.train_text_encoder_or_embedding(): - output_names.append('text_encoder_1_hidden_state') - output_names.append('text_encoder_1_pooled_state') + output_names += ['text_encoder_1_hidden_state', 'text_encoder_1_pooled_state'] if not config.train_text_encoder_2_or_embedding(): - output_names.append('text_encoder_2_hidden_state') - output_names.append('text_encoder_2_pooled_state') + output_names += ['text_encoder_2_hidden_state', 'text_encoder_2_pooled_state'] if not config.train_text_encoder_3_or_embedding(): output_names.append('text_encoder_3_hidden_state') - def before_cache_image_fun(): - model.to(self.temp_device) - model.vae_to(self.train_device) - model.eval() - torch_gc() - return self._output_modules_from_out_names( + model, model_setup, output_names=output_names, config=config, - before_cache_image_fun=before_cache_image_fun, use_conditioning_image=True, vae=model.vae, autocast_context=[model.autocast_context], @@ -261,61 +174,29 @@ def before_save_fun(): # SaveImage(image_in_name='mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1), # SaveImage(image_in_name='image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1), - modules = [] - - modules.append(decode_image) - modules.append(save_image) + modules = [decode_image, save_image] if config.model_type.has_conditioning_image_input(): - modules.append(decode_conditioning_image) - modules.append(save_conditioning_image) + modules += [decode_conditioning_image, save_conditioning_image] if config.masked_training or config.model_type.has_mask_input(): - modules.append(upscale_mask) - modules.append(save_mask) + modules += [upscale_mask, save_mask] - modules.append(decode_prompt) - modules.append(save_prompt) + modules += [decode_prompt, save_prompt] return modules - def create_dataset( + def _create_dataset( self, config: TrainConfig, - model: StableDiffusion3Model, + model: BaseModel, + model_setup: BaseModelSetup, train_progress: TrainProgress, is_validation: bool = False, ): - enumerate_input = self._enumerate_input_modules(config) - load_input = self._load_input_modules(config, model.train_dtype) - mask_augmentation = self._mask_augmentation_modules(config) - aspect_bucketing_in = self._aspect_bucketing_in(config, 64) - crop_modules = self._crop_modules(config) - augmentation_modules = self._augmentation_modules(config) - inpainting_modules = self._inpainting_modules(config) - preparation_modules = self._preparation_modules(config, model) - cache_modules = self._cache_modules(config, model) - output_modules = self._output_modules(config, model) - - debug_modules = self._debug_modules(config, model) - - return self._create_mgds( - config, - [ - enumerate_input, - load_input, - mask_augmentation, - aspect_bucketing_in, - crop_modules, - augmentation_modules, - inpainting_modules, - preparation_modules, - cache_modules, - output_modules, - - debug_modules if config.debug_mode else None, - # inserted before output_modules, which contains a sorting operation - ], - train_progress, - is_validation, + return DataLoaderText2ImageMixin._create_dataset(self, + config, model, model_setup, train_progress, is_validation, + aspect_bucketing_quantization=64, ) + +factory.register(BaseDataLoader, StableDiffusion3BaseDataLoader, ModelType.STABLE_DIFFUSION_35) diff --git a/modules/dataLoader/StableDiffusionBaseDataLoader.py b/modules/dataLoader/StableDiffusionBaseDataLoader.py index 1f83586a3..781f769a6 100644 --- a/modules/dataLoader/StableDiffusionBaseDataLoader.py +++ b/modules/dataLoader/StableDiffusionBaseDataLoader.py @@ -1,17 +1,18 @@ -import copy import os from modules.dataLoader.BaseDataLoader import BaseDataLoader from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin +from modules.model.BaseModel import BaseModel from modules.model.StableDiffusionModel import StableDiffusionModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.modelSetup.BaseStableDiffusionSetup import BaseStableDiffusionSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig -from modules.util.torch_util import torch_gc +from modules.util.enum.ModelType import ModelType from modules.util.TrainProgress import TrainProgress -from mgds.MGDS import MGDS, TrainDataLoader from mgds.pipelineModules.DecodeTokens import DecodeTokens from mgds.pipelineModules.DecodeVAE import DecodeVAE -from mgds.pipelineModules.DiskCache import DiskCache from mgds.pipelineModules.EncodeClipText import EncodeClipText from mgds.pipelineModules.EncodeVAE import EncodeVAE from mgds.pipelineModules.MapData import MapData @@ -21,48 +22,12 @@ from mgds.pipelineModules.SaveText import SaveText from mgds.pipelineModules.ScaleImage import ScaleImage from mgds.pipelineModules.Tokenize import Tokenize -from mgds.pipelineModules.VariationSorting import VariationSorting - -import torch class StableDiffusionBaseDataLoader( BaseDataLoader, DataLoaderText2ImageMixin, ): - def __init__( - self, - train_device: torch.device, - temp_device: torch.device, - config: TrainConfig, - model: StableDiffusionModel, - train_progress: TrainProgress, - is_validation: bool = False, - ): - super().__init__( - train_device, - temp_device, - ) - - if is_validation: - config = copy.copy(config) - config.batch_size = 1 - config.multi_gpu = False - - self.__ds = self.create_dataset( - config=config, - model=model, - train_progress=train_progress, - is_validation=is_validation, - ) - self.__dl = TrainDataLoader(self.__ds, config.batch_size) - - def get_data_set(self) -> MGDS: - return self.__ds - - def get_data_loader(self) -> TrainDataLoader: - return self.__dl - def _preparation_modules(self, config: TrainConfig, model: StableDiffusionModel): rescale_image = RescaleImageChannels(image_in_name='image', image_out_name='image', in_range_min=0, in_range_max=1, out_range_min=-1, out_range_max=1) rescale_conditioning_image = RescaleImageChannels(image_in_name='conditioning_image', image_out_name='conditioning_image', in_range_min=0, in_range_max=1, out_range_min=-1, out_range_max=1) @@ -74,27 +39,27 @@ def _preparation_modules(self, config: TrainConfig, model: StableDiffusionModel) conditioning_image_sample = SampleVAEDistribution(in_name='latent_conditioning_image_distribution', out_name='latent_conditioning_image', mode='mean') downscale_depth = ScaleImage(in_name='depth', out_name='latent_depth', factor=0.125) tokenize_prompt = Tokenize(in_name='prompt', tokens_out_name='tokens', mask_out_name='tokens_mask', tokenizer=model.tokenizer, max_token_length=model.tokenizer.model_max_length) - encode_prompt = EncodeClipText(in_name='tokens', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_hidden_state', pooled_out_name=None, add_layer_norm=True, text_encoder=model.text_encoder, hidden_state_output_index=-(1 + config.text_encoder_layer_skip), autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) + encode_prompt = EncodeClipText(in_name='tokens', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_hidden_state', pooled_out_name=None, add_layer_norm=True, + text_encoder=model.text_encoder, hidden_state_output_index=-(1 + config.text_encoder_layer_skip), autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) - modules = [rescale_image, encode_image, image_sample, add_embeddings_to_prompt, tokenize_prompt] + modules = [rescale_image, encode_image, image_sample] if config.masked_training or config.model_type.has_mask_input(): modules.append(downscale_mask) if config.model_type.has_conditioning_image_input(): - modules.append(rescale_conditioning_image) - modules.append(encode_conditioning_image) - modules.append(conditioning_image_sample) + modules += [rescale_conditioning_image, encode_conditioning_image, conditioning_image_sample] if config.model_type.has_depth_input(): modules.append(downscale_depth) + modules += [add_embeddings_to_prompt, tokenize_prompt] if not config.train_text_encoder_or_embedding(): modules.append(encode_prompt) return modules - def _cache_modules(self, config: TrainConfig, model: StableDiffusionModel): + def _cache_modules(self, config: TrainConfig, model: StableDiffusionModel, model_setup: BaseStableDiffusionSetup): image_split_names = ['latent_image'] if config.masked_training or config.model_type.has_mask_input(): @@ -114,45 +79,17 @@ def _cache_modules(self, config: TrainConfig, model: StableDiffusionModel): 'prompt', 'concept' ] - image_cache_dir = os.path.join(config.cache_dir, "image") - text_cache_dir = os.path.join(config.cache_dir, "text") - - def before_cache_image_fun(): - model.to(self.temp_device) - model.vae_to(self.train_device) - model.eval() - torch_gc() - - def before_cache_text_fun(): - model.to(self.temp_device) - model.text_encoder_to(self.train_device) - model.eval() - torch_gc() - - image_disk_cache = DiskCache(cache_dir=image_cache_dir, split_names=image_split_names, aggregate_names=image_aggregate_names, variations_in_name='concept.image_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.image'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_image_fun) - - text_disk_cache = DiskCache(cache_dir=text_cache_dir, split_names=text_split_names, aggregate_names=[], variations_in_name='concept.text_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_text_fun) - - modules = [] - - if config.latent_caching: - modules.append(image_disk_cache) - - if config.latent_caching: - sort_names = [x for x in sort_names if x not in image_aggregate_names] - sort_names = [x for x in sort_names if x not in image_split_names] - - if not config.train_text_encoder_or_embedding(): - modules.append(text_disk_cache) - sort_names = [x for x in sort_names if x not in text_split_names] - - if len(sort_names) > 0: - variation_sorting = VariationSorting(names=sort_names, balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled') - modules.append(variation_sorting) - - return modules + return self._cache_modules_from_names( + model, model_setup, + image_split_names=image_split_names, + image_aggregate_names=image_aggregate_names, + text_split_names=text_split_names, + sort_names=sort_names, + config=config, + text_caching=not config.train_text_encoder_or_embedding(), + ) - def _output_modules(self, config: TrainConfig, model: StableDiffusionModel): + def _output_modules(self, config: TrainConfig, model: StableDiffusionModel, model_setup: BaseStableDiffusionSetup): output_names = [ 'image_path', 'latent_image', 'prompt', @@ -171,16 +108,10 @@ def _output_modules(self, config: TrainConfig, model: StableDiffusionModel): if not config.train_text_encoder_or_embedding(): output_names.append('text_encoder_hidden_state') - def before_cache_image_fun(): - model.to(self.temp_device) - model.vae_to(self.train_device) - model.eval() - torch_gc() - return self._output_modules_from_out_names( + model, model_setup, output_names=output_names, config=config, - before_cache_image_fun=before_cache_image_fun, use_conditioning_image=True, vae=model.vae, autocast_context=[model.autocast_context], @@ -209,61 +140,36 @@ def before_save_fun(): # SaveImage(image_in_name='depth', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1), # SaveImage(image_in_name='image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1), - modules = [] - - modules.append(decode_image) - modules.append(save_image) + modules = [decode_image, save_image] if config.model_type.has_conditioning_image_input(): - modules.append(decode_conditioning_image) - modules.append(save_conditioning_image) + modules += [decode_conditioning_image, save_conditioning_image] if config.masked_training or config.model_type.has_mask_input(): - modules.append(upscale_mask) - modules.append(save_mask) + modules += [upscale_mask, save_mask] - modules.append(decode_prompt) - modules.append(save_prompt) + modules += [decode_prompt, save_prompt] return modules - def create_dataset( + def _create_dataset( self, config: TrainConfig, - model: StableDiffusionModel, + model: BaseModel, + model_setup: BaseModelSetup, train_progress: TrainProgress, is_validation: bool = False, ): - enumerate_input = self._enumerate_input_modules(config) - load_input = self._load_input_modules(config, model.train_dtype) - mask_augmentation = self._mask_augmentation_modules(config) - aspect_bucketing_in = self._aspect_bucketing_in(config, 8) - crop_modules = self._crop_modules(config) - augmentation_modules = self._augmentation_modules(config) - inpainting_modules = self._inpainting_modules(config) - preparation_modules = self._preparation_modules(config, model) - cache_modules = self._cache_modules(config, model) - output_modules = self._output_modules(config, model) - - debug_modules = self._debug_modules(config, model) - - return self._create_mgds( - config, - [ - enumerate_input, - load_input, - mask_augmentation, - aspect_bucketing_in, - crop_modules, - augmentation_modules, - inpainting_modules, - preparation_modules, - cache_modules, - output_modules, - - debug_modules if config.debug_mode else None, - # inserted before output_modules, which contains a sorting operation - ], - train_progress, - is_validation, + return DataLoaderText2ImageMixin._create_dataset(self, + config, model, model_setup, train_progress, is_validation, + aspect_bucketing_quantization=8, ) + +factory.register(BaseDataLoader, StableDiffusionBaseDataLoader, ModelType.STABLE_DIFFUSION_15) +factory.register(BaseDataLoader, StableDiffusionBaseDataLoader, ModelType.STABLE_DIFFUSION_15_INPAINTING) +factory.register(BaseDataLoader, StableDiffusionBaseDataLoader, ModelType.STABLE_DIFFUSION_20) +factory.register(BaseDataLoader, StableDiffusionBaseDataLoader, ModelType.STABLE_DIFFUSION_20_BASE) +factory.register(BaseDataLoader, StableDiffusionBaseDataLoader, ModelType.STABLE_DIFFUSION_20_INPAINTING) +factory.register(BaseDataLoader, StableDiffusionBaseDataLoader, ModelType.STABLE_DIFFUSION_20_DEPTH) +factory.register(BaseDataLoader, StableDiffusionBaseDataLoader, ModelType.STABLE_DIFFUSION_21) +factory.register(BaseDataLoader, StableDiffusionBaseDataLoader, ModelType.STABLE_DIFFUSION_21_BASE) diff --git a/modules/dataLoader/StableDiffusionFineTuneVaeDataLoader.py b/modules/dataLoader/StableDiffusionFineTuneVaeDataLoader.py index 11df2bb48..ed5dd32b8 100644 --- a/modules/dataLoader/StableDiffusionFineTuneVaeDataLoader.py +++ b/modules/dataLoader/StableDiffusionFineTuneVaeDataLoader.py @@ -1,15 +1,16 @@ -import copy import os import re from modules.dataLoader.BaseDataLoader import BaseDataLoader from modules.model.StableDiffusionModel import StableDiffusionModel -from modules.util import path_util +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.util import factory, path_util from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.torch_util import torch_gc from modules.util.TrainProgress import TrainProgress -from mgds.MGDS import MGDS, TrainDataLoader from mgds.OutputPipelineModule import OutputPipelineModule from mgds.pipelineModules.AspectBatchSorting import AspectBatchSorting from mgds.pipelineModules.AspectBucketing import AspectBucketing @@ -39,39 +40,6 @@ class StableDiffusionFineTuneVaeDataLoader(BaseDataLoader): - def __init__( - self, - train_device: torch.device, - temp_device: torch.device, - config: TrainConfig, - model: StableDiffusionModel, - train_progress: TrainProgress, - is_validation: bool = False, - ): - super().__init__( - train_device, - temp_device, - ) - - if is_validation: - config = copy.copy(config) - config.batch_size = 1 - config.multi_gpu = False - - self.__ds = self.create_dataset( - config=config, - model=model, - train_progress=train_progress, - is_validation=is_validation, - ) - self.__dl = TrainDataLoader(self.__ds, config.batch_size) - - def get_data_set(self) -> MGDS: - return self.__ds - - def get_data_loader(self) -> TrainDataLoader: - return self.__dl - def _setup_cache_device( self, model: StableDiffusionModel, @@ -222,8 +190,10 @@ def __cache_modules(self, config: TrainConfig, model: StableDiffusionModel): def before_cache_fun(): self._setup_cache_device(model, self.train_device, self.temp_device, config) - disk_cache = DiskCache(cache_dir=config.cache_dir, split_names=split_names, aggregate_names=aggregate_names, variations_in_name='concept.image_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.image'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_fun) - variation_sorting = VariationSorting(names=sort_names, balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled') + disk_cache = DiskCache(cache_dir=config.cache_dir, split_names=split_names, aggregate_names=aggregate_names, variations_in_name='concept.image_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', + variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.image'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_fun) + variation_sorting = VariationSorting(names=sort_names, balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], + group_enabled_in_name='concept.enabled') modules = [] @@ -278,21 +248,18 @@ def before_save_fun(): save_image = SaveImage(image_in_name='decoded_image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1, before_save_fun=before_save_fun) save_mask = SaveImage(image_in_name='latent_mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1, before_save_fun=before_save_fun) - modules = [] - - modules.append(decode_image) - modules.append(save_image) + modules = [decode_image, save_image] if config.masked_training or config.model_type.has_mask_input(): - modules.append(upscale_mask) - modules.append(save_mask) + modules += [upscale_mask, save_mask] return modules - def create_dataset( + def _create_dataset( self, config: TrainConfig, model: StableDiffusionModel, + model_setup: BaseModelSetup, train_progress: TrainProgress, is_validation: bool = False, ): @@ -326,3 +293,12 @@ def create_dataset( train_progress, is_validation, ) + +factory.register(BaseDataLoader, StableDiffusionFineTuneVaeDataLoader, ModelType.STABLE_DIFFUSION_15, TrainingMethod.FINE_TUNE_VAE) +factory.register(BaseDataLoader, StableDiffusionFineTuneVaeDataLoader, ModelType.STABLE_DIFFUSION_15_INPAINTING, TrainingMethod.FINE_TUNE_VAE) +factory.register(BaseDataLoader, StableDiffusionFineTuneVaeDataLoader, ModelType.STABLE_DIFFUSION_20, TrainingMethod.FINE_TUNE_VAE) +factory.register(BaseDataLoader, StableDiffusionFineTuneVaeDataLoader, ModelType.STABLE_DIFFUSION_20_BASE, TrainingMethod.FINE_TUNE_VAE) +factory.register(BaseDataLoader, StableDiffusionFineTuneVaeDataLoader, ModelType.STABLE_DIFFUSION_20_INPAINTING, TrainingMethod.FINE_TUNE_VAE) +factory.register(BaseDataLoader, StableDiffusionFineTuneVaeDataLoader, ModelType.STABLE_DIFFUSION_20_DEPTH, TrainingMethod.FINE_TUNE_VAE) +factory.register(BaseDataLoader, StableDiffusionFineTuneVaeDataLoader, ModelType.STABLE_DIFFUSION_21, TrainingMethod.FINE_TUNE_VAE) +factory.register(BaseDataLoader, StableDiffusionFineTuneVaeDataLoader, ModelType.STABLE_DIFFUSION_21_BASE, TrainingMethod.FINE_TUNE_VAE) diff --git a/modules/dataLoader/StableDiffusionXLBaseDataLoader.py b/modules/dataLoader/StableDiffusionXLBaseDataLoader.py index e6adade8b..ed1ad491e 100644 --- a/modules/dataLoader/StableDiffusionXLBaseDataLoader.py +++ b/modules/dataLoader/StableDiffusionXLBaseDataLoader.py @@ -1,17 +1,18 @@ -import copy import os from modules.dataLoader.BaseDataLoader import BaseDataLoader from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin +from modules.model.BaseModel import BaseModel from modules.model.StableDiffusionXLModel import StableDiffusionXLModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.modelSetup.BaseStableDiffusionXLSetup import BaseStableDiffusionXLSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig -from modules.util.torch_util import torch_gc +from modules.util.enum.ModelType import ModelType from modules.util.TrainProgress import TrainProgress -from mgds.MGDS import MGDS, TrainDataLoader from mgds.pipelineModules.DecodeTokens import DecodeTokens from mgds.pipelineModules.DecodeVAE import DecodeVAE -from mgds.pipelineModules.DiskCache import DiskCache from mgds.pipelineModules.EncodeClipText import EncodeClipText from mgds.pipelineModules.EncodeVAE import EncodeVAE from mgds.pipelineModules.MapData import MapData @@ -21,48 +22,12 @@ from mgds.pipelineModules.SaveText import SaveText from mgds.pipelineModules.ScaleImage import ScaleImage from mgds.pipelineModules.Tokenize import Tokenize -from mgds.pipelineModules.VariationSorting import VariationSorting - -import torch class StableDiffusionXLBaseDataLoader( BaseDataLoader, DataLoaderText2ImageMixin, ): - def __init__( - self, - train_device: torch.device, - temp_device: torch.device, - config: TrainConfig, - model: StableDiffusionXLModel, - train_progress: TrainProgress, - is_validation: bool = False, - ): - super().__init__( - train_device, - temp_device, - ) - - if is_validation: - config = copy.copy(config) - config.batch_size = 1 - config.multi_gpu = False - - self.__ds = self.create_dataset( - config=config, - model=model, - train_progress=train_progress, - is_validation=is_validation, - ) - self.__dl = TrainDataLoader(self.__ds, config.batch_size) - - def get_data_set(self) -> MGDS: - return self.__ds - - def get_data_loader(self) -> TrainDataLoader: - return self.__dl - def _preparation_modules(self, config: TrainConfig, model: StableDiffusionXLModel): rescale_image = RescaleImageChannels(image_in_name='image', image_out_name='image', in_range_min=0, in_range_max=1, out_range_min=-1, out_range_max=1) rescale_conditioning_image = RescaleImageChannels(image_in_name='conditioning_image', image_out_name='conditioning_image', in_range_min=0, in_range_max=1, out_range_min=-1, out_range_max=1) @@ -75,14 +40,12 @@ def _preparation_modules(self, config: TrainConfig, model: StableDiffusionXLMode conditioning_image_sample = SampleVAEDistribution(in_name='latent_conditioning_image_distribution', out_name='latent_conditioning_image', mode='mean') tokenize_prompt_1 = Tokenize(in_name='prompt_1', tokens_out_name='tokens_1', mask_out_name='tokens_mask_1', tokenizer=model.tokenizer_1, max_token_length=model.tokenizer_1.model_max_length) tokenize_prompt_2 = Tokenize(in_name='prompt_2', tokens_out_name='tokens_2', mask_out_name='tokens_mask_2', tokenizer=model.tokenizer_2, max_token_length=model.tokenizer_2.model_max_length) - encode_prompt_1 = EncodeClipText(in_name='tokens_1', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_1_hidden_state', pooled_out_name=None, add_layer_norm=False, text_encoder=model.text_encoder_1, hidden_state_output_index=-(2 + config.text_encoder_layer_skip), autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) - encode_prompt_2 = EncodeClipText(in_name='tokens_2', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_2_hidden_state', pooled_out_name='text_encoder_2_pooled_state', add_layer_norm=False, text_encoder=model.text_encoder_2, hidden_state_output_index=-(2 + config.text_encoder_2_layer_skip), autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) + encode_prompt_1 = EncodeClipText(in_name='tokens_1', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_1_hidden_state', pooled_out_name=None, add_layer_norm=False, + text_encoder=model.text_encoder_1, hidden_state_output_index=-(2 + config.text_encoder_layer_skip), autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) + encode_prompt_2 = EncodeClipText(in_name='tokens_2', tokens_attention_mask_in_name=None, hidden_state_out_name='text_encoder_2_hidden_state', pooled_out_name='text_encoder_2_pooled_state', add_layer_norm=False, + text_encoder=model.text_encoder_2, hidden_state_output_index=-(2 + config.text_encoder_2_layer_skip), autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) - modules = [ - rescale_image, encode_image, image_sample, - add_embeddings_to_prompt_1, tokenize_prompt_1, - add_embeddings_to_prompt_2, tokenize_prompt_2, - ] + modules = [rescale_image, encode_image, image_sample] if config.masked_training or config.model_type.has_mask_input(): modules.append(downscale_mask) @@ -92,6 +55,10 @@ def _preparation_modules(self, config: TrainConfig, model: StableDiffusionXLMode modules.append(encode_conditioning_image) modules.append(conditioning_image_sample) + modules += [ + add_embeddings_to_prompt_1, tokenize_prompt_1, + add_embeddings_to_prompt_2, tokenize_prompt_2, + ] if not config.train_text_encoder_or_embedding(): modules.append(encode_prompt_1) @@ -100,7 +67,7 @@ def _preparation_modules(self, config: TrainConfig, model: StableDiffusionXLMode return modules - def _cache_modules(self, config: TrainConfig, model: StableDiffusionXLModel): + def _cache_modules(self, config: TrainConfig, model: StableDiffusionXLModel, model_setup: BaseStableDiffusionXLSetup): image_split_names = ['latent_image', 'original_resolution', 'crop_offset'] if config.masked_training or config.model_type.has_mask_input(): @@ -120,59 +87,22 @@ def _cache_modules(self, config: TrainConfig, model: StableDiffusionXLModel): ] if not config.train_text_encoder_or_embedding(): - text_split_names.append('tokens_1') - text_split_names.append('text_encoder_1_hidden_state') + text_split_names += ['tokens_1', 'text_encoder_1_hidden_state'] if not config.train_text_encoder_2_or_embedding(): - text_split_names.append('tokens_2') - text_split_names.append('text_encoder_2_hidden_state') - text_split_names.append('text_encoder_2_pooled_state') - - image_cache_dir = os.path.join(config.cache_dir, "image") - text_cache_dir = os.path.join(config.cache_dir, "text") - - def before_cache_image_fun(): - model.to(self.temp_device) - model.vae_to(self.train_device) - model.eval() - torch_gc() - - def before_cache_text_fun(): - model.to(self.temp_device) - - if not config.train_text_encoder_or_embedding(): - model.text_encoder_1_to(self.train_device) - - if not config.train_text_encoder_2_or_embedding(): - model.text_encoder_2_to(self.train_device) - - model.eval() - torch_gc() - - image_disk_cache = DiskCache(cache_dir=image_cache_dir, split_names=image_split_names, aggregate_names=image_aggregate_names, variations_in_name='concept.image_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.image'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_image_fun) - - text_disk_cache = DiskCache(cache_dir=text_cache_dir, split_names=text_split_names, aggregate_names=[], variations_in_name='concept.text_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_text_fun) - - modules = [] - - if config.latent_caching: - modules.append(image_disk_cache) - - if config.latent_caching: - sort_names = [x for x in sort_names if x not in image_aggregate_names] - sort_names = [x for x in sort_names if x not in image_split_names] - - if not config.train_text_encoder_or_embedding() or not config.train_text_encoder_2_or_embedding(): - modules.append(text_disk_cache) - sort_names = [x for x in sort_names if x not in text_split_names] - - if len(sort_names) > 0: - variation_sorting = VariationSorting(names=sort_names, balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled') - modules.append(variation_sorting) - - return modules + text_split_names += ['tokens_2', 'text_encoder_2_hidden_state', 'text_encoder_2_pooled_state'] + + return self._cache_modules_from_names( + model, model_setup, + image_split_names=image_split_names, + image_aggregate_names=image_aggregate_names, + text_split_names=text_split_names, + sort_names=sort_names, + config=config, + text_caching=not config.train_text_encoder_or_embedding() or not config.train_text_encoder_2_or_embedding(), + ) - def _output_modules(self, config: TrainConfig, model: StableDiffusionXLModel): + def _output_modules(self, config: TrainConfig, model: StableDiffusionXLModel, model_setup: BaseStableDiffusionXLSetup): output_names = [ 'image_path', 'latent_image', 'prompt_1', 'prompt_2', @@ -190,19 +120,12 @@ def _output_modules(self, config: TrainConfig, model: StableDiffusionXLModel): output_names.append('text_encoder_1_hidden_state') if not config.train_text_encoder_2_or_embedding(): - output_names.append('text_encoder_2_hidden_state') - output_names.append('text_encoder_2_pooled_state') - - def before_cache_image_fun(): - model.to(self.temp_device) - model.vae_to(self.train_device) - model.eval() - torch_gc() + output_names += ['text_encoder_2_hidden_state', 'text_encoder_2_pooled_state'] return self._output_modules_from_out_names( + model, model_setup, output_names=output_names, config=config, - before_cache_image_fun=before_cache_image_fun, use_conditioning_image=True, vae=model.vae, autocast_context=[model.autocast_context, model.vae_autocast_context], @@ -229,61 +152,29 @@ def before_save_fun(): # SaveImage(image_in_name='mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1), # SaveImage(image_in_name='image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1), - modules = [] - - modules.append(decode_image) - modules.append(save_image) + modules = [decode_image, save_image] if config.model_type.has_conditioning_image_input(): - modules.append(decode_conditioning_image) - modules.append(save_conditioning_image) + modules += [decode_conditioning_image, save_conditioning_image] if config.masked_training or config.model_type.has_mask_input(): - modules.append(upscale_mask) - modules.append(save_mask) + modules += [upscale_mask, save_mask] - modules.append(decode_prompt) - modules.append(save_prompt) + modules += [decode_prompt, save_prompt] return modules - def create_dataset( + def _create_dataset( self, config: TrainConfig, - model: StableDiffusionXLModel, + model: BaseModel, + model_setup: BaseModelSetup, train_progress: TrainProgress, is_validation: bool = False, ): - enumerate_input = self._enumerate_input_modules(config) - load_input = self._load_input_modules(config, model.vae_train_dtype) - mask_augmentation = self._mask_augmentation_modules(config) - aspect_bucketing_in = self._aspect_bucketing_in(config, 64) - crop_modules = self._crop_modules(config) - augmentation_modules = self._augmentation_modules(config) - inpainting_modules = self._inpainting_modules(config) - preparation_modules = self._preparation_modules(config, model) - cache_modules = self._cache_modules(config, model) - output_modules = self._output_modules(config, model) - - debug_modules = self._debug_modules(config, model) - - return self._create_mgds( - config, - [ - enumerate_input, - load_input, - mask_augmentation, - aspect_bucketing_in, - crop_modules, - augmentation_modules, - inpainting_modules, - preparation_modules, - cache_modules, - output_modules, - - debug_modules if config.debug_mode else None, - # inserted before output_modules, which contains a sorting operation - ], - train_progress, - is_validation, + return DataLoaderText2ImageMixin._create_dataset(self, + config, model, model_setup, train_progress, is_validation, + aspect_bucketing_quantization=64, ) +factory.register(BaseDataLoader, StableDiffusionXLBaseDataLoader, ModelType.STABLE_DIFFUSION_XL_10_BASE) +factory.register(BaseDataLoader, StableDiffusionXLBaseDataLoader, ModelType.STABLE_DIFFUSION_XL_10_BASE_INPAINTING) diff --git a/modules/dataLoader/WuerstchenBaseDataLoader.py b/modules/dataLoader/WuerstchenBaseDataLoader.py index bc7763b5a..f5cf2f41a 100644 --- a/modules/dataLoader/WuerstchenBaseDataLoader.py +++ b/modules/dataLoader/WuerstchenBaseDataLoader.py @@ -1,17 +1,19 @@ -import copy import os from modules.dataLoader.BaseDataLoader import BaseDataLoader from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin from modules.dataLoader.wuerstchen.EncodeWuerstchenEffnet import EncodeWuerstchenEffnet +from modules.model.BaseModel import BaseModel from modules.model.WuerstchenModel import WuerstchenModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.modelSetup.BaseWuerstchenSetup import BaseWuerstchenSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType from modules.util.torch_util import torch_gc from modules.util.TrainProgress import TrainProgress -from mgds.MGDS import MGDS, TrainDataLoader from mgds.pipelineModules.DecodeTokens import DecodeTokens -from mgds.pipelineModules.DiskCache import DiskCache from mgds.pipelineModules.EncodeClipText import EncodeClipText from mgds.pipelineModules.MapData import MapData from mgds.pipelineModules.NormalizeImageChannels import NormalizeImageChannels @@ -19,48 +21,12 @@ from mgds.pipelineModules.SaveText import SaveText from mgds.pipelineModules.ScaleImage import ScaleImage from mgds.pipelineModules.Tokenize import Tokenize -from mgds.pipelineModules.VariationSorting import VariationSorting - -import torch class WuerstchenBaseDataLoader( BaseDataLoader, DataLoaderText2ImageMixin, ): - def __init__( - self, - train_device: torch.device, - temp_device: torch.device, - config: TrainConfig, - model: WuerstchenModel, - train_progress: TrainProgress, - is_validation: bool = False, - ): - super().__init__( - train_device, - temp_device, - ) - - if is_validation: - config = copy.copy(config) - config.batch_size = 1 - config.multi_gpu = False - - self.__ds = self.create_dataset( - config=config, - model=model, - train_progress=train_progress, - is_validation=is_validation, - ) - self.__dl = TrainDataLoader(self.__ds, config.batch_size) - - def get_data_set(self) -> MGDS: - return self.__ds - - def get_data_loader(self) -> TrainDataLoader: - return self.__dl - def _preparation_modules(self, config: TrainConfig, model: WuerstchenModel): downscale_image = ScaleImage(in_name='image', out_name='image', factor=0.75) normalize_image = NormalizeImageChannels(image_in_name='image', image_out_name='image', mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) @@ -69,24 +35,24 @@ def _preparation_modules(self, config: TrainConfig, model: WuerstchenModel): add_embeddings_to_prompt = MapData(in_name='prompt', out_name='prompt', map_fn=model.add_prior_text_encoder_embeddings_to_prompt) tokenize_prompt = Tokenize(in_name='prompt', tokens_out_name='tokens', mask_out_name='tokens_mask', tokenizer=model.prior_tokenizer, max_token_length=model.prior_tokenizer.model_max_length) if model.model_type.is_wuerstchen_v2(): - encode_prompt = EncodeClipText(in_name='tokens', tokens_attention_mask_in_name='tokens_mask', hidden_state_out_name='text_encoder_hidden_state', pooled_out_name=None, add_layer_norm=True, text_encoder=model.prior_text_encoder, hidden_state_output_index=-1, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) + encode_prompt = EncodeClipText(in_name='tokens', tokens_attention_mask_in_name='tokens_mask', hidden_state_out_name='text_encoder_hidden_state', pooled_out_name=None, add_layer_norm=True, + text_encoder=model.prior_text_encoder, hidden_state_output_index=-1, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) elif model.model_type.is_stable_cascade(): - encode_prompt = EncodeClipText(in_name='tokens', tokens_attention_mask_in_name='tokens_mask', hidden_state_out_name='text_encoder_hidden_state', pooled_out_name='pooled_text_encoder_output', add_layer_norm=False, text_encoder=model.prior_text_encoder, hidden_state_output_index=-1, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) + encode_prompt = EncodeClipText(in_name='tokens', tokens_attention_mask_in_name='tokens_mask', hidden_state_out_name='text_encoder_hidden_state', pooled_out_name='pooled_text_encoder_output', add_layer_norm=False, + text_encoder=model.prior_text_encoder, hidden_state_output_index=-1, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) - modules = [ - downscale_image, normalize_image, encode_image, - add_embeddings_to_prompt, tokenize_prompt, - ] + modules = [downscale_image, normalize_image, encode_image] if config.masked_training or config.model_type.has_mask_input(): modules.append(downscale_mask) + modules += [add_embeddings_to_prompt, tokenize_prompt] if not config.train_text_encoder_or_embedding(): modules.append(encode_prompt) return modules - def _cache_modules(self, config: TrainConfig, model: WuerstchenModel): + def _cache_modules(self, config: TrainConfig, model: WuerstchenModel, model_setup: BaseWuerstchenSetup): image_split_names = [ 'latent_image', 'original_resolution', 'crop_offset', @@ -105,45 +71,24 @@ def _cache_modules(self, config: TrainConfig, model: WuerstchenModel): 'prompt', 'concept' ] - image_cache_dir = os.path.join(config.cache_dir, "image") - text_cache_dir = os.path.join(config.cache_dir, "text") - def before_cache_image_fun(): model.to(self.temp_device) model.effnet_encoder_to(self.train_device) model.eval() torch_gc() - def before_cache_text_fun(): - model.to(self.temp_device) - model.prior_text_encoder_to(self.train_device) - model.eval() - torch_gc() - - image_disk_cache = DiskCache(cache_dir=image_cache_dir, split_names=image_split_names, aggregate_names=image_aggregate_names, variations_in_name='concept.image_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.image'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_image_fun) - - text_disk_cache = DiskCache(cache_dir=text_cache_dir, split_names=text_split_names, aggregate_names=[], variations_in_name='concept.text_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_text_fun) - - modules = [] - - if config.latent_caching: - modules.append(image_disk_cache) - - if config.latent_caching: - sort_names = [x for x in sort_names if x not in image_aggregate_names] - sort_names = [x for x in sort_names if x not in image_split_names] - - if not config.train_text_encoder_or_embedding(): - modules.append(text_disk_cache) - sort_names = [x for x in sort_names if x not in text_split_names] - - if len(sort_names) > 0: - variation_sorting = VariationSorting(names=sort_names, balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled') - modules.append(variation_sorting) - - return modules + return self._cache_modules_from_names( + model, model_setup, + image_split_names=image_split_names, + image_aggregate_names=image_aggregate_names, + text_split_names=text_split_names, + sort_names=sort_names, + config=config, + text_caching=not config.train_text_encoder_or_embedding(), + before_cache_image_fun=before_cache_image_fun + ) - def _output_modules(self, config: TrainConfig, model: WuerstchenModel): + def _output_modules(self, config: TrainConfig, model: WuerstchenModel, model_setup: BaseWuerstchenSetup): output_names = [ 'image_path', 'latent_image', 'prompt', @@ -168,6 +113,7 @@ def before_cache_image_fun(): torch_gc() return self._output_modules_from_out_names( + model, model_setup, output_names=output_names, config=config, before_cache_image_fun=before_cache_image_fun, @@ -191,49 +137,25 @@ def _debug_modules(self, config: TrainConfig, model: WuerstchenModel): modules = [] if config.masked_training or config.model_type.has_mask_input(): - modules.append(upscale_mask) - modules.append(save_mask) + modules += [upscale_mask, save_mask] - modules.append(decode_prompt) - modules.append(save_prompt) + modules += [decode_prompt, save_prompt] return modules - def create_dataset( + def _create_dataset( self, config: TrainConfig, - model: WuerstchenModel, + model: BaseModel, + model_setup: BaseModelSetup, train_progress: TrainProgress, is_validation: bool = False, ): - enumerate_input = self._enumerate_input_modules(config) - load_input = self._load_input_modules(config, model.train_dtype) - mask_augmentation = self._mask_augmentation_modules(config) - aspect_bucketing_in = self._aspect_bucketing_in(config, 128) - crop_modules = self._crop_modules(config) - augmentation_modules = self._augmentation_modules(config) - preparation_modules = self._preparation_modules(config, model) - cache_modules = self._cache_modules(config, model) - output_modules = self._output_modules(config, model) - - debug_modules = self._debug_modules(config, model) - - return self._create_mgds( - config, - [ - enumerate_input, - load_input, - mask_augmentation, - aspect_bucketing_in, - crop_modules, - augmentation_modules, - preparation_modules, - cache_modules, - output_modules, - - debug_modules if config.debug_mode else None, - # inserted before output_modules, which contains a sorting operation - ], - train_progress, - is_validation, + return DataLoaderText2ImageMixin._create_dataset(self, + config, model, model_setup, train_progress, is_validation, + aspect_bucketing_quantization=128, + supports_inpainting=False, ) + +factory.register(BaseDataLoader, WuerstchenBaseDataLoader, ModelType.WUERSTCHEN_2) +factory.register(BaseDataLoader, WuerstchenBaseDataLoader, ModelType.STABLE_CASCADE_1) diff --git a/modules/dataLoader/ZImageBaseDataLoader.py b/modules/dataLoader/ZImageBaseDataLoader.py index dccae2709..3b9450024 100644 --- a/modules/dataLoader/ZImageBaseDataLoader.py +++ b/modules/dataLoader/ZImageBaseDataLoader.py @@ -1,17 +1,18 @@ -import copy import os from modules.dataLoader.BaseDataLoader import BaseDataLoader from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin +from modules.model.BaseModel import BaseModel from modules.model.ZImageModel import PROMPT_MAX_LENGTH, ZImageModel, format_input +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.modelSetup.BaseZImageSetup import BaseZImageSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig -from modules.util.torch_util import torch_gc +from modules.util.enum.ModelType import ModelType from modules.util.TrainProgress import TrainProgress -from mgds.MGDS import MGDS, TrainDataLoader from mgds.pipelineModules.DecodeTokens import DecodeTokens from mgds.pipelineModules.DecodeVAE import DecodeVAE -from mgds.pipelineModules.DiskCache import DiskCache from mgds.pipelineModules.EncodeQwenText import EncodeQwenText from mgds.pipelineModules.EncodeVAE import EncodeVAE from mgds.pipelineModules.RescaleImageChannels import RescaleImageChannels @@ -20,48 +21,12 @@ from mgds.pipelineModules.SaveText import SaveText from mgds.pipelineModules.ScaleImage import ScaleImage from mgds.pipelineModules.Tokenize import Tokenize -from mgds.pipelineModules.VariationSorting import VariationSorting -import torch - -class ZImageBaseDataLoader( #TODO share code +class ZImageBaseDataLoader( BaseDataLoader, DataLoaderText2ImageMixin, ): - def __init__( - self, - train_device: torch.device, - temp_device: torch.device, - config: TrainConfig, - model: ZImageModel, - train_progress: TrainProgress, - is_validation: bool = False, - ): - super().__init__( - train_device, - temp_device, - ) - - if is_validation: - config = copy.copy(config) - config.batch_size = 1 - config.multi_gpu = False - - self.__ds = self.create_dataset( - config=config, - model=model, - train_progress=train_progress, - is_validation=is_validation, - ) - self.__dl = TrainDataLoader(self.__ds, config.batch_size) - - def get_data_set(self) -> MGDS: - return self.__ds - - def get_data_loader(self) -> TrainDataLoader: - return self.__dl - def _preparation_modules(self, config: TrainConfig, model: ZImageModel): rescale_image = RescaleImageChannels(image_in_name='image', image_out_name='image', in_range_min=0, in_range_max=1, out_range_min=-1, out_range_max=1) encode_image = EncodeVAE(in_name='image', out_name='latent_image_distribution', vae=model.vae, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) @@ -77,14 +42,13 @@ def _preparation_modules(self, config: TrainConfig, model: ZImageModel): modules = [rescale_image, encode_image, image_sample] - modules.append(tokenize_prompt) - if config.masked_training or config.model_type.has_mask_input(): #FIXME correct? + if config.masked_training or config.model_type.has_mask_input(): modules.append(downscale_mask) - modules.append(encode_prompt) + modules += [tokenize_prompt, encode_prompt] return modules - def _cache_modules(self, config: TrainConfig, model: ZImageModel): + def _cache_modules(self, config: TrainConfig, model: ZImageModel, model_setup: BaseZImageSetup): image_split_names = ['latent_image', 'original_resolution', 'crop_offset'] if config.masked_training or config.model_type.has_mask_input(): @@ -99,48 +63,19 @@ def _cache_modules(self, config: TrainConfig, model: ZImageModel): 'concept' ] - text_split_names.append('tokens') - text_split_names.append('tokens_mask') - - text_split_names.append('text_encoder_hidden_state') - - image_cache_dir = os.path.join(config.cache_dir, "image") - text_cache_dir = os.path.join(config.cache_dir, "text") + text_split_names += ['tokens', 'tokens_mask', 'text_encoder_hidden_state'] - def before_cache_image_fun(): - model.to(self.temp_device) - model.vae_to(self.train_device) - model.eval() - torch_gc() - - def before_cache_text_fun(): - model.to(self.temp_device) - model.text_encoder_to(self.train_device) - model.eval() - torch_gc() - - image_disk_cache = DiskCache(cache_dir=image_cache_dir, split_names=image_split_names, aggregate_names=image_aggregate_names, variations_in_name='concept.image_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.image'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_image_fun) - text_disk_cache = DiskCache(cache_dir=text_cache_dir, split_names=text_split_names, aggregate_names=[], variations_in_name='concept.text_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_text_fun) - - modules = [] - - if config.latent_caching: - modules.append(image_disk_cache) - - if config.latent_caching: - sort_names = [x for x in sort_names if x not in image_aggregate_names] - sort_names = [x for x in sort_names if x not in image_split_names] - - modules.append(text_disk_cache) - sort_names = [x for x in sort_names if x not in text_split_names] - - if len(sort_names) > 0: - variation_sorting = VariationSorting(names=sort_names, balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled') - modules.append(variation_sorting) - - return modules + return self._cache_modules_from_names( + model, model_setup, + image_split_names=image_split_names, + image_aggregate_names=image_aggregate_names, + text_split_names=text_split_names, + sort_names=sort_names, + config=config, + text_caching=True, + ) - def _output_modules(self, config: TrainConfig, model: ZImageModel): + def _output_modules(self, config: TrainConfig, model: ZImageModel, model_setup: BaseZImageSetup): output_names = [ 'image_path', 'latent_image', 'prompt', @@ -154,16 +89,10 @@ def _output_modules(self, config: TrainConfig, model: ZImageModel): output_names.append('text_encoder_hidden_state') - def before_cache_image_fun(): - model.to(self.temp_device) - model.vae_to(self.train_device) - model.eval() - torch_gc() - return self._output_modules_from_out_names( + model, model_setup, output_names=output_names, config=config, - before_cache_image_fun=before_cache_image_fun, use_conditioning_image=False, vae=model.vae, autocast_context=[model.autocast_context], @@ -188,57 +117,26 @@ def before_save_fun(): # SaveImage(image_in_name='mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1), # SaveImage(image_in_name='image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1), - modules = [] - - modules.append(decode_image) - modules.append(save_image) + modules = [decode_image, save_image] if config.masked_training or config.model_type.has_mask_input(): - modules.append(upscale_mask) - modules.append(save_mask) + modules += [upscale_mask, save_mask] - modules.append(decode_prompt) - modules.append(save_prompt) + modules += [decode_prompt, save_prompt] return modules - def create_dataset( + def _create_dataset( self, config: TrainConfig, - model: ZImageModel, + model: BaseModel, + model_setup: BaseModelSetup, train_progress: TrainProgress, is_validation: bool = False, ): - enumerate_input = self._enumerate_input_modules(config) - load_input = self._load_input_modules(config, model.train_dtype) - mask_augmentation = self._mask_augmentation_modules(config) - aspect_bucketing_in = self._aspect_bucketing_in(config, 64) - crop_modules = self._crop_modules(config) - augmentation_modules = self._augmentation_modules(config) - inpainting_modules = self._inpainting_modules(config) - preparation_modules = self._preparation_modules(config, model) - cache_modules = self._cache_modules(config, model) - output_modules = self._output_modules(config, model) - - debug_modules = self._debug_modules(config, model) - - return self._create_mgds( - config, - [ - enumerate_input, - load_input, - mask_augmentation, - aspect_bucketing_in, - crop_modules, - augmentation_modules, - inpainting_modules, - preparation_modules, - cache_modules, - output_modules, - - debug_modules if config.debug_mode else None, - # inserted before output_modules, which contains a sorting operation - ], - train_progress, - is_validation + return DataLoaderText2ImageMixin._create_dataset(self, + config, model, model_setup, train_progress, is_validation, + aspect_bucketing_quantization=64, ) + +factory.register(BaseDataLoader, ZImageBaseDataLoader, ModelType.Z_IMAGE) diff --git a/modules/dataLoader/mixin/DataLoaderText2ImageMixin.py b/modules/dataLoader/mixin/DataLoaderText2ImageMixin.py index bc9d69d56..78b74f8ea 100644 --- a/modules/dataLoader/mixin/DataLoaderText2ImageMixin.py +++ b/modules/dataLoader/mixin/DataLoaderText2ImageMixin.py @@ -1,10 +1,17 @@ +import os import re +from abc import ABCMeta, abstractmethod from collections.abc import Callable import modules.util.multi_gpu_util as multi +from modules.model.BaseModel import BaseModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.modelSetup.mixin.ModelSetupText2ImageMixin import ModelSetupText2ImageMixin from modules.util import path_util from modules.util.config.TrainConfig import TrainConfig from modules.util.enum.DataType import DataType +from modules.util.torch_util import torch_gc +from modules.util.TrainProgress import TrainProgress from mgds.OutputPipelineModule import OutputPipelineModule from mgds.pipelineModules.AspectBatchSorting import AspectBatchSorting @@ -12,6 +19,7 @@ from mgds.pipelineModules.CalcAspect import CalcAspect from mgds.pipelineModules.CapitalizeTags import CapitalizeTags from mgds.pipelineModules.CollectPaths import CollectPaths +from mgds.pipelineModules.DiskCache import DiskCache from mgds.pipelineModules.DistributedSampler import DistributedSampler from mgds.pipelineModules.DownloadHuggingfaceDatasets import DownloadHuggingfaceDatasets from mgds.pipelineModules.DropTags import DropTags @@ -40,16 +48,14 @@ from mgds.pipelineModules.SelectRandomText import SelectRandomText from mgds.pipelineModules.ShuffleTags import ShuffleTags from mgds.pipelineModules.SingleAspectCalculation import SingleAspectCalculation +from mgds.pipelineModules.VariationSorting import VariationSorting import torch from diffusers import AutoencoderKL -class DataLoaderText2ImageMixin: - def __init__(self): - pass - +class DataLoaderText2ImageMixin(metaclass=ABCMeta): def _enumerate_input_modules(self, config: TrainConfig, allow_videos: bool = False) -> list: supported_extensions = set() supported_extensions |= path_util.supported_image_extensions() @@ -256,6 +262,8 @@ def _inpainting_modules(self, config: TrainConfig): def _output_modules_from_out_names( self, + model: BaseModel, + model_setup: ModelSetupText2ImageMixin, output_names: list[str | tuple[str, str]], config: TrainConfig, before_cache_image_fun: Callable[[], None] | None = None, @@ -264,6 +272,11 @@ def _output_modules_from_out_names( autocast_context: list[torch.autocast | None] = None, train_dtype: DataType | None = None, ): + if before_cache_image_fun is None: + def prepare_vae(): + model_setup.prepare_vae(model) + before_cache_image_fun = prepare_vae + sort_names = output_names + ['concept'] output_names = output_names + [ @@ -306,3 +319,121 @@ def _output_modules_from_out_names( modules.append(output) return modules + + def _cache_modules_from_names( + self, + model: BaseModel, + model_setup: ModelSetupText2ImageMixin, + image_split_names: list[str], + image_aggregate_names: list[str], + text_split_names: list[str], + sort_names: list[str], + config: TrainConfig, + text_caching: bool, + before_cache_image_fun: Callable[[], None] | None = None, + ): + image_cache_dir = os.path.join(config.cache_dir, "image") + text_cache_dir = os.path.join(config.cache_dir, "text") + + if before_cache_image_fun is None: + def prepare_vae(): + model.to(self.temp_device) + model.vae_to(self.train_device) + model.eval() + torch_gc() + before_cache_image_fun = prepare_vae + + def before_cache_text_fun(): + model_setup.prepare_text_caching(model, config) + + image_disk_cache = DiskCache(cache_dir=image_cache_dir, split_names=image_split_names, aggregate_names=image_aggregate_names, variations_in_name='concept.image_variations', + balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.image'], + group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_image_fun) + + text_disk_cache = DiskCache(cache_dir=text_cache_dir, split_names=text_split_names, aggregate_names=[], variations_in_name='concept.text_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', + variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_text_fun) + + modules = [] + + if config.latent_caching: + modules.append(image_disk_cache) + + sort_names = [x for x in sort_names if x not in image_aggregate_names] + sort_names = [x for x in sort_names if x not in image_split_names] + + if text_caching: + modules.append(text_disk_cache) + sort_names = [x for x in sort_names if x not in text_split_names] + + if len(sort_names) > 0: + variation_sorting = VariationSorting(names=sort_names, balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', + variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled') + + modules.append(variation_sorting) + + return modules + + + def _create_dataset( + self, + config: TrainConfig, + model: BaseModel, + model_setup: ModelSetupText2ImageMixin, + train_progress: TrainProgress, + is_validation: bool, + aspect_bucketing_quantization: int, + frame_dim_enabled: bool=False, + allow_video_files: bool=False, + allow_video: bool=False, #TODO workaround for Qwen - is it the same as frame_dim_enabled? + supports_inpainting: bool=True, #TODO many models probably don't support inpainting, but this has been enabled in most dataloaders before refactoring, too + ): + enumerate_input = self._enumerate_input_modules(config, allow_videos=allow_video_files) + load_input = self._load_input_modules(config, model.train_dtype, allow_video=allow_video) + mask_augmentation = self._mask_augmentation_modules(config) + aspect_bucketing_in = self._aspect_bucketing_in(config, aspect_bucketing_quantization, frame_dim_enabled) + crop_modules = self._crop_modules(config) + augmentation_modules = self._augmentation_modules(config) + if supports_inpainting: + inpainting_modules = self._inpainting_modules(config) + preparation_modules = self._preparation_modules(config, model) + cache_modules = self._cache_modules(config, model, model_setup) + output_modules = self._output_modules(config, model, model_setup) + + debug_modules = self._debug_modules(config, model) + + return self._create_mgds( + config, + [ + enumerate_input, + load_input, + mask_augmentation, + aspect_bucketing_in, + crop_modules, + augmentation_modules + ] + ([inpainting_modules] if supports_inpainting else []) + [ + preparation_modules, + cache_modules, + output_modules, + + debug_modules if config.debug_mode else None, + # inserted before output_modules, which contains a sorting operation + ], + train_progress, + is_validation + ) + + @abstractmethod + def _preparation_modules(self, config: TrainConfig, model: BaseModel): + pass + + @abstractmethod + def _cache_modules(self, config: TrainConfig, model: BaseModel, model_setup: BaseModelSetup): + pass + + @abstractmethod + def _output_modules(self, config: TrainConfig, model: BaseModel, model_setup: BaseModelSetup): + pass + + @abstractmethod + def _debug_modules(self, config: TrainConfig, model: BaseModel): + pass diff --git a/modules/model/ChromaModel.py b/modules/model/ChromaModel.py index d62a8d21b..6de84887e 100644 --- a/modules/model/ChromaModel.py +++ b/modules/model/ChromaModel.py @@ -268,3 +268,8 @@ def unpack_latents(self, latents, height: int, width: int): latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) return latents + + +from modules.util import factory + +factory.register(BaseModel, ChromaModel, ModelType.CHROMA_1) diff --git a/modules/model/Flux2Model.py b/modules/model/Flux2Model.py new file mode 100644 index 000000000..83c58403b --- /dev/null +++ b/modules/model/Flux2Model.py @@ -0,0 +1,298 @@ +import math +from contextlib import nullcontext +from random import Random + +from modules.model.BaseModel import BaseModel +from modules.module.LoRAModule import LoRAModuleWrapper +from modules.util.convert_util import add_prefix, lora_qkv_fusion, qkv_fusion, remove_prefix, swap_chunks +from modules.util.enum.ModelType import ModelType +from modules.util.LayerOffloadConductor import LayerOffloadConductor + +import torch +from torch import Tensor + +from diffusers import ( + AutoencoderKLFlux2, + DiffusionPipeline, + FlowMatchEulerDiscreteScheduler, + Flux2Pipeline, + Flux2Transformer2DModel, +) +from diffusers.pipelines.flux2.pipeline_flux2 import format_input +from transformers import AutoProcessor, Mistral3ForConditionalGeneration + +SYSTEM_MESSAGE = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation." +HIDDEN_STATES_LAYERS = [10, 20, 30] + +def diffusers_to_original(qkv_fusion): + return [ + ("context_embedder", "txt_in"), + ("x_embedder", "img_in"), + ("time_guidance_embed.timestep_embedder", "time_in", [ + ("linear_1", "in_layer"), + ("linear_2", "out_layer"), + ]), + ("time_guidance_embed.guidance_embedder", "guidance_in", [ + ("linear_1", "in_layer"), + ("linear_2", "out_layer"), + ]), + ("double_stream_modulation_img.linear", "double_stream_modulation_img.lin"), + ("double_stream_modulation_txt.linear", "double_stream_modulation_txt.lin"), + ("single_stream_modulation.linear", "single_stream_modulation.lin"), + ("proj_out", "final_layer.linear"), + ("norm_out.linear", "final_layer.adaLN_modulation.1", swap_chunks, swap_chunks), + ("transformer_blocks.{i}", "double_blocks.{i}", + qkv_fusion("attn.to_q", "attn.to_k", "attn.to_v", "img_attn.qkv") + \ + qkv_fusion("attn.add_q_proj", "attn.add_k_proj", "attn.add_v_proj", "txt_attn.qkv") + [ + ("attn.norm_k.weight", "img_attn.norm.key_norm.scale"), + ("attn.norm_q.weight", "img_attn.norm.query_norm.scale"), + ("attn.to_out.0", "img_attn.proj"), + ("ff.linear_in", "img_mlp.0"), + ("ff.linear_out", "img_mlp.2"), + ("attn.norm_added_k.weight", "txt_attn.norm.key_norm.scale"), + ("attn.norm_added_q.weight", "txt_attn.norm.query_norm.scale"), + ("attn.to_add_out", "txt_attn.proj"), + ("ff_context.linear_in", "txt_mlp.0"), + ("ff_context.linear_out", "txt_mlp.2"), + ]), + ("single_transformer_blocks.{i}", "single_blocks.{i}", [ + ("attn.to_qkv_mlp_proj", "linear1"), + ("attn.to_out", "linear2"), + ("attn.norm_k.weight", "norm.key_norm.scale"), + ("attn.norm_q.weight", "norm.query_norm.scale"), + ]), + ] + +diffusers_lora_to_original = diffusers_to_original(lora_qkv_fusion) +diffusers_checkpoint_to_original = diffusers_to_original(qkv_fusion) +diffusers_lora_to_comfy = [remove_prefix("transformer"), diffusers_to_original(lora_qkv_fusion), add_prefix("diffusion_model")] + + +class Flux2Model(BaseModel): + # base model data + tokenizer: AutoProcessor | None + noise_scheduler: FlowMatchEulerDiscreteScheduler | None + text_encoder: Mistral3ForConditionalGeneration | None + vae: AutoencoderKLFlux2 | None + transformer: Flux2Transformer2DModel | None + + # autocast context + text_encoder_autocast_context: torch.autocast | nullcontext + + text_encoder_offload_conductor: LayerOffloadConductor | None + transformer_offload_conductor: LayerOffloadConductor | None + + transformer_lora: LoRAModuleWrapper | None + lora_state_dict: dict | None + + def __init__( + self, + model_type: ModelType, + ): + super().__init__( + model_type=model_type, + ) + + self.tokenizer = None + self.noise_scheduler = None + self.text_encoder = None + self.vae = None + self.transformer = None + + self.text_encoder_autocast_context = nullcontext() + + self.text_encoder_offload_conductor = None + self.transformer_offload_conductor = None + + self.transformer_lora = None + self.lora_state_dict = None + + def adapters(self) -> list[LoRAModuleWrapper]: + return [a for a in [ + self.transformer_lora, + ] if a is not None] + + def vae_to(self, device: torch.device): + self.vae.to(device=device) + + def text_encoder_to(self, device: torch.device): + if self.text_encoder is not None: + if self.text_encoder_offload_conductor is not None and \ + self.text_encoder_offload_conductor.layer_offload_activated(): + self.text_encoder_offload_conductor.to(device) + else: + self.text_encoder.to(device=device) + + def transformer_to(self, device: torch.device): + if self.transformer_offload_conductor is not None and \ + self.transformer_offload_conductor.layer_offload_activated(): + self.transformer_offload_conductor.to(device) + else: + self.transformer.to(device=device) + + if self.transformer_lora is not None: + self.transformer_lora.to(device) + + def to(self, device: torch.device): + self.vae_to(device) + self.text_encoder_to(device) + self.transformer_to(device) + + def eval(self): + self.vae.eval() + if self.text_encoder is not None: + self.text_encoder.eval() + self.transformer.eval() + + def create_pipeline(self) -> DiffusionPipeline: + return Flux2Pipeline( + transformer=self.transformer, + scheduler=self.noise_scheduler, + vae=self.vae, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + ) + + def encode_text( + self, + train_device: torch.device, + batch_size: int = 1, + rand: Random | None = None, + text: str = None, + tokens: Tensor = None, + tokens_mask: Tensor = None, + text_encoder_sequence_length: int | None = None, + text_encoder_dropout_probability: float | None = None, + text_encoder_output: Tensor = None, + ) -> tuple[Tensor, Tensor]: + if tokens is None and text is not None: + if isinstance(text, str): + text = [text] + + messages = format_input(prompts=text, system_message=SYSTEM_MESSAGE) + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False, + ) + + tokenizer_output = self.tokenizer( + text, + max_length=text_encoder_sequence_length, #max length is including system message + padding='max_length', + truncation=True, + return_tensors="pt" + ) + tokens = tokenizer_output.input_ids.to(self.text_encoder.device) + tokens_mask = tokenizer_output.attention_mask.to(self.text_encoder.device) + + if text_encoder_output is None and self.text_encoder is not None: + with self.text_encoder_autocast_context: + text_encoder_output = self.text_encoder( + tokens, + attention_mask=tokens_mask.float(), + output_hidden_states=True, + use_cache=False, + ) + + text_encoder_output = torch.stack([text_encoder_output.hidden_states[k] for k in HIDDEN_STATES_LAYERS], dim=1) + batch_size, num_channels, seq_len, hidden_dim = text_encoder_output.shape + assert seq_len == text_encoder_sequence_length + text_encoder_output = text_encoder_output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + if text_encoder_dropout_probability is not None and text_encoder_dropout_probability > 0.0: + raise NotImplementedError #https://github.com/Nerogar/OneTrainer/issues/957 + + return text_encoder_output + + + #code adapted from https://github.com/huggingface/diffusers/blob/c8656ed73c638e51fc2e777a5fd355d69fa5220f/src/diffusers/pipelines/flux2/pipeline_flux2.py + @staticmethod + def prepare_latent_image_ids(latents: torch.Tensor) -> torch.Tensor: + batch_size, _, height, width = latents.shape + + t = torch.arange(1, device=latents.device) + h = torch.arange(height, device=latents.device) + w = torch.arange(width, device=latents.device) + l_ = torch.arange(1, device=latents.device) + + latent_ids = torch.cartesian_prod(t, h, w, l_) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + #packing and unpacking on patchified latents + @staticmethod + def pack_latents(latents) -> Tensor: + batch_size, num_channels, height, width = latents.shape + return latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + + @staticmethod + def unpack_latents(latents, height: int, width: int) -> Tensor: + batch_size, seq_len, num_channels = latents.shape + return latents.reshape(batch_size, height, width, num_channels).permute(0, 3, 1, 2) + + #TODO inference code uses empirical mu. But that code cannot be used for inference because it depends on num of inference steps + # is dynamic timestep shifting during training still applicable? + #unpatchified width and height + def calculate_timestep_shift(self, latent_height: int, latent_width: int) -> float: + base_seq_len = self.noise_scheduler.config.base_image_seq_len + max_seq_len = self.noise_scheduler.config.max_image_seq_len + base_shift = self.noise_scheduler.config.base_shift + max_shift = self.noise_scheduler.config.max_shift + patch_size = 2 + + image_seq_len = (latent_width // patch_size) * (latent_height // patch_size) + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return math.exp(mu) + + @staticmethod + def prepare_text_ids(x: torch.Tensor) -> torch.Tensor: + B, L, _ = x.shape + out_ids = [] + + for _ in range(B): #TODO why iterate? can text ids have different length? according to diffusers and original inference code: no + t = torch.arange(1, device=x.device) + h = torch.arange(1, device=x.device) + w = torch.arange(1, device=x.device) + l_ = torch.arange(L, device=x.device) + + coords = torch.cartesian_prod(t, h, w, l_) + out_ids.append(coords) + + return torch.stack(out_ids) + + @staticmethod + def patchify_latents(latents: torch.Tensor) -> torch.Tensor: + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + @staticmethod + def unpatchify_latents(latents: torch.Tensor) -> torch.Tensor: + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + #scaling on patchified latents + def scale_latents(self, latents: Tensor) -> Tensor: + #TODO moves to device - necessary? save in model? + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + return (latents - latents_bn_mean) / latents_bn_std + + + def unscale_latents(self, latents: Tensor) -> Tensor: + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + return latents * latents_bn_std + latents_bn_mean diff --git a/modules/model/FluxModel.py b/modules/model/FluxModel.py index c9ec81c88..4b85f1272 100644 --- a/modules/model/FluxModel.py +++ b/modules/model/FluxModel.py @@ -341,7 +341,7 @@ def unpack_latents(self, latents, height: int, width: int): return latents - def calculate_timestep_shift(self, latent_width: int, latent_height: int): + def calculate_timestep_shift(self, latent_height: int, latent_width: int): base_seq_len = self.noise_scheduler.config.base_image_seq_len max_seq_len = self.noise_scheduler.config.max_image_seq_len base_shift = self.noise_scheduler.config.base_shift diff --git a/modules/model/WuerstchenModel.py b/modules/model/WuerstchenModel.py index 1404663e0..b470cf019 100644 --- a/modules/model/WuerstchenModel.py +++ b/modules/model/WuerstchenModel.py @@ -170,6 +170,9 @@ def to(self, device: torch.device): self.prior_text_encoder_to(device) self.prior_prior_to(device) + def vae_to(self, device: torch.device): + raise NotImplementedError + def eval(self): if self.model_type.is_wuerstchen_v2(): self.decoder_text_encoder.eval() diff --git a/modules/modelLoader/Flux2ModelLoader.py b/modules/modelLoader/Flux2ModelLoader.py new file mode 100644 index 000000000..6329b1f95 --- /dev/null +++ b/modules/modelLoader/Flux2ModelLoader.py @@ -0,0 +1,227 @@ +import os +import traceback + +from modules.model.BaseModel import BaseModel +from modules.model.Flux2Model import Flux2Model +from modules.modelLoader.GenericFineTuneModelLoader import make_fine_tune_model_loader +from modules.modelLoader.GenericLoRAModelLoader import make_lora_model_loader +from modules.modelLoader.mixin.HFModelLoaderMixin import HFModelLoaderMixin +from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin +from modules.util.config.TrainConfig import QuantizationConfig + +#from omi_model_standards.convert.lora.convert_flux_lora import convert_flux_lora_key_sets #TODO +from modules.util.convert.lora.convert_lora_util import LoraConversionKeySet +from modules.util.enum.ModelType import ModelType +from modules.util.ModelNames import ModelNames +from modules.util.ModelWeightDtypes import ModelWeightDtypes + +import torch + +from diffusers import ( + AutoencoderKLFlux2, + FlowMatchEulerDiscreteScheduler, + Flux2Transformer2DModel, + GGUFQuantizationConfig, +) +from transformers import ( + Mistral3ForConditionalGeneration, + PixtralProcessor, +) + + +class Flux2ModelLoader( + HFModelLoaderMixin, +): + def __init__(self): + super().__init__() + + def __load_internal( + self, + model: Flux2Model, + model_type: ModelType, + weight_dtypes: ModelWeightDtypes, + base_model_name: str, + transformer_model_name: str, + vae_model_name: str, + quantization: QuantizationConfig, + ): + if os.path.isfile(os.path.join(base_model_name, "meta.json")): + self.__load_diffusers( + model, model_type, weight_dtypes, base_model_name, transformer_model_name, vae_model_name, quantization, + ) + else: + raise Exception("not an internal model") + + def __load_diffusers( + self, + model: Flux2Model, + model_type: ModelType, + weight_dtypes: ModelWeightDtypes, + base_model_name: str, + transformer_model_name: str, + vae_model_name: str, + quantization: QuantizationConfig, + ): + diffusers_sub = [] + transformers_sub = ["text_encoder"] + if not transformer_model_name: + diffusers_sub.append("transformer") + if not vae_model_name: + diffusers_sub.append("vae") + + self._prepare_sub_modules( + base_model_name, + diffusers_modules=diffusers_sub, + transformers_modules=transformers_sub, + ) + + tokenizer = PixtralProcessor.from_pretrained( + base_model_name, + subfolder="tokenizer", + ).tokenizer + + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + base_model_name, + subfolder="scheduler", + ) + + text_encoder = self._load_transformers_sub_module( + Mistral3ForConditionalGeneration, + weight_dtypes.text_encoder, + weight_dtypes.fallback_train_dtype, + base_model_name, + "text_encoder", + ) + + if vae_model_name: + vae = self._load_diffusers_sub_module( + AutoencoderKLFlux2, + weight_dtypes.vae, + weight_dtypes.train_dtype, + vae_model_name, + ) + else: + vae = self._load_diffusers_sub_module( + AutoencoderKLFlux2, + weight_dtypes.vae, + weight_dtypes.train_dtype, + base_model_name, + "vae", + ) + + if transformer_model_name: + transformer = Flux2Transformer2DModel.from_single_file( + transformer_model_name, + #avoid loading the transformer in float32: + torch_dtype = torch.bfloat16 if weight_dtypes.transformer.torch_dtype() is None else weight_dtypes.transformer.torch_dtype(), + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.transformer.is_gguf() else None, + ) + transformer = self._convert_diffusers_sub_module_to_dtype( + transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quantization, + ) + else: + transformer = self._load_diffusers_sub_module( + Flux2Transformer2DModel, + weight_dtypes.transformer, + weight_dtypes.train_dtype, + base_model_name, + "transformer", + quantization, + ) + + model.model_type = model_type + model.tokenizer = tokenizer + model.noise_scheduler = noise_scheduler + model.text_encoder = text_encoder + model.vae = vae + model.transformer = transformer + + def __load_safetensors( + self, + model: Flux2Model, + model_type: ModelType, + weight_dtypes: ModelWeightDtypes, + base_model_name: str, + transformer_model_name: str, + vae_model_name: str, + quantization: QuantizationConfig, + ): + #no single file .safetensors for Qwen available at the time of writing this code + raise NotImplementedError("Loading of single file Flux2 models not supported. Use the diffusers model instead. Optionally, transformer-only safetensor files can be loaded by overriding the transformer.") + + def load( + self, + model: Flux2Model, + model_type: ModelType, + model_names: ModelNames, + weight_dtypes: ModelWeightDtypes, + quantization: QuantizationConfig, + ): + stacktraces = [] + + try: + self.__load_internal( + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quantization, + ) + return + except Exception: + stacktraces.append(traceback.format_exc()) + + try: + self.__load_diffusers( + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quantization, + ) + return + except Exception: + stacktraces.append(traceback.format_exc()) + + try: + self.__load_safetensors( + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quantization, + ) + return + except Exception: + stacktraces.append(traceback.format_exc()) + + for stacktrace in stacktraces: + print(stacktrace) + raise Exception("could not load model: " + model_names.base_model) + + + +class Flux2LoRALoader( + LoRALoaderMixin +): + def __init__(self): + super().__init__() + + def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None: + return None #TODO + #return convert_flux_lora_key_sets() + + def load( + self, + model: Flux2Model, + model_names: ModelNames, + ): + return self._load(model, model_names) + + +Flux2LoRAModelLoader = make_lora_model_loader( + model_spec_map={ + ModelType.FLUX_DEV_2: "resources/sd_model_spec/flux_dev_2.0-lora.json", + }, + model_class=Flux2Model, + model_loader_class=Flux2ModelLoader, + lora_loader_class=Flux2LoRALoader, + embedding_loader_class=None, +) + +Flux2FineTuneModelLoader = make_fine_tune_model_loader( + model_spec_map={ + ModelType.FLUX_DEV_2: "resources/sd_model_spec/flux_dev_2.0.json", + }, + model_class=Flux2Model, + model_loader_class=Flux2ModelLoader, + embedding_loader_class=None, +) diff --git a/modules/modelLoader/GenericEmbeddingModelLoader.py b/modules/modelLoader/GenericEmbeddingModelLoader.py index 560c7ff3e..019502fb4 100644 --- a/modules/modelLoader/GenericEmbeddingModelLoader.py +++ b/modules/modelLoader/GenericEmbeddingModelLoader.py @@ -2,8 +2,10 @@ from modules.modelLoader.BaseModelLoader import BaseModelLoader from modules.modelLoader.mixin.InternalModelLoaderMixin import InternalModelLoaderMixin from modules.modelLoader.mixin.ModelSpecModelLoaderMixin import ModelSpecModelLoaderMixin +from modules.util import factory from modules.util.config.TrainConfig import QuantizationConfig from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes @@ -47,4 +49,7 @@ def load( embedding_loader.load(model, model_names.embedding.model_name, model_names) return model + + for model_type in model_spec_map: + factory.register(BaseModelLoader, GenericEmbeddingModelLoader, model_type, TrainingMethod.EMBEDDING) return GenericEmbeddingModelLoader diff --git a/modules/modelLoader/GenericFineTuneModelLoader.py b/modules/modelLoader/GenericFineTuneModelLoader.py index 936719211..09915388f 100644 --- a/modules/modelLoader/GenericFineTuneModelLoader.py +++ b/modules/modelLoader/GenericFineTuneModelLoader.py @@ -2,8 +2,10 @@ from modules.modelLoader.BaseModelLoader import BaseModelLoader from modules.modelLoader.mixin.InternalModelLoaderMixin import InternalModelLoaderMixin from modules.modelLoader.mixin.ModelSpecModelLoaderMixin import ModelSpecModelLoaderMixin +from modules.util import factory from modules.util.config.TrainConfig import QuantizationConfig from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes @@ -13,7 +15,11 @@ def make_fine_tune_model_loader( model_class: type[BaseModel], model_loader_class: type, embedding_loader_class: type | None, + training_methods: list[TrainingMethod] = None, ): + if training_methods is None: + training_methods = [TrainingMethod.FINE_TUNE] + class GenericFineTuneModelLoader( BaseModelLoader, ModelSpecModelLoaderMixin, @@ -50,4 +56,7 @@ def load( return model + for model_type in model_spec_map: + for method in training_methods: + factory.register(BaseModelLoader, GenericFineTuneModelLoader, model_type, method) return GenericFineTuneModelLoader diff --git a/modules/modelLoader/GenericLoRAModelLoader.py b/modules/modelLoader/GenericLoRAModelLoader.py index efbe83982..d120eb008 100644 --- a/modules/modelLoader/GenericLoRAModelLoader.py +++ b/modules/modelLoader/GenericLoRAModelLoader.py @@ -2,8 +2,10 @@ from modules.modelLoader.BaseModelLoader import BaseModelLoader from modules.modelLoader.mixin.InternalModelLoaderMixin import InternalModelLoaderMixin from modules.modelLoader.mixin.ModelSpecModelLoaderMixin import ModelSpecModelLoaderMixin +from modules.util import factory from modules.util.config.TrainConfig import QuantizationConfig from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.ModelNames import ModelNames from modules.util.ModelWeightDtypes import ModelWeightDtypes @@ -53,4 +55,6 @@ def load( return model + for model_type in model_spec_map: + factory.register(BaseModelLoader, GenericLoRAModelLoader, model_type, TrainingMethod.LORA) return GenericLoRAModelLoader diff --git a/modules/modelLoader/StableDiffusionFineTuneModelLoader.py b/modules/modelLoader/StableDiffusionFineTuneModelLoader.py index 1307eedd4..bf18f3f5e 100644 --- a/modules/modelLoader/StableDiffusionFineTuneModelLoader.py +++ b/modules/modelLoader/StableDiffusionFineTuneModelLoader.py @@ -3,6 +3,7 @@ from modules.modelLoader.stableDiffusion.StableDiffusionEmbeddingLoader import StableDiffusionEmbeddingLoader from modules.modelLoader.stableDiffusion.StableDiffusionModelLoader import StableDiffusionModelLoader from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod StableDiffusionFineTuneModelLoader = make_fine_tune_model_loader( model_spec_map={ @@ -18,4 +19,5 @@ model_class=StableDiffusionModel, model_loader_class=StableDiffusionModelLoader, embedding_loader_class=StableDiffusionEmbeddingLoader, + training_methods=[TrainingMethod.FINE_TUNE, TrainingMethod.FINE_TUNE_VAE], ) diff --git a/modules/modelSampler/ChromaSampler.py b/modules/modelSampler/ChromaSampler.py index 3e72bba12..4b6033623 100644 --- a/modules/modelSampler/ChromaSampler.py +++ b/modules/modelSampler/ChromaSampler.py @@ -4,6 +4,7 @@ from modules.model.ChromaModel import ChromaModel from modules.modelSampler.BaseModelSampler import BaseModelSampler, ModelSamplerOutput +from modules.util import factory from modules.util.config.SampleConfig import SampleConfig from modules.util.enum.AudioFormat import AudioFormat from modules.util.enum.FileType import FileType @@ -188,3 +189,5 @@ def sample( ) on_sample(sampler_output) + +factory.register(BaseModelSampler, ChromaSampler, ModelType.CHROMA_1) diff --git a/modules/modelSampler/Flux2Sampler.py b/modules/modelSampler/Flux2Sampler.py new file mode 100644 index 000000000..33075e2fd --- /dev/null +++ b/modules/modelSampler/Flux2Sampler.py @@ -0,0 +1,190 @@ +import copy +import inspect +from collections.abc import Callable + +from modules.model.Flux2Model import Flux2Model +from modules.modelSampler.BaseModelSampler import BaseModelSampler, ModelSamplerOutput +from modules.util import factory +from modules.util.config.SampleConfig import SampleConfig +from modules.util.enum.AudioFormat import AudioFormat +from modules.util.enum.FileType import FileType +from modules.util.enum.ImageFormat import ImageFormat +from modules.util.enum.ModelType import ModelType +from modules.util.enum.NoiseScheduler import NoiseScheduler +from modules.util.enum.VideoFormat import VideoFormat +from modules.util.torch_util import torch_gc + +import torch + +from diffusers.pipelines.flux2.pipeline_flux2 import compute_empirical_mu + +import numpy as np +from tqdm import tqdm + + +class Flux2Sampler(BaseModelSampler): + def __init__( + self, + train_device: torch.device, + temp_device: torch.device, + model: Flux2Model, + model_type: ModelType, + ): + super().__init__(train_device, temp_device) + + self.model = model + self.model_type = model_type + self.pipeline = model.create_pipeline() + + @torch.no_grad() + def __sample_base( + self, + prompt: str, + height: int, + width: int, + seed: int, + random_seed: bool, + diffusion_steps: int, + cfg_scale: float, + noise_scheduler: NoiseScheduler, + text_encoder_sequence_length: int | None = None, + on_update_progress: Callable[[int, int], None] = lambda _, __: None, + ) -> ModelSamplerOutput: + with self.model.autocast_context: + generator = torch.Generator(device=self.train_device) + if random_seed: + generator.seed() + else: + generator.manual_seed(seed) + + noise_scheduler = copy.deepcopy(self.model.noise_scheduler) + image_processor = self.pipeline.image_processor + transformer = self.pipeline.transformer + vae = self.pipeline.vae + + vae_scale_factor = 8 + num_latent_channels = 32 + patch_size = 2 + + # prepare prompt + self.model.text_encoder_to(self.train_device) + + prompt_embedding = self.model.encode_text( + text=prompt, + train_device=self.train_device, + text_encoder_sequence_length=text_encoder_sequence_length, + ) + + self.model.text_encoder_to(self.temp_device) + torch_gc() + + # prepare latent image + latent_image = torch.randn( + size=(1, num_latent_channels, height // vae_scale_factor, width // vae_scale_factor), + generator=generator, + device=self.train_device, + dtype=torch.float32, + ) + + latent_image = self.model.patchify_latents(latent_image) + image_ids = self.model.prepare_latent_image_ids(latent_image) + + #TODO test dynamic timestep shifting instead of empirical + #shift = self.model.calculate_timestep_shift(latent_image.shape[-2], latent_image.shape[-1]) + #mu = math.log(shift) + + latent_image = self.model.pack_latents(latent_image) + image_seq_len = latent_image.shape[1] + mu = compute_empirical_mu(image_seq_len, diffusion_steps) + + # prepare timesteps + #TODO for other models, too? This is different than with sigmas=None + sigmas = np.linspace(1.0, 1 / diffusion_steps, diffusion_steps) + noise_scheduler.set_timesteps(diffusion_steps, device=self.train_device, mu=mu, sigmas=sigmas) + timesteps = noise_scheduler.timesteps + + # denoising loop + extra_step_kwargs = {} #TODO remove + if "generator" in set(inspect.signature(noise_scheduler.step).parameters.keys()): + extra_step_kwargs["generator"] = generator + + text_ids = self.model.prepare_text_ids(prompt_embedding) + + + self.model.transformer_to(self.train_device) + for i, timestep in enumerate(tqdm(timesteps, desc="sampling")): + latent_model_input = torch.cat([latent_image]) + expanded_timestep = timestep.expand(latent_model_input.shape[0]) + + guidance = torch.tensor([cfg_scale], device=self.train_device) + + noise_pred = transformer( + hidden_states=latent_model_input.to(dtype=self.model.train_dtype.torch_dtype()), + timestep=expanded_timestep / 1000, + guidance=guidance.to(dtype=self.model.train_dtype.torch_dtype()), + encoder_hidden_states=prompt_embedding.to(dtype=self.model.train_dtype.torch_dtype()), + txt_ids=text_ids, + img_ids=image_ids, + joint_attention_kwargs=None, + return_dict=True + ).sample + + latent_image = noise_scheduler.step(noise_pred, timestep, latent_image, return_dict=False, **extra_step_kwargs)[0] + + on_update_progress(i + 1, len(timesteps)) + + self.model.transformer_to(self.temp_device) + torch_gc() + self.model.vae_to(self.train_device) + + latent_image = self.model.unpack_latents( + latent_image, + height // vae_scale_factor // patch_size, + width // vae_scale_factor // patch_size, + ) + latents = self.model.unscale_latents(latent_image) + latents = self.model.unpatchify_latents(latents) + + image = vae.decode(latents, return_dict=False)[0] + + image = image_processor.postprocess(image, output_type='pil') + + self.model.vae_to(self.temp_device) + torch_gc() + + return ModelSamplerOutput( + file_type=FileType.IMAGE, + data=image[0], + ) + + def sample( + self, + sample_config: SampleConfig, + destination: str, + image_format: ImageFormat | None = None, + video_format: VideoFormat | None = None, + audio_format: AudioFormat | None = None, + on_sample: Callable[[ModelSamplerOutput], None] = lambda _: None, + on_update_progress: Callable[[int, int], None] = lambda _, __: None, + ): + sampler_output = self.__sample_base( + prompt=sample_config.prompt, + height=self.quantize_resolution(sample_config.height, 64), + width=self.quantize_resolution(sample_config.width, 64), + seed=sample_config.seed, + random_seed=sample_config.random_seed, + diffusion_steps=sample_config.diffusion_steps, + cfg_scale=sample_config.cfg_scale, + noise_scheduler=sample_config.noise_scheduler, + text_encoder_sequence_length=sample_config.text_encoder_1_sequence_length, + on_update_progress=on_update_progress, + ) + + self.save_sampler_output( + sampler_output, destination, + image_format, video_format, audio_format, + ) + + on_sample(sampler_output) + +factory.register(BaseModelSampler, Flux2Sampler, ModelType.FLUX_DEV_2) diff --git a/modules/modelSampler/FluxSampler.py b/modules/modelSampler/FluxSampler.py index 4fd07393b..93f8837ba 100644 --- a/modules/modelSampler/FluxSampler.py +++ b/modules/modelSampler/FluxSampler.py @@ -5,6 +5,7 @@ from modules.model.FluxModel import FluxModel from modules.modelSampler.BaseModelSampler import BaseModelSampler, ModelSamplerOutput +from modules.util import factory from modules.util.config.SampleConfig import SampleConfig from modules.util.enum.AudioFormat import AudioFormat from modules.util.enum.FileType import FileType @@ -146,7 +147,6 @@ def __sample_base( self.model.transformer_to(self.temp_device) torch_gc() - latent_image = self.model.unpack_latents( latent_image, height // vae_scale_factor, @@ -159,7 +159,7 @@ def __sample_base( latents = (latent_image / vae.config.scaling_factor) + vae.config.shift_factor image = vae.decode(latents, return_dict=False)[0] - do_denormalize = [True] * image.shape[0] + do_denormalize = [True] * image.shape[0] #TODO remove and test, from Flux and other models. True is the default image = image_processor.postprocess(image, output_type='pil', do_denormalize=do_denormalize) self.model.vae_to(self.temp_device) @@ -450,3 +450,6 @@ def sample( ) on_sample(sampler_output) + +factory.register(BaseModelSampler, FluxSampler, ModelType.FLUX_DEV_1) +factory.register(BaseModelSampler, FluxSampler, ModelType.FLUX_FILL_DEV_1) diff --git a/modules/modelSampler/HiDreamSampler.py b/modules/modelSampler/HiDreamSampler.py index af3a78291..d02197a35 100644 --- a/modules/modelSampler/HiDreamSampler.py +++ b/modules/modelSampler/HiDreamSampler.py @@ -4,6 +4,7 @@ from modules.model.HiDreamModel import HiDreamModel from modules.modelSampler.BaseModelSampler import BaseModelSampler, ModelSamplerOutput +from modules.util import factory from modules.util.config.SampleConfig import SampleConfig from modules.util.enum.AudioFormat import AudioFormat from modules.util.enum.FileType import FileType @@ -191,3 +192,5 @@ def sample( ) on_sample(sampler_output) + +factory.register(BaseModelSampler, HiDreamSampler, ModelType.HI_DREAM_FULL) diff --git a/modules/modelSampler/HunyuanVideoSampler.py b/modules/modelSampler/HunyuanVideoSampler.py index c7074fb7e..020bbe6d3 100644 --- a/modules/modelSampler/HunyuanVideoSampler.py +++ b/modules/modelSampler/HunyuanVideoSampler.py @@ -4,6 +4,7 @@ from modules.model.HunyuanVideoModel import HunyuanVideoModel from modules.modelSampler.BaseModelSampler import BaseModelSampler, ModelSamplerOutput +from modules.util import factory from modules.util.config.SampleConfig import SampleConfig from modules.util.enum.AudioFormat import AudioFormat from modules.util.enum.FileType import FileType @@ -205,3 +206,5 @@ def sample( ) on_sample(sampler_output) + +factory.register(BaseModelSampler, HunyuanVideoSampler, ModelType.HUNYUAN_VIDEO) diff --git a/modules/modelSampler/PixArtAlphaSampler.py b/modules/modelSampler/PixArtAlphaSampler.py index f90bced10..7bec3f59a 100644 --- a/modules/modelSampler/PixArtAlphaSampler.py +++ b/modules/modelSampler/PixArtAlphaSampler.py @@ -3,7 +3,7 @@ from modules.model.PixArtAlphaModel import PixArtAlphaModel from modules.modelSampler.BaseModelSampler import BaseModelSampler, ModelSamplerOutput -from modules.util import create +from modules.util import create, factory from modules.util.config.SampleConfig import SampleConfig from modules.util.enum.AudioFormat import AudioFormat from modules.util.enum.FileType import FileType @@ -191,3 +191,6 @@ def sample( ) on_sample(sampler_output) + +factory.register(BaseModelSampler, PixArtAlphaSampler, ModelType.PIXART_ALPHA) +factory.register(BaseModelSampler, PixArtAlphaSampler, ModelType.PIXART_SIGMA) diff --git a/modules/modelSampler/QwenSampler.py b/modules/modelSampler/QwenSampler.py index 798bc54cd..715e67528 100644 --- a/modules/modelSampler/QwenSampler.py +++ b/modules/modelSampler/QwenSampler.py @@ -5,6 +5,7 @@ from modules.model.QwenModel import QwenModel from modules.modelSampler.BaseModelSampler import BaseModelSampler, ModelSamplerOutput +from modules.util import factory from modules.util.config.SampleConfig import SampleConfig from modules.util.enum.AudioFormat import AudioFormat from modules.util.enum.FileType import FileType @@ -198,3 +199,5 @@ def sample( ) on_sample(sampler_output) + +factory.register(BaseModelSampler, QwenSampler, ModelType.QWEN) diff --git a/modules/modelSampler/SanaSampler.py b/modules/modelSampler/SanaSampler.py index 8e0d02d9f..edf7510cc 100644 --- a/modules/modelSampler/SanaSampler.py +++ b/modules/modelSampler/SanaSampler.py @@ -4,6 +4,7 @@ from modules.model.SanaModel import SanaModel from modules.modelSampler.BaseModelSampler import BaseModelSampler, ModelSamplerOutput +from modules.util import factory from modules.util.config.SampleConfig import SampleConfig from modules.util.enum.AudioFormat import AudioFormat from modules.util.enum.FileType import FileType @@ -176,3 +177,5 @@ def sample( ) on_sample(sampler_output) + +factory.register(BaseModelSampler, SanaSampler, ModelType.SANA) diff --git a/modules/modelSampler/StableDiffusion3Sampler.py b/modules/modelSampler/StableDiffusion3Sampler.py index 44953c5a4..4cbbe4cc1 100644 --- a/modules/modelSampler/StableDiffusion3Sampler.py +++ b/modules/modelSampler/StableDiffusion3Sampler.py @@ -4,6 +4,7 @@ from modules.model.StableDiffusion3Model import StableDiffusion3Model from modules.modelSampler.BaseModelSampler import BaseModelSampler, ModelSamplerOutput +from modules.util import factory from modules.util.config.SampleConfig import SampleConfig from modules.util.enum.AudioFormat import AudioFormat from modules.util.enum.FileType import FileType @@ -190,3 +191,6 @@ def sample( ) on_sample(sampler_output) + +factory.register(BaseModelSampler, StableDiffusion3Sampler, ModelType.STABLE_DIFFUSION_3) +factory.register(BaseModelSampler, StableDiffusion3Sampler, ModelType.STABLE_DIFFUSION_35) diff --git a/modules/modelSampler/StableDiffusionSampler.py b/modules/modelSampler/StableDiffusionSampler.py index aea4cd868..e16abeda2 100644 --- a/modules/modelSampler/StableDiffusionSampler.py +++ b/modules/modelSampler/StableDiffusionSampler.py @@ -3,7 +3,7 @@ from modules.model.StableDiffusionModel import StableDiffusionModel from modules.modelSampler.BaseModelSampler import BaseModelSampler, ModelSamplerOutput -from modules.util import create +from modules.util import create, factory from modules.util.config.SampleConfig import SampleConfig from modules.util.enum.AudioFormat import AudioFormat from modules.util.enum.FileType import FileType @@ -429,3 +429,12 @@ def sample( ) on_sample(sampler_output) + +factory.register(BaseModelSampler, StableDiffusionSampler, ModelType.STABLE_DIFFUSION_15) +factory.register(BaseModelSampler, StableDiffusionSampler, ModelType.STABLE_DIFFUSION_15_INPAINTING) +factory.register(BaseModelSampler, StableDiffusionSampler, ModelType.STABLE_DIFFUSION_20) +factory.register(BaseModelSampler, StableDiffusionSampler, ModelType.STABLE_DIFFUSION_20_BASE) +factory.register(BaseModelSampler, StableDiffusionSampler, ModelType.STABLE_DIFFUSION_20_INPAINTING) +factory.register(BaseModelSampler, StableDiffusionSampler, ModelType.STABLE_DIFFUSION_20_DEPTH) +factory.register(BaseModelSampler, StableDiffusionSampler, ModelType.STABLE_DIFFUSION_21) +factory.register(BaseModelSampler, StableDiffusionSampler, ModelType.STABLE_DIFFUSION_21_BASE) diff --git a/modules/modelSampler/StableDiffusionVaeSampler.py b/modules/modelSampler/StableDiffusionVaeSampler.py index 9b8fab3f5..afb39bcff 100644 --- a/modules/modelSampler/StableDiffusionVaeSampler.py +++ b/modules/modelSampler/StableDiffusionVaeSampler.py @@ -2,11 +2,13 @@ from modules.model.StableDiffusionModel import StableDiffusionModel from modules.modelSampler.BaseModelSampler import BaseModelSampler, ModelSamplerOutput +from modules.util import factory from modules.util.config.SampleConfig import SampleConfig from modules.util.enum.AudioFormat import AudioFormat from modules.util.enum.FileType import FileType from modules.util.enum.ImageFormat import ImageFormat from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.enum.VideoFormat import VideoFormat import torch @@ -76,3 +78,12 @@ def sample( ) on_sample(sampler_output) + +factory.register(BaseModelSampler, StableDiffusionVaeSampler, ModelType.STABLE_DIFFUSION_15, TrainingMethod.FINE_TUNE_VAE) +factory.register(BaseModelSampler, StableDiffusionVaeSampler, ModelType.STABLE_DIFFUSION_15_INPAINTING, TrainingMethod.FINE_TUNE_VAE) +factory.register(BaseModelSampler, StableDiffusionVaeSampler, ModelType.STABLE_DIFFUSION_20, TrainingMethod.FINE_TUNE_VAE) +factory.register(BaseModelSampler, StableDiffusionVaeSampler, ModelType.STABLE_DIFFUSION_20_BASE, TrainingMethod.FINE_TUNE_VAE) +factory.register(BaseModelSampler, StableDiffusionVaeSampler, ModelType.STABLE_DIFFUSION_20_INPAINTING, TrainingMethod.FINE_TUNE_VAE) +factory.register(BaseModelSampler, StableDiffusionVaeSampler, ModelType.STABLE_DIFFUSION_20_DEPTH, TrainingMethod.FINE_TUNE_VAE) +factory.register(BaseModelSampler, StableDiffusionVaeSampler, ModelType.STABLE_DIFFUSION_21, TrainingMethod.FINE_TUNE_VAE) +factory.register(BaseModelSampler, StableDiffusionVaeSampler, ModelType.STABLE_DIFFUSION_21_BASE, TrainingMethod.FINE_TUNE_VAE) diff --git a/modules/modelSampler/StableDiffusionXLSampler.py b/modules/modelSampler/StableDiffusionXLSampler.py index 0b5e26dc6..e157d3b0a 100644 --- a/modules/modelSampler/StableDiffusionXLSampler.py +++ b/modules/modelSampler/StableDiffusionXLSampler.py @@ -3,7 +3,7 @@ from modules.model.StableDiffusionXLModel import StableDiffusionXLModel from modules.modelSampler.BaseModelSampler import BaseModelSampler, ModelSamplerOutput -from modules.util import create +from modules.util import create, factory from modules.util.config.SampleConfig import SampleConfig from modules.util.enum.AudioFormat import AudioFormat from modules.util.enum.FileType import FileType @@ -501,3 +501,6 @@ def sample( ) on_sample(sampler_output) + +factory.register(BaseModelSampler, StableDiffusionXLSampler, ModelType.STABLE_DIFFUSION_XL_10_BASE) +factory.register(BaseModelSampler, StableDiffusionXLSampler, ModelType.STABLE_DIFFUSION_XL_10_BASE_INPAINTING) diff --git a/modules/modelSampler/WuerstchenSampler.py b/modules/modelSampler/WuerstchenSampler.py index 9aa23e3ad..061a1e22b 100644 --- a/modules/modelSampler/WuerstchenSampler.py +++ b/modules/modelSampler/WuerstchenSampler.py @@ -3,6 +3,7 @@ from modules.model.WuerstchenModel import WuerstchenModel from modules.modelSampler.BaseModelSampler import BaseModelSampler, ModelSamplerOutput +from modules.util import factory from modules.util.config.SampleConfig import SampleConfig from modules.util.enum.AudioFormat import AudioFormat from modules.util.enum.FileType import FileType @@ -364,3 +365,6 @@ def sample( ) on_sample(sampler_output) + +factory.register(BaseModelSampler, WuerstchenSampler, ModelType.WUERSTCHEN_2) +factory.register(BaseModelSampler, WuerstchenSampler, ModelType.STABLE_CASCADE_1) diff --git a/modules/modelSampler/ZImageSampler.py b/modules/modelSampler/ZImageSampler.py index 14a373537..4ac837649 100644 --- a/modules/modelSampler/ZImageSampler.py +++ b/modules/modelSampler/ZImageSampler.py @@ -4,6 +4,7 @@ from modules.model.ZImageModel import ZImageModel from modules.modelSampler.BaseModelSampler import BaseModelSampler, ModelSamplerOutput +from modules.util import factory from modules.util.config.SampleConfig import SampleConfig from modules.util.enum.AudioFormat import AudioFormat from modules.util.enum.FileType import FileType @@ -162,3 +163,5 @@ def sample( ) on_sample(sampler_output) + +factory.register(BaseModelSampler, ZImageSampler, ModelType.Z_IMAGE) diff --git a/modules/modelSaver/ChromaEmbeddingModelSaver.py b/modules/modelSaver/ChromaEmbeddingModelSaver.py index 032a3dfc6..142fe2a34 100644 --- a/modules/modelSaver/ChromaEmbeddingModelSaver.py +++ b/modules/modelSaver/ChromaEmbeddingModelSaver.py @@ -1,32 +1,10 @@ from modules.model.ChromaModel import ChromaModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver from modules.modelSaver.chroma.ChromaEmbeddingSaver import ChromaEmbeddingSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin -from modules.util.enum.ModelFormat import ModelFormat +from modules.modelSaver.GenericEmbeddingModelSaver import make_embedding_model_saver from modules.util.enum.ModelType import ModelType -import torch - - -class ChromaEmbeddingModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: ChromaModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - embedding_model_saver = ChromaEmbeddingSaver() - - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +ChromaEmbeddingModelSaver = make_embedding_model_saver( + ModelType.CHROMA_1, + model_class=ChromaModel, + embedding_saver_class=ChromaEmbeddingSaver, +) diff --git a/modules/modelSaver/ChromaFineTuneModelSaver.py b/modules/modelSaver/ChromaFineTuneModelSaver.py index b7e89745d..d3d1df6c0 100644 --- a/modules/modelSaver/ChromaFineTuneModelSaver.py +++ b/modules/modelSaver/ChromaFineTuneModelSaver.py @@ -1,34 +1,12 @@ from modules.model.ChromaModel import ChromaModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver from modules.modelSaver.chroma.ChromaEmbeddingSaver import ChromaEmbeddingSaver from modules.modelSaver.chroma.ChromaModelSaver import ChromaModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin -from modules.util.enum.ModelFormat import ModelFormat +from modules.modelSaver.GenericFineTuneModelSaver import make_fine_tune_model_saver from modules.util.enum.ModelType import ModelType -import torch - - -class ChromaFineTuneModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: ChromaModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - base_model_saver = ChromaModelSaver() - embedding_model_saver = ChromaEmbeddingSaver() - - base_model_saver.save(model, output_model_format, output_model_destination, dtype) - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +ChromaFineTuneModelSaver = make_fine_tune_model_saver( + ModelType.CHROMA_1, + model_class=ChromaModel, + model_saver_class=ChromaModelSaver, + embedding_saver_class=ChromaEmbeddingSaver, +) diff --git a/modules/modelSaver/ChromaLoRAModelSaver.py b/modules/modelSaver/ChromaLoRAModelSaver.py index 557747384..9f6af6d57 100644 --- a/modules/modelSaver/ChromaLoRAModelSaver.py +++ b/modules/modelSaver/ChromaLoRAModelSaver.py @@ -1,35 +1,12 @@ from modules.model.ChromaModel import ChromaModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver from modules.modelSaver.chroma.ChromaEmbeddingSaver import ChromaEmbeddingSaver from modules.modelSaver.chroma.ChromaLoRASaver import ChromaLoRASaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin -from modules.util.enum.ModelFormat import ModelFormat +from modules.modelSaver.GenericLoRAModelSaver import make_lora_model_saver from modules.util.enum.ModelType import ModelType -import torch - - -class ChromaLoRAModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: ChromaModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - lora_model_saver = ChromaLoRASaver() - embedding_model_saver = ChromaEmbeddingSaver() - - lora_model_saver.save(model, output_model_format, output_model_destination, dtype) - if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +ChromaLoRAModelSaver = make_lora_model_saver( + ModelType.CHROMA_1, + model_class=ChromaModel, + lora_saver_class=ChromaLoRASaver, + embedding_saver_class=ChromaEmbeddingSaver, +) diff --git a/modules/modelSaver/Flux2FineTuneModelSaver.py b/modules/modelSaver/Flux2FineTuneModelSaver.py new file mode 100644 index 000000000..4d86200da --- /dev/null +++ b/modules/modelSaver/Flux2FineTuneModelSaver.py @@ -0,0 +1,11 @@ +from modules.model.Flux2Model import Flux2Model +from modules.modelSaver.flux2.Flux2ModelSaver import Flux2ModelSaver +from modules.modelSaver.GenericFineTuneModelSaver import make_fine_tune_model_saver +from modules.util.enum.ModelType import ModelType + +Flux2FineTuneModelSaver = make_fine_tune_model_saver( + ModelType.FLUX_DEV_2, + model_class=Flux2Model, + model_saver_class=Flux2ModelSaver, + embedding_saver_class=None, +) diff --git a/modules/modelSaver/Flux2LoRAModelSaver.py b/modules/modelSaver/Flux2LoRAModelSaver.py new file mode 100644 index 000000000..7a3cbaf3c --- /dev/null +++ b/modules/modelSaver/Flux2LoRAModelSaver.py @@ -0,0 +1,11 @@ +from modules.model.Flux2Model import Flux2Model +from modules.modelSaver.flux2.Flux2LoRASaver import Flux2LoRASaver +from modules.modelSaver.GenericLoRAModelSaver import make_lora_model_saver +from modules.util.enum.ModelType import ModelType + +Flux2LoRAModelSaver = make_lora_model_saver( + ModelType.FLUX_DEV_2, + model_class=Flux2Model, + lora_saver_class=Flux2LoRASaver, + embedding_saver_class=None, +) diff --git a/modules/modelSaver/FluxEmbeddingModelSaver.py b/modules/modelSaver/FluxEmbeddingModelSaver.py index 445073de8..96fe530cf 100644 --- a/modules/modelSaver/FluxEmbeddingModelSaver.py +++ b/modules/modelSaver/FluxEmbeddingModelSaver.py @@ -1,32 +1,10 @@ from modules.model.FluxModel import FluxModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver from modules.modelSaver.flux.FluxEmbeddingSaver import FluxEmbeddingSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin -from modules.util.enum.ModelFormat import ModelFormat +from modules.modelSaver.GenericEmbeddingModelSaver import make_embedding_model_saver from modules.util.enum.ModelType import ModelType -import torch - - -class FluxEmbeddingModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: FluxModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - embedding_model_saver = FluxEmbeddingSaver() - - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +FluxEmbeddingModelSaver = make_embedding_model_saver( + [ModelType.FLUX_DEV_1, ModelType.FLUX_FILL_DEV_1], + model_class=FluxModel, + embedding_saver_class=FluxEmbeddingSaver, +) diff --git a/modules/modelSaver/FluxFineTuneModelSaver.py b/modules/modelSaver/FluxFineTuneModelSaver.py index 014c12184..e7c9ffc3e 100644 --- a/modules/modelSaver/FluxFineTuneModelSaver.py +++ b/modules/modelSaver/FluxFineTuneModelSaver.py @@ -1,34 +1,12 @@ from modules.model.FluxModel import FluxModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver from modules.modelSaver.flux.FluxEmbeddingSaver import FluxEmbeddingSaver from modules.modelSaver.flux.FluxModelSaver import FluxModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin -from modules.util.enum.ModelFormat import ModelFormat +from modules.modelSaver.GenericFineTuneModelSaver import make_fine_tune_model_saver from modules.util.enum.ModelType import ModelType -import torch - - -class FluxFineTuneModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: FluxModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - base_model_saver = FluxModelSaver() - embedding_model_saver = FluxEmbeddingSaver() - - base_model_saver.save(model, output_model_format, output_model_destination, dtype) - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +FluxFineTuneModelSaver = make_fine_tune_model_saver( + [ModelType.FLUX_DEV_1, ModelType.FLUX_FILL_DEV_1], + model_class=FluxModel, + model_saver_class=FluxModelSaver, + embedding_saver_class=FluxEmbeddingSaver, +) diff --git a/modules/modelSaver/FluxLoRAModelSaver.py b/modules/modelSaver/FluxLoRAModelSaver.py index 3e54160b3..a19eb3868 100644 --- a/modules/modelSaver/FluxLoRAModelSaver.py +++ b/modules/modelSaver/FluxLoRAModelSaver.py @@ -1,35 +1,12 @@ from modules.model.FluxModel import FluxModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver from modules.modelSaver.flux.FluxEmbeddingSaver import FluxEmbeddingSaver from modules.modelSaver.flux.FluxLoRASaver import FluxLoRASaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin -from modules.util.enum.ModelFormat import ModelFormat +from modules.modelSaver.GenericLoRAModelSaver import make_lora_model_saver from modules.util.enum.ModelType import ModelType -import torch - - -class FluxLoRAModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: FluxModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - lora_model_saver = FluxLoRASaver() - embedding_model_saver = FluxEmbeddingSaver() - - lora_model_saver.save(model, output_model_format, output_model_destination, dtype) - if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +FluxLoRAModelSaver = make_lora_model_saver( + [ModelType.FLUX_DEV_1, ModelType.FLUX_FILL_DEV_1], + model_class=FluxModel, + lora_saver_class=FluxLoRASaver, + embedding_saver_class=FluxEmbeddingSaver, +) diff --git a/modules/modelSaver/GenericEmbeddingModelSaver.py b/modules/modelSaver/GenericEmbeddingModelSaver.py new file mode 100644 index 000000000..094db0a7c --- /dev/null +++ b/modules/modelSaver/GenericEmbeddingModelSaver.py @@ -0,0 +1,46 @@ +from modules.model.BaseModel import BaseModel +from modules.modelSaver.BaseModelSaver import BaseModelSaver +from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.util import factory +from modules.util.enum.ModelFormat import ModelFormat +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod + +import torch + + +def make_embedding_model_saver( + model_types: list[ModelType] | ModelType, + model_class: type[BaseModel], + embedding_saver_class: type, +): + if not isinstance(model_types, list): + model_types = [model_types] + + class GenericEmbeddingModelSaver( + BaseModelSaver, + InternalModelSaverMixin, + ): + def __init__(self): + super().__init__() + + def save( + self, + model: model_class, + model_type: ModelType, + output_model_format: ModelFormat, + output_model_destination: str, + dtype: torch.dtype | None, + ): + embedding_model_saver = embedding_saver_class() + + embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) + embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) + + if output_model_format == ModelFormat.INTERNAL: + self._save_internal_data(model, output_model_destination) + + for model_type in model_types: + factory.register(BaseModelSaver, GenericEmbeddingModelSaver, model_type, TrainingMethod.EMBEDDING) + + return GenericEmbeddingModelSaver diff --git a/modules/modelSaver/GenericFineTuneModelSaver.py b/modules/modelSaver/GenericFineTuneModelSaver.py new file mode 100644 index 000000000..09267c61e --- /dev/null +++ b/modules/modelSaver/GenericFineTuneModelSaver.py @@ -0,0 +1,49 @@ +from modules.model.BaseModel import BaseModel +from modules.modelSaver.BaseModelSaver import BaseModelSaver +from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.util import factory +from modules.util.enum.ModelFormat import ModelFormat +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod + +import torch + + +def make_fine_tune_model_saver( + model_types: list[ModelType] | ModelType, + model_class: type[BaseModel], + model_saver_class: type, + embedding_saver_class: type | None, +): + if not isinstance(model_types, list): + model_types = [model_types] + + class GenericFineTuneModelSaver( + BaseModelSaver, + InternalModelSaverMixin, + ): + def __init__(self): + super().__init__() + + def save( + self, + model: model_class, + model_type: ModelType, + output_model_format: ModelFormat, + output_model_destination: str, + dtype: torch.dtype | None, + ): + base_model_saver = model_saver_class() + base_model_saver.save(model, output_model_format, output_model_destination, dtype) + + if embedding_saver_class is not None: + embedding_model_saver = embedding_saver_class() + embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) + + if output_model_format == ModelFormat.INTERNAL: + self._save_internal_data(model, output_model_destination) + + for model_type in model_types: + factory.register(BaseModelSaver, GenericFineTuneModelSaver, model_type, TrainingMethod.FINE_TUNE) + + return GenericFineTuneModelSaver diff --git a/modules/modelSaver/GenericLoRAModelSaver.py b/modules/modelSaver/GenericLoRAModelSaver.py new file mode 100644 index 000000000..59955dc8b --- /dev/null +++ b/modules/modelSaver/GenericLoRAModelSaver.py @@ -0,0 +1,50 @@ +from modules.model.BaseModel import BaseModel +from modules.modelSaver.BaseModelSaver import BaseModelSaver +from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.util import factory +from modules.util.enum.ModelFormat import ModelFormat +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod + +import torch + + +def make_lora_model_saver( + model_types: list[ModelType] | ModelType, + model_class: type[BaseModel], + lora_saver_class: type, + embedding_saver_class: type | None, +): + if not isinstance(model_types, list): + model_types = [model_types] + + class GenericLoRAModelSaver( + BaseModelSaver, + InternalModelSaverMixin, + ): + def __init__(self): + super().__init__() + + def save( + self, + model: model_class, + model_type: ModelType, + output_model_format: ModelFormat, + output_model_destination: str, + dtype: torch.dtype | None, + ): + lora_model_saver = lora_saver_class() + lora_model_saver.save(model, output_model_format, output_model_destination, dtype) + + if embedding_saver_class is not None: + embedding_model_saver = embedding_saver_class() + if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: + embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) + + if output_model_format == ModelFormat.INTERNAL: + self._save_internal_data(model, output_model_destination) + + for model_type in model_types: + factory.register(BaseModelSaver, GenericLoRAModelSaver, model_type, TrainingMethod.LORA) + + return GenericLoRAModelSaver diff --git a/modules/modelSaver/HiDreamEmbeddingModelSaver.py b/modules/modelSaver/HiDreamEmbeddingModelSaver.py index a6ab49cdd..68ab7514e 100644 --- a/modules/modelSaver/HiDreamEmbeddingModelSaver.py +++ b/modules/modelSaver/HiDreamEmbeddingModelSaver.py @@ -1,32 +1,10 @@ from modules.model.HiDreamModel import HiDreamModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver +from modules.modelSaver.GenericEmbeddingModelSaver import make_embedding_model_saver from modules.modelSaver.hidream.HiDreamEmbeddingSaver import HiDreamEmbeddingSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class HiDreamEmbeddingModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: HiDreamModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - embedding_model_saver = HiDreamEmbeddingSaver() - - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +HiDreamEmbeddingModelSaver = make_embedding_model_saver( + ModelType.HI_DREAM_FULL, + model_class=HiDreamModel, + embedding_saver_class=HiDreamEmbeddingSaver, +) diff --git a/modules/modelSaver/HiDreamLoRAModelSaver.py b/modules/modelSaver/HiDreamLoRAModelSaver.py index 3081d8365..6f5ad8f6d 100644 --- a/modules/modelSaver/HiDreamLoRAModelSaver.py +++ b/modules/modelSaver/HiDreamLoRAModelSaver.py @@ -1,35 +1,12 @@ from modules.model.HiDreamModel import HiDreamModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver +from modules.modelSaver.GenericLoRAModelSaver import make_lora_model_saver from modules.modelSaver.hidream.HiDreamEmbeddingSaver import HiDreamEmbeddingSaver from modules.modelSaver.hidream.HiDreamLoRASaver import HiDreamLoRASaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class HiDreamLoRAModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: HiDreamModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - lora_model_saver = HiDreamLoRASaver() - embedding_model_saver = HiDreamEmbeddingSaver() - - lora_model_saver.save(model, output_model_format, output_model_destination, dtype) - if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +HiDreamLoRAModelSaver = make_lora_model_saver( + ModelType.HI_DREAM_FULL, + model_class=HiDreamModel, + lora_saver_class=HiDreamLoRASaver, + embedding_saver_class=HiDreamEmbeddingSaver, +) diff --git a/modules/modelSaver/HunyuanVideoEmbeddingModelSaver.py b/modules/modelSaver/HunyuanVideoEmbeddingModelSaver.py index b99e97138..56cf30560 100644 --- a/modules/modelSaver/HunyuanVideoEmbeddingModelSaver.py +++ b/modules/modelSaver/HunyuanVideoEmbeddingModelSaver.py @@ -1,32 +1,10 @@ from modules.model.HunyuanVideoModel import HunyuanVideoModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver +from modules.modelSaver.GenericEmbeddingModelSaver import make_embedding_model_saver from modules.modelSaver.hunyuanVideo.HunyuanVideoEmbeddingSaver import HunyuanVideoEmbeddingSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class HunyuanVideoEmbeddingModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: HunyuanVideoModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - embedding_model_saver = HunyuanVideoEmbeddingSaver() - - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +HunyuanVideoEmbeddingModelSaver = make_embedding_model_saver( + ModelType.HUNYUAN_VIDEO, + model_class=HunyuanVideoModel, + embedding_saver_class=HunyuanVideoEmbeddingSaver, +) diff --git a/modules/modelSaver/HunyuanVideoFineTuneModelSaver.py b/modules/modelSaver/HunyuanVideoFineTuneModelSaver.py index dc2b21d84..53902e7f9 100644 --- a/modules/modelSaver/HunyuanVideoFineTuneModelSaver.py +++ b/modules/modelSaver/HunyuanVideoFineTuneModelSaver.py @@ -1,34 +1,12 @@ from modules.model.HunyuanVideoModel import HunyuanVideoModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver +from modules.modelSaver.GenericFineTuneModelSaver import make_fine_tune_model_saver from modules.modelSaver.hunyuanVideo.HunyuanVideoEmbeddingSaver import HunyuanVideoEmbeddingSaver from modules.modelSaver.hunyuanVideo.HunyuanVideoModelSaver import HunyuanVideoModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class HunyuanVideoFineTuneModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: HunyuanVideoModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - base_model_saver = HunyuanVideoModelSaver() - embedding_model_saver = HunyuanVideoEmbeddingSaver() - - base_model_saver.save(model, output_model_format, output_model_destination, dtype) - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +HunyuanVideoFineTuneModelSaver = make_fine_tune_model_saver( + ModelType.HUNYUAN_VIDEO, + model_class=HunyuanVideoModel, + model_saver_class=HunyuanVideoModelSaver, + embedding_saver_class=HunyuanVideoEmbeddingSaver, +) diff --git a/modules/modelSaver/HunyuanVideoLoRAModelSaver.py b/modules/modelSaver/HunyuanVideoLoRAModelSaver.py index 61862a12d..be0ee1820 100644 --- a/modules/modelSaver/HunyuanVideoLoRAModelSaver.py +++ b/modules/modelSaver/HunyuanVideoLoRAModelSaver.py @@ -1,36 +1,12 @@ - from modules.model.HunyuanVideoModel import HunyuanVideoModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver +from modules.modelSaver.GenericLoRAModelSaver import make_lora_model_saver from modules.modelSaver.hunyuanVideo.HunyuanVideoEmbeddingSaver import HunyuanVideoEmbeddingSaver from modules.modelSaver.hunyuanVideo.HunyuanVideoLoRASaver import HunyuanVideoLoRASaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class HunyuanVideoLoRAModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: HunyuanVideoModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - lora_model_saver = HunyuanVideoLoRASaver() - embedding_model_saver = HunyuanVideoEmbeddingSaver() - - lora_model_saver.save(model, output_model_format, output_model_destination, dtype) - if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +HunyuanVideoLoRAModelSaver = make_lora_model_saver( + ModelType.HUNYUAN_VIDEO, + model_class=HunyuanVideoModel, + lora_saver_class=HunyuanVideoLoRASaver, + embedding_saver_class=HunyuanVideoEmbeddingSaver, +) diff --git a/modules/modelSaver/PixArtAlphaEmbeddingModelSaver.py b/modules/modelSaver/PixArtAlphaEmbeddingModelSaver.py index b7a4100ee..b42b142e4 100644 --- a/modules/modelSaver/PixArtAlphaEmbeddingModelSaver.py +++ b/modules/modelSaver/PixArtAlphaEmbeddingModelSaver.py @@ -1,32 +1,10 @@ from modules.model.PixArtAlphaModel import PixArtAlphaModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.modelSaver.GenericEmbeddingModelSaver import make_embedding_model_saver from modules.modelSaver.pixartAlpha.PixArtAlphaEmbeddingSaver import PixArtAlphaEmbeddingSaver -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class PixArtAlphaEmbeddingModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: PixArtAlphaModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - embedding_model_saver = PixArtAlphaEmbeddingSaver() - - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +PixArtAlphaEmbeddingModelSaver = make_embedding_model_saver( + [ModelType.PIXART_ALPHA, ModelType.PIXART_SIGMA], + model_class=PixArtAlphaModel, + embedding_saver_class=PixArtAlphaEmbeddingSaver, +) diff --git a/modules/modelSaver/PixArtAlphaFineTuneModelSaver.py b/modules/modelSaver/PixArtAlphaFineTuneModelSaver.py index 8498111f7..ab8556d9d 100644 --- a/modules/modelSaver/PixArtAlphaFineTuneModelSaver.py +++ b/modules/modelSaver/PixArtAlphaFineTuneModelSaver.py @@ -1,34 +1,12 @@ from modules.model.PixArtAlphaModel import PixArtAlphaModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.modelSaver.GenericFineTuneModelSaver import make_fine_tune_model_saver from modules.modelSaver.pixartAlpha.PixArtAlphaEmbeddingSaver import PixArtAlphaEmbeddingSaver from modules.modelSaver.pixartAlpha.PixArtAlphaModelSaver import PixArtAlphaModelSaver -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class PixArtAlphaFineTuneModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: PixArtAlphaModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype, - ): - base_model_saver = PixArtAlphaModelSaver() - embedding_model_saver = PixArtAlphaEmbeddingSaver() - - base_model_saver.save(model, output_model_format, output_model_destination, dtype) - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +PixArtAlphaFineTuneModelSaver = make_fine_tune_model_saver( + [ModelType.PIXART_ALPHA, ModelType.PIXART_SIGMA], + model_class=PixArtAlphaModel, + model_saver_class=PixArtAlphaModelSaver, + embedding_saver_class=PixArtAlphaEmbeddingSaver, +) diff --git a/modules/modelSaver/PixArtAlphaLoRAModelSaver.py b/modules/modelSaver/PixArtAlphaLoRAModelSaver.py index 397b357af..797d5e219 100644 --- a/modules/modelSaver/PixArtAlphaLoRAModelSaver.py +++ b/modules/modelSaver/PixArtAlphaLoRAModelSaver.py @@ -1,35 +1,12 @@ from modules.model.PixArtAlphaModel import PixArtAlphaModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.modelSaver.GenericLoRAModelSaver import make_lora_model_saver from modules.modelSaver.pixartAlpha.PixArtAlphaEmbeddingSaver import PixArtAlphaEmbeddingSaver from modules.modelSaver.pixartAlpha.PixArtAlphaLoRASaver import PixArtAlphaLoRASaver -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class PixArtAlphaLoRAModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: PixArtAlphaModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype, - ): - lora_model_saver = PixArtAlphaLoRASaver() - embedding_model_saver = PixArtAlphaEmbeddingSaver() - - lora_model_saver.save(model, output_model_format, output_model_destination, dtype) - if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +PixArtAlphaLoRAModelSaver = make_lora_model_saver( + [ModelType.PIXART_ALPHA, ModelType.PIXART_SIGMA], + model_class=PixArtAlphaModel, + lora_saver_class=PixArtAlphaLoRASaver, + embedding_saver_class=PixArtAlphaEmbeddingSaver, +) diff --git a/modules/modelSaver/QwenFineTuneModelSaver.py b/modules/modelSaver/QwenFineTuneModelSaver.py index dc726252e..25ac3160e 100644 --- a/modules/modelSaver/QwenFineTuneModelSaver.py +++ b/modules/modelSaver/QwenFineTuneModelSaver.py @@ -1,31 +1,11 @@ from modules.model.QwenModel import QwenModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.modelSaver.GenericFineTuneModelSaver import make_fine_tune_model_saver from modules.modelSaver.qwen.QwenModelSaver import QwenModelSaver -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class QwenFineTuneModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: QwenModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - base_model_saver = QwenModelSaver() - - base_model_saver.save(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +QwenFineTuneModelSaver = make_fine_tune_model_saver( + ModelType.QWEN, + model_class=QwenModel, + model_saver_class=QwenModelSaver, + embedding_saver_class=None, +) diff --git a/modules/modelSaver/QwenLoRAModelSaver.py b/modules/modelSaver/QwenLoRAModelSaver.py index 4eb82dc5a..10090c3c1 100644 --- a/modules/modelSaver/QwenLoRAModelSaver.py +++ b/modules/modelSaver/QwenLoRAModelSaver.py @@ -1,30 +1,11 @@ from modules.model.QwenModel import QwenModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.modelSaver.GenericLoRAModelSaver import make_lora_model_saver from modules.modelSaver.qwen.QwenLoRASaver import QwenLoRASaver -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class QwenLoRAModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: QwenModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - lora_model_saver = QwenLoRASaver() - - lora_model_saver.save(model, output_model_format, output_model_destination, dtype) - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +QwenLoRAModelSaver = make_lora_model_saver( + ModelType.QWEN, + model_class=QwenModel, + lora_saver_class=QwenLoRASaver, + embedding_saver_class=None, +) diff --git a/modules/modelSaver/SanaEmbeddingModelSaver.py b/modules/modelSaver/SanaEmbeddingModelSaver.py index e2f7b931e..75a9a586a 100644 --- a/modules/modelSaver/SanaEmbeddingModelSaver.py +++ b/modules/modelSaver/SanaEmbeddingModelSaver.py @@ -1,32 +1,10 @@ from modules.model.SanaModel import SanaModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.modelSaver.GenericEmbeddingModelSaver import make_embedding_model_saver from modules.modelSaver.sana.SanaEmbeddingSaver import SanaEmbeddingSaver -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class SanaEmbeddingModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: SanaModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - embedding_model_saver = SanaEmbeddingSaver() - - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +SanaEmbeddingModelSaver = make_embedding_model_saver( + ModelType.SANA, + model_class=SanaModel, + embedding_saver_class=SanaEmbeddingSaver, +) diff --git a/modules/modelSaver/SanaFineTuneModelSaver.py b/modules/modelSaver/SanaFineTuneModelSaver.py index f03c8ecd5..f77d71c7b 100644 --- a/modules/modelSaver/SanaFineTuneModelSaver.py +++ b/modules/modelSaver/SanaFineTuneModelSaver.py @@ -1,34 +1,12 @@ from modules.model.SanaModel import SanaModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.modelSaver.GenericFineTuneModelSaver import make_fine_tune_model_saver from modules.modelSaver.sana.SanaEmbeddingSaver import SanaEmbeddingSaver from modules.modelSaver.sana.SanaModelSaver import SanaModelSaver -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class SanaFineTuneModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: SanaModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype, - ): - base_model_saver = SanaModelSaver() - embedding_model_saver = SanaEmbeddingSaver() - - base_model_saver.save(model, output_model_format, output_model_destination, dtype) - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +SanaFineTuneModelSaver = make_fine_tune_model_saver( + ModelType.SANA, + model_class=SanaModel, + model_saver_class=SanaModelSaver, + embedding_saver_class=SanaEmbeddingSaver, +) diff --git a/modules/modelSaver/SanaLoRAModelSaver.py b/modules/modelSaver/SanaLoRAModelSaver.py index 41f8f31fb..ee3722814 100644 --- a/modules/modelSaver/SanaLoRAModelSaver.py +++ b/modules/modelSaver/SanaLoRAModelSaver.py @@ -1,35 +1,12 @@ from modules.model.SanaModel import SanaModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.modelSaver.GenericLoRAModelSaver import make_lora_model_saver from modules.modelSaver.sana.SanaEmbeddingSaver import SanaEmbeddingSaver from modules.modelSaver.sana.SanaLoRASaver import SanaLoRASaver -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class SanaLoRAModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: SanaModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype, - ): - lora_model_saver = SanaLoRASaver() - embedding_model_saver = SanaEmbeddingSaver() - - lora_model_saver.save(model, output_model_format, output_model_destination, dtype) - if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +SanaLoRAModelSaver = make_lora_model_saver( + ModelType.SANA, + model_class=SanaModel, + lora_saver_class=SanaLoRASaver, + embedding_saver_class=SanaEmbeddingSaver, +) diff --git a/modules/modelSaver/StableDiffusion3EmbeddingModelSaver.py b/modules/modelSaver/StableDiffusion3EmbeddingModelSaver.py index da2e7d963..a03e1361d 100644 --- a/modules/modelSaver/StableDiffusion3EmbeddingModelSaver.py +++ b/modules/modelSaver/StableDiffusion3EmbeddingModelSaver.py @@ -1,32 +1,10 @@ from modules.model.StableDiffusion3Model import StableDiffusion3Model -from modules.modelSaver.BaseModelSaver import BaseModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.modelSaver.GenericEmbeddingModelSaver import make_embedding_model_saver from modules.modelSaver.stableDiffusion3.StableDiffusion3EmbeddingSaver import StableDiffusion3EmbeddingSaver -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class StableDiffusion3EmbeddingModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: StableDiffusion3Model, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - embedding_model_saver = StableDiffusion3EmbeddingSaver() - - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +StableDiffusion3EmbeddingModelSaver = make_embedding_model_saver( + [ModelType.STABLE_DIFFUSION_3, ModelType.STABLE_DIFFUSION_35], + model_class=StableDiffusion3Model, + embedding_saver_class=StableDiffusion3EmbeddingSaver, +) diff --git a/modules/modelSaver/StableDiffusion3FineTuneModelSaver.py b/modules/modelSaver/StableDiffusion3FineTuneModelSaver.py index bc69856b7..13083daec 100644 --- a/modules/modelSaver/StableDiffusion3FineTuneModelSaver.py +++ b/modules/modelSaver/StableDiffusion3FineTuneModelSaver.py @@ -1,34 +1,12 @@ from modules.model.StableDiffusion3Model import StableDiffusion3Model -from modules.modelSaver.BaseModelSaver import BaseModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.modelSaver.GenericFineTuneModelSaver import make_fine_tune_model_saver from modules.modelSaver.stableDiffusion3.StableDiffusion3EmbeddingSaver import StableDiffusion3EmbeddingSaver from modules.modelSaver.stableDiffusion3.StableDiffusion3ModelSaver import StableDiffusion3ModelSaver -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class StableDiffusion3FineTuneModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: StableDiffusion3Model, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - base_model_saver = StableDiffusion3ModelSaver() - embedding_model_saver = StableDiffusion3EmbeddingSaver() - - base_model_saver.save(model, output_model_format, output_model_destination, dtype) - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +StableDiffusion3FineTuneModelSaver = make_fine_tune_model_saver( + [ModelType.STABLE_DIFFUSION_3, ModelType.STABLE_DIFFUSION_35], + model_class=StableDiffusion3Model, + model_saver_class=StableDiffusion3ModelSaver, + embedding_saver_class=StableDiffusion3EmbeddingSaver, +) diff --git a/modules/modelSaver/StableDiffusion3LoRAModelSaver.py b/modules/modelSaver/StableDiffusion3LoRAModelSaver.py index 7d2966bbc..0ac11052b 100644 --- a/modules/modelSaver/StableDiffusion3LoRAModelSaver.py +++ b/modules/modelSaver/StableDiffusion3LoRAModelSaver.py @@ -1,35 +1,12 @@ from modules.model.StableDiffusion3Model import StableDiffusion3Model -from modules.modelSaver.BaseModelSaver import BaseModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.modelSaver.GenericLoRAModelSaver import make_lora_model_saver from modules.modelSaver.stableDiffusion3.StableDiffusion3EmbeddingSaver import StableDiffusion3EmbeddingSaver from modules.modelSaver.stableDiffusion3.StableDiffusion3LoRASaver import StableDiffusion3LoRASaver -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class StableDiffusion3LoRAModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: StableDiffusion3Model, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - lora_model_saver = StableDiffusion3LoRASaver() - embedding_model_saver = StableDiffusion3EmbeddingSaver() - - lora_model_saver.save(model, output_model_format, output_model_destination, dtype) - if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +StableDiffusion3LoRAModelSaver = make_lora_model_saver( + [ModelType.STABLE_DIFFUSION_3, ModelType.STABLE_DIFFUSION_35], + model_class=StableDiffusion3Model, + lora_saver_class=StableDiffusion3LoRASaver, + embedding_saver_class=StableDiffusion3EmbeddingSaver, +) diff --git a/modules/modelSaver/StableDiffusionEmbeddingModelSaver.py b/modules/modelSaver/StableDiffusionEmbeddingModelSaver.py index 7149aa844..cb77ac049 100644 --- a/modules/modelSaver/StableDiffusionEmbeddingModelSaver.py +++ b/modules/modelSaver/StableDiffusionEmbeddingModelSaver.py @@ -1,32 +1,11 @@ from modules.model.StableDiffusionModel import StableDiffusionModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.modelSaver.GenericEmbeddingModelSaver import make_embedding_model_saver from modules.modelSaver.stableDiffusion.StableDiffusionEmbeddingSaver import StableDiffusionEmbeddingSaver -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class StableDiffusionEmbeddingModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: StableDiffusionModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - embedding_model_saver = StableDiffusionEmbeddingSaver() - - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +StableDiffusionEmbeddingModelSaver = make_embedding_model_saver( + [ModelType.STABLE_DIFFUSION_15, ModelType.STABLE_DIFFUSION_15_INPAINTING, ModelType.STABLE_DIFFUSION_20, ModelType.STABLE_DIFFUSION_20_BASE, + ModelType.STABLE_DIFFUSION_20_INPAINTING, ModelType.STABLE_DIFFUSION_20_DEPTH, ModelType.STABLE_DIFFUSION_21, ModelType.STABLE_DIFFUSION_21_BASE], + model_class=StableDiffusionModel, + embedding_saver_class=StableDiffusionEmbeddingSaver, +) diff --git a/modules/modelSaver/StableDiffusionFineTuneModelSaver.py b/modules/modelSaver/StableDiffusionFineTuneModelSaver.py index 7ae92de3f..66e8f6510 100644 --- a/modules/modelSaver/StableDiffusionFineTuneModelSaver.py +++ b/modules/modelSaver/StableDiffusionFineTuneModelSaver.py @@ -1,34 +1,23 @@ from modules.model.StableDiffusionModel import StableDiffusionModel from modules.modelSaver.BaseModelSaver import BaseModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.modelSaver.GenericFineTuneModelSaver import make_fine_tune_model_saver from modules.modelSaver.stableDiffusion.StableDiffusionEmbeddingSaver import StableDiffusionEmbeddingSaver from modules.modelSaver.stableDiffusion.StableDiffusionModelSaver import StableDiffusionModelSaver -from modules.util.enum.ModelFormat import ModelFormat +from modules.util import factory from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod -import torch +model_types = [ModelType.STABLE_DIFFUSION_15, ModelType.STABLE_DIFFUSION_15_INPAINTING, ModelType.STABLE_DIFFUSION_20, ModelType.STABLE_DIFFUSION_20_BASE, + ModelType.STABLE_DIFFUSION_20_INPAINTING, ModelType.STABLE_DIFFUSION_20_DEPTH, ModelType.STABLE_DIFFUSION_21, ModelType.STABLE_DIFFUSION_21_BASE], -class StableDiffusionFineTuneModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() +StableDiffusionFineTuneModelSaver = make_fine_tune_model_saver( + model_types, + model_class=StableDiffusionModel, + model_saver_class=StableDiffusionModelSaver, + embedding_saver_class=StableDiffusionEmbeddingSaver, +) - def save( - self, - model: StableDiffusionModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - base_model_saver = StableDiffusionModelSaver() - embedding_model_saver = StableDiffusionEmbeddingSaver() - - base_model_saver.save(model, model_type, output_model_format, output_model_destination, dtype) - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +#make_fine_tune_model_saver only registers for FINE_TUNE: +for model_type in model_types: + factory.register(BaseModelSaver, StableDiffusionFineTuneModelSaver, model_type, TrainingMethod.FINE_TUNE_VAE) diff --git a/modules/modelSaver/StableDiffusionLoRAModelSaver.py b/modules/modelSaver/StableDiffusionLoRAModelSaver.py index 337a0e902..fe4c37611 100644 --- a/modules/modelSaver/StableDiffusionLoRAModelSaver.py +++ b/modules/modelSaver/StableDiffusionLoRAModelSaver.py @@ -1,35 +1,13 @@ from modules.model.StableDiffusionModel import StableDiffusionModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.modelSaver.GenericLoRAModelSaver import make_lora_model_saver from modules.modelSaver.stableDiffusion.StableDiffusionEmbeddingSaver import StableDiffusionEmbeddingSaver from modules.modelSaver.stableDiffusion.StableDiffusionLoRASaver import StableDiffusionLoRASaver -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class StableDiffusionLoRAModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: StableDiffusionModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - lora_model_saver = StableDiffusionLoRASaver() - embedding_model_saver = StableDiffusionEmbeddingSaver() - - lora_model_saver.save(model, output_model_format, output_model_destination, dtype) - if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +StableDiffusionLoRAModelSaver = make_lora_model_saver( + [ModelType.STABLE_DIFFUSION_15, ModelType.STABLE_DIFFUSION_15_INPAINTING, ModelType.STABLE_DIFFUSION_20, ModelType.STABLE_DIFFUSION_20_BASE, + ModelType.STABLE_DIFFUSION_20_INPAINTING, ModelType.STABLE_DIFFUSION_20_DEPTH, ModelType.STABLE_DIFFUSION_21, ModelType.STABLE_DIFFUSION_21_BASE], + model_class=StableDiffusionModel, + lora_saver_class=StableDiffusionLoRASaver, + embedding_saver_class=StableDiffusionEmbeddingSaver, +) diff --git a/modules/modelSaver/StableDiffusionXLEmbeddingModelSaver.py b/modules/modelSaver/StableDiffusionXLEmbeddingModelSaver.py index 24035c552..62ecb53cc 100644 --- a/modules/modelSaver/StableDiffusionXLEmbeddingModelSaver.py +++ b/modules/modelSaver/StableDiffusionXLEmbeddingModelSaver.py @@ -1,32 +1,10 @@ from modules.model.StableDiffusionXLModel import StableDiffusionXLModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.modelSaver.GenericEmbeddingModelSaver import make_embedding_model_saver from modules.modelSaver.stableDiffusionXL.StableDiffusionXLEmbeddingSaver import StableDiffusionXLEmbeddingSaver -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class StableDiffusionXLEmbeddingModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: StableDiffusionXLModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - embedding_model_saver = StableDiffusionXLEmbeddingSaver() - - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +StableDiffusionXLEmbeddingModelSaver = make_embedding_model_saver( + [ModelType.STABLE_DIFFUSION_XL_10_BASE, ModelType.STABLE_DIFFUSION_XL_10_BASE_INPAINTING], + model_class=StableDiffusionXLModel, + embedding_saver_class=StableDiffusionXLEmbeddingSaver, +) diff --git a/modules/modelSaver/StableDiffusionXLFineTuneModelSaver.py b/modules/modelSaver/StableDiffusionXLFineTuneModelSaver.py index c0417f076..714fdda6e 100644 --- a/modules/modelSaver/StableDiffusionXLFineTuneModelSaver.py +++ b/modules/modelSaver/StableDiffusionXLFineTuneModelSaver.py @@ -1,34 +1,12 @@ from modules.model.StableDiffusionXLModel import StableDiffusionXLModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.modelSaver.GenericFineTuneModelSaver import make_fine_tune_model_saver from modules.modelSaver.stableDiffusionXL.StableDiffusionXLEmbeddingSaver import StableDiffusionXLEmbeddingSaver from modules.modelSaver.stableDiffusionXL.StableDiffusionXLModelSaver import StableDiffusionXLModelSaver -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class StableDiffusionXLFineTuneModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: StableDiffusionXLModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - base_model_saver = StableDiffusionXLModelSaver() - embedding_model_saver = StableDiffusionXLEmbeddingSaver() - - base_model_saver.save(model, output_model_format, output_model_destination, dtype) - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +StableDiffusionXLFineTuneModelSaver = make_fine_tune_model_saver( + [ModelType.STABLE_DIFFUSION_XL_10_BASE, ModelType.STABLE_DIFFUSION_XL_10_BASE_INPAINTING], + model_class=StableDiffusionXLModel, + model_saver_class=StableDiffusionXLModelSaver, + embedding_saver_class=StableDiffusionXLEmbeddingSaver, +) diff --git a/modules/modelSaver/StableDiffusionXLLoRAModelSaver.py b/modules/modelSaver/StableDiffusionXLLoRAModelSaver.py index fe7ebf679..a963f3cbe 100644 --- a/modules/modelSaver/StableDiffusionXLLoRAModelSaver.py +++ b/modules/modelSaver/StableDiffusionXLLoRAModelSaver.py @@ -1,35 +1,12 @@ from modules.model.StableDiffusionXLModel import StableDiffusionXLModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.modelSaver.GenericLoRAModelSaver import make_lora_model_saver from modules.modelSaver.stableDiffusionXL.StableDiffusionXLEmbeddingSaver import StableDiffusionXLEmbeddingSaver from modules.modelSaver.stableDiffusionXL.StableDiffusionXLLoRASaver import StableDiffusionXLLoRASaver -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class StableDiffusionXLLoRAModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: StableDiffusionXLModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - lora_model_saver = StableDiffusionXLLoRASaver() - embedding_model_saver = StableDiffusionXLEmbeddingSaver() - - lora_model_saver.save(model, output_model_format, output_model_destination, dtype) - if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +StableDiffusionXLLoRAModelSaver = make_lora_model_saver( + [ModelType.STABLE_DIFFUSION_XL_10_BASE, ModelType.STABLE_DIFFUSION_XL_10_BASE_INPAINTING], + model_class=StableDiffusionXLModel, + lora_saver_class=StableDiffusionXLLoRASaver, + embedding_saver_class=StableDiffusionXLEmbeddingSaver, +) diff --git a/modules/modelSaver/WuerstchenEmbeddingModelSaver.py b/modules/modelSaver/WuerstchenEmbeddingModelSaver.py index 93dfb3d5e..a4ddc31e7 100644 --- a/modules/modelSaver/WuerstchenEmbeddingModelSaver.py +++ b/modules/modelSaver/WuerstchenEmbeddingModelSaver.py @@ -1,32 +1,10 @@ from modules.model.WuerstchenModel import WuerstchenModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.modelSaver.GenericEmbeddingModelSaver import make_embedding_model_saver from modules.modelSaver.wuerstchen.WuerstchenEmbeddingSaver import WuerstchenEmbeddingSaver -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class WuerstchenEmbeddingModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: WuerstchenModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype, - ): - embedding_model_saver = WuerstchenEmbeddingSaver() - - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +WuerstchenEmbeddingModelSaver = make_embedding_model_saver( + [ModelType.WUERSTCHEN_2, ModelType.STABLE_CASCADE_1], + model_class=WuerstchenModel, + embedding_saver_class=WuerstchenEmbeddingSaver, +) diff --git a/modules/modelSaver/WuerstchenFineTuneModelSaver.py b/modules/modelSaver/WuerstchenFineTuneModelSaver.py index 23d1aeee1..6233a6b33 100644 --- a/modules/modelSaver/WuerstchenFineTuneModelSaver.py +++ b/modules/modelSaver/WuerstchenFineTuneModelSaver.py @@ -1,34 +1,12 @@ from modules.model.WuerstchenModel import WuerstchenModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.modelSaver.GenericFineTuneModelSaver import make_fine_tune_model_saver from modules.modelSaver.wuerstchen.WuerstchenEmbeddingSaver import WuerstchenEmbeddingSaver from modules.modelSaver.wuerstchen.WuerstchenModelSaver import WuerstchenModelSaver -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class WuerstchenFineTuneModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: WuerstchenModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype, - ): - base_model_saver = WuerstchenModelSaver() - embedding_model_saver = WuerstchenEmbeddingSaver() - - base_model_saver.save(model, output_model_format, output_model_destination, dtype) - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +WuerstchenFineTuneModelSaver = make_fine_tune_model_saver( + [ModelType.WUERSTCHEN_2, ModelType.STABLE_CASCADE_1], + model_class=WuerstchenModel, + model_saver_class=WuerstchenModelSaver, + embedding_saver_class=WuerstchenEmbeddingSaver, +) diff --git a/modules/modelSaver/WuerstchenLoRAModelSaver.py b/modules/modelSaver/WuerstchenLoRAModelSaver.py index a123429c9..ffb0567f0 100644 --- a/modules/modelSaver/WuerstchenLoRAModelSaver.py +++ b/modules/modelSaver/WuerstchenLoRAModelSaver.py @@ -1,35 +1,12 @@ from modules.model.WuerstchenModel import WuerstchenModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.modelSaver.GenericLoRAModelSaver import make_lora_model_saver from modules.modelSaver.wuerstchen.WuerstchenEmbeddingSaver import WuerstchenEmbeddingSaver from modules.modelSaver.wuerstchen.WuerstchenLoRASaver import WuerstchenLoRASaver -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class WuerstchenLoRAModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: WuerstchenModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype, - ): - lora_model_saver = WuerstchenLoRASaver() - embedding_model_saver = WuerstchenEmbeddingSaver() - - lora_model_saver.save(model, output_model_format, output_model_destination, dtype) - if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL: - embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +WuerstchenLoRAModelSaver = make_lora_model_saver( + [ModelType.WUERSTCHEN_2, ModelType.STABLE_CASCADE_1], + model_class=WuerstchenModel, + lora_saver_class=WuerstchenLoRASaver, + embedding_saver_class=WuerstchenEmbeddingSaver, +) diff --git a/modules/modelSaver/ZImageFineTuneModelSaver.py b/modules/modelSaver/ZImageFineTuneModelSaver.py index 02660e7e1..0bafe31ad 100644 --- a/modules/modelSaver/ZImageFineTuneModelSaver.py +++ b/modules/modelSaver/ZImageFineTuneModelSaver.py @@ -1,31 +1,11 @@ from modules.model.ZImageModel import ZImageModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.modelSaver.GenericFineTuneModelSaver import make_fine_tune_model_saver from modules.modelSaver.zImage.ZImageModelSaver import ZImageModelSaver -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class ZImageFineTuneModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: ZImageModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - base_model_saver = ZImageModelSaver() - - base_model_saver.save(model, output_model_format, output_model_destination, dtype) - - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +ZImageFineTuneModelSaver = make_fine_tune_model_saver( + ModelType.Z_IMAGE, + model_class=ZImageModel, + model_saver_class=ZImageModelSaver, + embedding_saver_class=None, +) diff --git a/modules/modelSaver/ZImageLoRAModelSaver.py b/modules/modelSaver/ZImageLoRAModelSaver.py index 67292be53..ae760d2a3 100644 --- a/modules/modelSaver/ZImageLoRAModelSaver.py +++ b/modules/modelSaver/ZImageLoRAModelSaver.py @@ -1,30 +1,11 @@ from modules.model.ZImageModel import ZImageModel -from modules.modelSaver.BaseModelSaver import BaseModelSaver -from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin +from modules.modelSaver.GenericLoRAModelSaver import make_lora_model_saver from modules.modelSaver.zImage.ZImageLoRASaver import ZImageLoRASaver -from modules.util.enum.ModelFormat import ModelFormat from modules.util.enum.ModelType import ModelType -import torch - - -class ZImageLoRAModelSaver( - BaseModelSaver, - InternalModelSaverMixin, -): - def __init__(self): - super().__init__() - - def save( - self, - model: ZImageModel, - model_type: ModelType, - output_model_format: ModelFormat, - output_model_destination: str, - dtype: torch.dtype | None, - ): - lora_model_saver = ZImageLoRASaver() - - lora_model_saver.save(model, output_model_format, output_model_destination, dtype) - if output_model_format == ModelFormat.INTERNAL: - self._save_internal_data(model, output_model_destination) +ZImageLoRAModelSaver = make_lora_model_saver( + ModelType.Z_IMAGE, + model_class=ZImageModel, + lora_saver_class=ZImageLoRASaver, + embedding_saver_class=None, +) diff --git a/modules/modelSaver/flux2/Flux2LoRASaver.py b/modules/modelSaver/flux2/Flux2LoRASaver.py new file mode 100644 index 000000000..15471a82c --- /dev/null +++ b/modules/modelSaver/flux2/Flux2LoRASaver.py @@ -0,0 +1,52 @@ +import os +from pathlib import Path + +from modules.model.Flux2Model import Flux2Model, diffusers_lora_to_comfy +from modules.modelSaver.mixin.LoRASaverMixin import LoRASaverMixin +from modules.util.convert.lora.convert_lora_util import LoraConversionKeySet +from modules.util.convert_util import convert +from modules.util.enum.ModelFormat import ModelFormat + +import torch +from torch import Tensor + +from safetensors.torch import save_file + + +class Flux2LoRASaver( + LoRASaverMixin, +): + def __init__(self): + super().__init__() + + def _get_convert_key_sets(self, model: Flux2Model) -> list[LoraConversionKeySet] | None: + return None + + def _get_state_dict( + self, + model: Flux2Model, + ) -> dict[str, Tensor]: + state_dict = {} + if model.transformer_lora is not None: + state_dict |= model.transformer_lora.state_dict() + if model.lora_state_dict is not None: + state_dict |= model.lora_state_dict + + return state_dict + + def save( + self, + model: Flux2Model, + output_model_format: ModelFormat, + output_model_destination: str, + dtype: torch.dtype | None, + ): + if output_model_format == ModelFormat.COMFY_LORA: + state_dict = self._get_state_dict(model) + save_state_dict = self._convert_state_dict_dtype(state_dict, dtype) + save_state_dict = convert(save_state_dict, diffusers_lora_to_comfy) + + os.makedirs(Path(output_model_destination).parent.absolute(), exist_ok=True) + save_file(save_state_dict, output_model_destination, self._create_safetensors_header(model, save_state_dict)) + else: + self._save(model, output_model_format, output_model_destination, dtype) diff --git a/modules/modelSaver/flux2/Flux2ModelSaver.py b/modules/modelSaver/flux2/Flux2ModelSaver.py new file mode 100644 index 000000000..e2976244e --- /dev/null +++ b/modules/modelSaver/flux2/Flux2ModelSaver.py @@ -0,0 +1,85 @@ +import copy +import os.path +from pathlib import Path + +from modules.model.Flux2Model import Flux2Model, diffusers_checkpoint_to_original +from modules.modelSaver.mixin.DtypeModelSaverMixin import DtypeModelSaverMixin +from modules.util.convert_util import convert +from modules.util.enum.ModelFormat import ModelFormat + +import torch + +from safetensors.torch import save_file + + +class Flux2ModelSaver( + DtypeModelSaverMixin, +): + def __init__(self): + super().__init__() + + def __save_diffusers( + self, + model: Flux2Model, + destination: str, + dtype: torch.dtype | None, + ): + # Copy the model to cpu by first moving the original model to cpu. This preserves some VRAM. + pipeline = model.create_pipeline() + pipeline.to("cpu") + if dtype is not None: #TODO necessary? + # replace the tokenizers __deepcopy__ before calling deepcopy, to prevent a copy being made. + # the tokenizer tries to reload from the file system otherwise + tokenizer = pipeline.tokenizer + tokenizer.__deepcopy__ = lambda memo: tokenizer + + save_pipeline = copy.deepcopy(pipeline) + save_pipeline.to(device="cpu", dtype=dtype, silence_dtype_warnings=True) + + delattr(tokenizer, '__deepcopy__') + else: + save_pipeline = pipeline + + os.makedirs(Path(destination).absolute(), exist_ok=True) + save_pipeline.save_pretrained(destination) + + if dtype is not None: + del save_pipeline + + def __save_safetensors( + self, + model: Flux2Model, + destination: str, + dtype: torch.dtype | None, + ): + state_dict = model.transformer.state_dict() + state_dict = convert(state_dict, diffusers_checkpoint_to_original) + + save_state_dict = self._convert_state_dict_dtype(state_dict, dtype) + self._convert_state_dict_to_contiguous(save_state_dict) + + os.makedirs(Path(destination).parent.absolute(), exist_ok=True) + + save_file(save_state_dict, destination, self._create_safetensors_header(model, save_state_dict)) + + def __save_internal( + self, + model: Flux2Model, + destination: str, + ): + self.__save_diffusers(model, destination, None) + + def save( + self, + model: Flux2Model, + output_model_format: ModelFormat, + output_model_destination: str, + dtype: torch.dtype | None, + ): + match output_model_format: + case ModelFormat.DIFFUSERS: + self.__save_diffusers(model, output_model_destination, dtype) + case ModelFormat.SAFETENSORS: + self.__save_safetensors(model, output_model_destination, dtype) + case ModelFormat.INTERNAL: + self.__save_internal(model, output_model_destination) diff --git a/modules/modelSetup/BaseChromaSetup.py b/modules/modelSetup/BaseChromaSetup.py index 7a7847df7..cf1d5b68d 100644 --- a/modules/modelSetup/BaseChromaSetup.py +++ b/modules/modelSetup/BaseChromaSetup.py @@ -19,18 +19,12 @@ from modules.util.dtype_util import create_autocast_context, disable_fp16_autocast_context from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.quantization_util import quantize_layers +from modules.util.torch_util import torch_gc from modules.util.TrainProgress import TrainProgress import torch from torch import Tensor -PRESETS = { - "attn-mlp": ["attn", "ff.net"], - "attn-only": ["attn"], - "blocks": ["transformer_block"], - "full": [], -} - #TODO share more code with Flux and other models class BaseChromaSetup( @@ -42,6 +36,12 @@ class BaseChromaSetup( ModelSetupEmbeddingMixin, metaclass=ABCMeta ): + LAYER_PRESETS = { + "attn-mlp": ["attn", "ff.net"], + "attn-only": ["attn"], + "blocks": ["transformer_block"], + "full": [], + } def setup_optimizations( self, @@ -254,65 +254,16 @@ def predict( 'target': flow, } - if config.debug_mode: #TODO simplify + if config.debug_mode: with torch.no_grad(): - self._save_text( - self._decode_tokens(batch['tokens'], model.tokenizer), - config.debug_dir + "/training_batches", - "7-prompt", - train_progress.global_step, - ) - - # noise - self._save_image( - self._project_latent_to_image(latent_noise), - config.debug_dir + "/training_batches", - "1-noise", - train_progress.global_step, - ) - - # noisy image - self._save_image( - self._project_latent_to_image(scaled_noisy_latent_image), - config.debug_dir + "/training_batches", - "2-noisy_image", - train_progress.global_step, - ) - - # predicted flow - self._save_image( - self._project_latent_to_image(predicted_flow), - config.debug_dir + "/training_batches", - "3-predicted_flow", - train_progress.global_step, - ) - - # flow - flow = latent_noise - scaled_latent_image - self._save_image( - self._project_latent_to_image(flow), - config.debug_dir + "/training_batches", - "4-flow", - train_progress.global_step, - ) - predicted_scaled_latent_image = scaled_noisy_latent_image - predicted_flow * sigma - - # predicted image - self._save_image( - self._project_latent_to_image(predicted_scaled_latent_image), - config.debug_dir + "/training_batches", - "5-predicted_image", - train_progress.global_step, - ) - - # image - self._save_image( - self._project_latent_to_image(scaled_latent_image), - config.debug_dir + "/training_batches", - "6-image", - model.train_progress.global_step, - ) + self._save_tokens("7-prompt", batch['tokens'], model.tokenizer, config, train_progress) + self._save_latent("1-noise", latent_noise, config, train_progress) + self._save_latent("2-noisy_image", scaled_noisy_latent_image, config, train_progress) + self._save_latent("3-predicted_flow", predicted_flow, config, train_progress) + self._save_latent("4-flow", flow, config, train_progress) + self._save_latent("5-predicted_image", predicted_scaled_latent_image, config, train_progress) + self._save_latent("6-image", scaled_latent_image, config, train_progress) return model_output_data @@ -330,3 +281,13 @@ def calculate_loss( train_device=self.train_device, sigmas=model.noise_scheduler.sigmas, ).mean() + + + def prepare_text_caching(self, model: ChromaModel, config: TrainConfig): + model.to(self.temp_device) + + if not config.train_text_encoder_or_embedding(): + model.text_encoder_to(self.train_device) + + model.eval() + torch_gc() diff --git a/modules/modelSetup/BaseFlux2Setup.py b/modules/modelSetup/BaseFlux2Setup.py new file mode 100644 index 000000000..ecd64efe7 --- /dev/null +++ b/modules/modelSetup/BaseFlux2Setup.py @@ -0,0 +1,198 @@ +from abc import ABCMeta +from random import Random + +import modules.util.multi_gpu_util as multi +from modules.model.Flux2Model import Flux2Model +from modules.model.FluxModel import FluxModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.modelSetup.mixin.ModelSetupDebugMixin import ModelSetupDebugMixin +from modules.modelSetup.mixin.ModelSetupDiffusionLossMixin import ModelSetupDiffusionLossMixin +from modules.modelSetup.mixin.ModelSetupEmbeddingMixin import ModelSetupEmbeddingMixin +from modules.modelSetup.mixin.ModelSetupFlowMatchingMixin import ModelSetupFlowMatchingMixin +from modules.modelSetup.mixin.ModelSetupNoiseMixin import ModelSetupNoiseMixin +from modules.util.checkpointing_util import ( + enable_checkpointing_for_flux2_transformer, + enable_checkpointing_for_mistral_encoder_layers, +) +from modules.util.config.TrainConfig import TrainConfig +from modules.util.dtype_util import create_autocast_context, disable_fp16_autocast_context +from modules.util.enum.TrainingMethod import TrainingMethod +from modules.util.quantization_util import quantize_layers +from modules.util.torch_util import torch_gc +from modules.util.TrainProgress import TrainProgress + +import torch +from torch import Tensor + + +class BaseFlux2Setup( + BaseModelSetup, + ModelSetupDiffusionLossMixin, + ModelSetupDebugMixin, + ModelSetupNoiseMixin, + ModelSetupFlowMatchingMixin, + ModelSetupEmbeddingMixin, + metaclass=ABCMeta +): + LAYER_PRESETS = { + "blocks": ["transformer_block"], + "full": [], + } + + def setup_optimizations( + self, + model: Flux2Model, + config: TrainConfig, + ): + if config.gradient_checkpointing.enabled(): + model.transformer_offload_conductor = \ + enable_checkpointing_for_flux2_transformer(model.transformer, config) + if model.text_encoder is not None: + model.text_encoder_offload_conductor = \ + enable_checkpointing_for_mistral_encoder_layers(model.text_encoder, config) + + if config.force_circular_padding: + raise NotImplementedError #TODO applies to Flux2? +# apply_circular_padding_to_conv2d(model.vae) +# apply_circular_padding_to_conv2d(model.transformer) +# if model.transformer_lora is not None: +# apply_circular_padding_to_conv2d(model.transformer_lora) + + model.autocast_context, model.train_dtype = create_autocast_context(self.train_device, config.train_dtype, [ + config.weight_dtypes().transformer, + config.weight_dtypes().text_encoder, + config.weight_dtypes().vae, + config.weight_dtypes().lora if config.training_method == TrainingMethod.LORA else None, + ], config.enable_autocast_cache) + + model.text_encoder_autocast_context, model.text_encoder_train_dtype = \ + disable_fp16_autocast_context( + self.train_device, + config.train_dtype, + config.fallback_train_dtype, + [ + config.weight_dtypes().text_encoder, + config.weight_dtypes().lora if config.training_method == TrainingMethod.LORA else None, + ], + config.enable_autocast_cache, + ) + + quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config) + quantize_layers(model.vae, self.train_device, model.train_dtype, config) + quantize_layers(model.transformer, self.train_device, model.train_dtype, config) + + def predict( + self, + model: Flux2Model, + batch: dict, + config: TrainConfig, + train_progress: TrainProgress, + *, + deterministic: bool = False, + ) -> dict: + with model.autocast_context: + batch_seed = 0 if deterministic else train_progress.global_step * multi.world_size() + multi.rank() + generator = torch.Generator(device=config.train_device) + generator.manual_seed(batch_seed) + rand = Random(batch_seed) + + text_encoder_output = model.encode_text( + train_device=self.train_device, + batch_size=batch['latent_image'].shape[0], + rand=rand, + tokens=batch.get("tokens"), + tokens_mask=batch.get("tokens_mask"), + text_encoder_sequence_length=config.text_encoder_sequence_length, + text_encoder_output=batch.get('text_encoder_hidden_state'), + text_encoder_dropout_probability=config.text_encoder.dropout_probability, + ) + latent_image = model.patchify_latents(batch['latent_image'].float()) + latent_height = latent_image.shape[-2] + latent_width = latent_image.shape[-1] + scaled_latent_image = model.scale_latents(latent_image) + + latent_noise = self._create_noise(scaled_latent_image, config, generator) + + shift = model.calculate_timestep_shift(latent_height, latent_width) + timestep = self._get_timestep_discrete( + model.noise_scheduler.config['num_train_timesteps'], + deterministic, + generator, + scaled_latent_image.shape[0], + config, + shift = shift if config.dynamic_timestep_shifting else config.timestep_shift, + ) + + scaled_noisy_latent_image, sigma = self._add_noise_discrete( + scaled_latent_image, + latent_noise, + timestep, + model.noise_scheduler.timesteps, + ) + latent_input = scaled_noisy_latent_image + + guidance = torch.tensor([config.transformer.guidance_scale], device=self.train_device) + guidance = guidance.expand(latent_input.shape[0]) + + text_ids = model.prepare_text_ids(text_encoder_output) + image_ids = model.prepare_latent_image_ids(latent_input) + packed_latent_input = model.pack_latents(latent_input) + + packed_predicted_flow = model.transformer( + hidden_states=packed_latent_input.to(dtype=model.train_dtype.torch_dtype()), + timestep=timestep / 1000, + guidance=guidance.to(dtype=model.train_dtype.torch_dtype()), + encoder_hidden_states=text_encoder_output.to(dtype=model.train_dtype.torch_dtype()), + txt_ids=text_ids, + img_ids=image_ids, + joint_attention_kwargs=None, + return_dict=True + ).sample + + predicted_flow = model.unpack_latents( + packed_predicted_flow, + latent_input.shape[2], + latent_input.shape[3], + ) + + flow = latent_noise - scaled_latent_image + model_output_data = { + 'loss_type': 'target', + 'timestep': timestep, + 'predicted': predicted_flow, + 'target': flow, + } + + if config.debug_mode: + with torch.no_grad(): + predicted_scaled_latent_image = scaled_noisy_latent_image - predicted_flow * sigma + self._save_tokens("7-prompt", batch['tokens'], model.tokenizer, config, train_progress) + self._save_latent("1-noise", latent_noise, config, train_progress) + self._save_latent("2-noisy_image", scaled_noisy_latent_image, config, train_progress) + self._save_latent("3-predicted_flow", predicted_flow, config, train_progress) + self._save_latent("4-flow", flow, config, train_progress) + self._save_latent("5-predicted_image", predicted_scaled_latent_image, config, train_progress) + self._save_latent("6-image", scaled_latent_image, config, train_progress) + + return model_output_data + + def calculate_loss( + self, + model: Flux2Model, + batch: dict, + data: dict, + config: TrainConfig, + ) -> Tensor: + return self._flow_matching_losses( + batch=batch, + data=data, + config=config, + train_device=self.train_device, + sigmas=model.noise_scheduler.sigmas, + ).mean() + + def prepare_text_caching(self, model: FluxModel, config: TrainConfig): + model.to(self.temp_device) + model.text_encoder_to(self.train_device) + model.eval() + torch_gc() diff --git a/modules/modelSetup/BaseFluxSetup.py b/modules/modelSetup/BaseFluxSetup.py index 1865f382e..02a4e9740 100644 --- a/modules/modelSetup/BaseFluxSetup.py +++ b/modules/modelSetup/BaseFluxSetup.py @@ -20,17 +20,12 @@ from modules.util.dtype_util import create_autocast_context, disable_fp16_autocast_context from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.quantization_util import quantize_layers +from modules.util.torch_util import torch_gc from modules.util.TrainProgress import TrainProgress import torch from torch import Tensor -PRESETS = { - "attn-mlp": ["attn", "ff.net"], - "attn-only": ["attn"], - "blocks": ["transformer_block"], - "full": [], -} class BaseFluxSetup( BaseModelSetup, @@ -41,6 +36,12 @@ class BaseFluxSetup( ModelSetupEmbeddingMixin, metaclass=ABCMeta ): + LAYER_PRESETS = { + "attn-mlp": ["attn", "ff.net"], + "attn-only": ["attn"], + "blocks": ["transformer_block"], + "full": [], + } def setup_optimizations( self, @@ -316,63 +317,14 @@ def predict( if config.debug_mode: with torch.no_grad(): - self._save_text( - self._decode_tokens(batch['tokens_1'], model.tokenizer_1), - config.debug_dir + "/training_batches", - "7-prompt", - train_progress.global_step, - ) - - # noise - self._save_image( - self._project_latent_to_image(latent_noise), - config.debug_dir + "/training_batches", - "1-noise", - train_progress.global_step, - ) - - # noisy image - self._save_image( - self._project_latent_to_image(scaled_noisy_latent_image), - config.debug_dir + "/training_batches", - "2-noisy_image", - train_progress.global_step, - ) - - # predicted flow - self._save_image( - self._project_latent_to_image(predicted_flow), - config.debug_dir + "/training_batches", - "3-predicted_flow", - train_progress.global_step, - ) - - # flow - flow = latent_noise - scaled_latent_image - self._save_image( - self._project_latent_to_image(flow), - config.debug_dir + "/training_batches", - "4-flow", - train_progress.global_step, - ) - predicted_scaled_latent_image = scaled_noisy_latent_image - predicted_flow * sigma - - # predicted image - self._save_image( - self._project_latent_to_image(predicted_scaled_latent_image), - config.debug_dir + "/training_batches", - "5-predicted_image", - train_progress.global_step, - ) - - # image - self._save_image( - self._project_latent_to_image(scaled_latent_image), - config.debug_dir + "/training_batches", - "6-image", - model.train_progress.global_step, - ) + self._save_tokens("7-prompt", batch['tokens_1'], model.tokenizer_1, config, train_progress) + self._save_latent("1-noise", latent_noise, config, train_progress) + self._save_latent("2-noisy_image", scaled_noisy_latent_image, config, train_progress) + self._save_latent("3-predicted_flow", predicted_flow, config, train_progress) + self._save_latent("4-flow", flow, config, train_progress) + self._save_latent("5-predicted_image", predicted_scaled_latent_image, config, train_progress) + self._save_latent("6-image", scaled_latent_image, config, train_progress) return model_output_data @@ -390,3 +342,15 @@ def calculate_loss( train_device=self.train_device, sigmas=model.noise_scheduler.sigmas, ).mean() + + def prepare_text_caching(self, model: FluxModel, config: TrainConfig): + model.to(self.temp_device) + + if not config.train_text_encoder_or_embedding(): + model.text_encoder_to(self.train_device) + + if not config.train_text_encoder_2_or_embedding(): + model.text_encoder_2_to(self.train_device) + + model.eval() + torch_gc() diff --git a/modules/modelSetup/BaseHiDreamSetup.py b/modules/modelSetup/BaseHiDreamSetup.py index 48e691ecd..0abb27d3e 100644 --- a/modules/modelSetup/BaseHiDreamSetup.py +++ b/modules/modelSetup/BaseHiDreamSetup.py @@ -19,18 +19,12 @@ from modules.util.dtype_util import create_autocast_context, disable_fp16_autocast_context from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.quantization_util import quantize_layers +from modules.util.torch_util import torch_gc from modules.util.TrainProgress import TrainProgress import torch from torch import Tensor -PRESETS = { - "attn-mlp": ["attn1", "ff_i"], - "attn-only": ["attn1"], - "blocks": ["stream_block"], - "full": [], -} - class BaseHiDreamSetup( BaseModelSetup, @@ -41,6 +35,12 @@ class BaseHiDreamSetup( ModelSetupEmbeddingMixin, metaclass=ABCMeta ): + LAYER_PRESETS = { + "attn-mlp": ["attn1", "ff_i"], + "attn-only": ["attn1"], + "blocks": ["stream_block"], + "full": [], + } def setup_optimizations( self, @@ -403,63 +403,14 @@ def predict( if config.debug_mode: with torch.no_grad(): - self._save_text( - self._decode_tokens(batch['tokens_1'], model.tokenizer_1), - config.debug_dir + "/training_batches", - "7-prompt", - train_progress.global_step, - ) - - # noise - self._save_image( - self._project_latent_to_image(latent_noise), - config.debug_dir + "/training_batches", - "1-noise", - train_progress.global_step, - ) - - # noisy image - self._save_image( - self._project_latent_to_image(scaled_noisy_latent_image), - config.debug_dir + "/training_batches", - "2-noisy_image", - train_progress.global_step, - ) - - # predicted flow - self._save_image( - self._project_latent_to_image(predicted_flow), - config.debug_dir + "/training_batches", - "3-predicted_flow", - train_progress.global_step, - ) - - # flow - flow = latent_noise - scaled_latent_image - self._save_image( - self._project_latent_to_image(flow), - config.debug_dir + "/training_batches", - "4-flow", - train_progress.global_step, - ) - predicted_scaled_latent_image = scaled_noisy_latent_image - predicted_flow * sigma - - # predicted image - self._save_image( - self._project_latent_to_image(predicted_scaled_latent_image), - config.debug_dir + "/training_batches", - "5-predicted_image", - train_progress.global_step, - ) - - # image - self._save_image( - self._project_latent_to_image(scaled_latent_image), - config.debug_dir + "/training_batches", - "6-image", - model.train_progress.global_step, - ) + self._save_tokens("7-prompt", batch['tokens_1'], model.tokenizer_1, config, train_progress) + self._save_latent("1-noise", latent_noise, config, train_progress) + self._save_latent("2-noisy_image", scaled_noisy_latent_image, config, train_progress) + self._save_latent("3-predicted_flow", predicted_flow, config, train_progress) + self._save_latent("4-flow", flow, config, train_progress) + self._save_latent("5-predicted_image", predicted_scaled_latent_image, config, train_progress) + self._save_latent("6-image", scaled_latent_image, config, train_progress) return model_output_data @@ -477,3 +428,21 @@ def calculate_loss( train_device=self.train_device, sigmas=model.noise_scheduler.sigmas, ).mean() + + def prepare_text_caching(self, model: HiDreamModel, config: TrainConfig): + model.to(self.temp_device) + + if not config.train_text_encoder_or_embedding(): + model.text_encoder_to(self.train_device) + + if not config.train_text_encoder_2_or_embedding(): + model.text_encoder_2_to(self.train_device) + + if not config.train_text_encoder_3_or_embedding(): + model.text_encoder_3_to(self.train_device) + + if not config.train_text_encoder_4_or_embedding(): + model.text_encoder_4_to(self.train_device) + + model.eval() + torch_gc() diff --git a/modules/modelSetup/BaseHunyuanVideoSetup.py b/modules/modelSetup/BaseHunyuanVideoSetup.py index bbb90b71a..16bf5e5f7 100644 --- a/modules/modelSetup/BaseHunyuanVideoSetup.py +++ b/modules/modelSetup/BaseHunyuanVideoSetup.py @@ -20,18 +20,12 @@ from modules.util.dtype_util import create_autocast_context, disable_fp16_autocast_context from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.quantization_util import quantize_layers +from modules.util.torch_util import torch_gc from modules.util.TrainProgress import TrainProgress import torch from torch import Tensor -PRESETS = { - "attn-mlp": ["attn", "ff.net"], - "attn-only": ["attn"], - "blocks": ["transformer_block"], - "full": [], -} - class BaseHunyuanVideoSetup( BaseModelSetup, @@ -42,6 +36,12 @@ class BaseHunyuanVideoSetup( ModelSetupEmbeddingMixin, metaclass=ABCMeta ): + LAYER_PRESETS = { + "attn-mlp": ["attn", "ff.net"], + "attn-only": ["attn"], + "blocks": ["transformer_block"], + "full": [], + } def setup_optimizations( self, @@ -286,62 +286,14 @@ def predict( if config.debug_mode: with torch.no_grad(): - self._save_text( - self._decode_tokens(batch['tokens_1'], model.tokenizer_1), - config.debug_dir + "/training_batches", - "7-prompt", - train_progress.global_step, - ) - - # noise - self._save_image( - self._project_latent_to_image(latent_noise), - config.debug_dir + "/training_batches", - "1-noise", - train_progress.global_step, - ) - - # noisy image - self._save_image( - self._project_latent_to_image(scaled_noisy_latent_image), - config.debug_dir + "/training_batches", - "2-noisy_image", - train_progress.global_step, - ) - - # predicted flow - self._save_image( - self._project_latent_to_image(predicted_flow), - config.debug_dir + "/training_batches", - "3-predicted_flow", - train_progress.global_step, - ) - - # flow - self._save_image( - self._project_latent_to_image(flow), - config.debug_dir + "/training_batches", - "4-flow", - train_progress.global_step, - ) - predicted_scaled_latent_image = scaled_noisy_latent_image - predicted_flow * sigma - - # predicted image - self._save_image( - self._project_latent_to_image(predicted_scaled_latent_image), - config.debug_dir + "/training_batches", - "5-predicted_image", - train_progress.global_step, - ) - - # image - self._save_image( - self._project_latent_to_image(scaled_latent_image), - config.debug_dir + "/training_batches", - "6-image", - model.train_progress.global_step, - ) + self._save_tokens("7-prompt", batch['tokens_1'], model.tokenizer_1, config, train_progress) + self._save_latent("1-noise", latent_noise, config, train_progress) + self._save_latent("2-noisy_image", scaled_noisy_latent_image, config, train_progress) + self._save_latent("3-predicted_flow", predicted_flow, config, train_progress) + self._save_latent("4-flow", flow, config, train_progress) + self._save_latent("5-predicted_image", predicted_scaled_latent_image, config, train_progress) + self._save_latent("6-image", scaled_latent_image, config, train_progress) return model_output_data @@ -359,3 +311,15 @@ def calculate_loss( train_device=self.train_device, sigmas=model.noise_scheduler.sigmas, ).mean() + + def prepare_text_caching(self, model: HunyuanVideoModel, config: TrainConfig): + model.to(self.temp_device) + + if not config.train_text_encoder_or_embedding(): + model.text_encoder_to(self.train_device) + + if not config.train_text_encoder_2_or_embedding(): + model.text_encoder_2_to(self.train_device) + + model.eval() + torch_gc() diff --git a/modules/modelSetup/BasePixArtAlphaSetup.py b/modules/modelSetup/BasePixArtAlphaSetup.py index 8240fb5f4..c4642b656 100644 --- a/modules/modelSetup/BasePixArtAlphaSetup.py +++ b/modules/modelSetup/BasePixArtAlphaSetup.py @@ -19,17 +19,12 @@ from modules.util.dtype_util import create_autocast_context, disable_fp16_autocast_context from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.quantization_util import quantize_layers +from modules.util.torch_util import torch_gc from modules.util.TrainProgress import TrainProgress import torch from torch import Tensor -PRESETS = { - "attn-mlp": ["attn1", "attn2", "ff.net"], - "attn-only": ["attn1", "attn2"], - "blocks": ["transformer_block"], - "full": [], -} class BasePixArtAlphaSetup( BaseModelSetup, @@ -40,6 +35,12 @@ class BasePixArtAlphaSetup( ModelSetupEmbeddingMixin, metaclass=ABCMeta, ): + LAYER_PRESETS = { + "attn-mlp": ["attn1", "attn2", "ff.net"], + "attn-only": ["attn1", "attn2"], + "blocks": ["transformer_block"], + "full": [], + } def __init__(self, train_device: torch.device, temp_device: torch.device, debug_mode: bool): super().__init__(train_device, temp_device, debug_mode) @@ -342,3 +343,12 @@ def calculate_loss( train_device=self.train_device, betas=model.noise_scheduler.betas, ).mean() + + def prepare_text_caching(self, model: PixArtAlphaModel, config: TrainConfig): + model.to(self.temp_device) + + if not config.train_text_encoder_or_embedding(): + model.text_encoder_to(self.train_device) + + model.eval() + torch_gc() diff --git a/modules/modelSetup/BaseQwenSetup.py b/modules/modelSetup/BaseQwenSetup.py index dc7115274..dd2d97d0b 100644 --- a/modules/modelSetup/BaseQwenSetup.py +++ b/modules/modelSetup/BaseQwenSetup.py @@ -17,18 +17,12 @@ from modules.util.dtype_util import create_autocast_context, disable_fp16_autocast_context from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.quantization_util import quantize_layers +from modules.util.torch_util import torch_gc from modules.util.TrainProgress import TrainProgress import torch from torch import Tensor -PRESETS = { - "attn-mlp": ["attn", "img_mlp", "txt_mlp"], - "attn-only": ["attn"], - "blocks": ["transformer_block"], - "full": [], -} - #TODO share more code with other models class BaseQwenSetup( @@ -39,6 +33,12 @@ class BaseQwenSetup( ModelSetupFlowMatchingMixin, metaclass=ABCMeta ): + LAYER_PRESETS = { + "attn-mlp": ["attn", "img_mlp", "txt_mlp"], + "attn-only": ["attn"], + "blocks": ["transformer_block"], + "full": [], + } def setup_optimizations( self, @@ -176,65 +176,16 @@ def predict( 'target': flow, } - if config.debug_mode: #TODO simplify + if config.debug_mode: with torch.no_grad(): - self._save_text( - self._decode_tokens(batch['tokens'], model.tokenizer), - config.debug_dir + "/training_batches", - "7-prompt", - train_progress.global_step, - ) - - # noise - self._save_image( - self._project_latent_to_image(latent_noise), - config.debug_dir + "/training_batches", - "1-noise", - train_progress.global_step, - ) - - # noisy image - self._save_image( - self._project_latent_to_image(scaled_noisy_latent_image), - config.debug_dir + "/training_batches", - "2-noisy_image", - train_progress.global_step, - ) - - # predicted flow - self._save_image( - self._project_latent_to_image(predicted_flow), - config.debug_dir + "/training_batches", - "3-predicted_flow", - train_progress.global_step, - ) - - # flow - flow = latent_noise - scaled_latent_image - self._save_image( - self._project_latent_to_image(flow), - config.debug_dir + "/training_batches", - "4-flow", - train_progress.global_step, - ) - predicted_scaled_latent_image = scaled_noisy_latent_image - predicted_flow * sigma - - # predicted image - self._save_image( - self._project_latent_to_image(predicted_scaled_latent_image), - config.debug_dir + "/training_batches", - "5-predicted_image", - train_progress.global_step, - ) - - # image - self._save_image( - self._project_latent_to_image(scaled_latent_image), - config.debug_dir + "/training_batches", - "6-image", - model.train_progress.global_step, - ) + self._save_tokens("7-prompt", batch['tokens'], model.tokenizer, config, train_progress) + self._save_latent("1-noise", latent_noise, config, train_progress) + self._save_latent("2-noisy_image", scaled_noisy_latent_image, config, train_progress) + self._save_latent("3-predicted_flow", predicted_flow, config, train_progress) + self._save_latent("4-flow", flow, config, train_progress) + self._save_latent("5-predicted_image", predicted_scaled_latent_image, config, train_progress) + self._save_latent("6-image", scaled_latent_image, config, train_progress) return model_output_data @@ -252,3 +203,12 @@ def calculate_loss( train_device=self.train_device, sigmas=model.noise_scheduler.sigmas, ).mean() + + def prepare_text_caching(self, model: QwenModel, config: TrainConfig): + model.to(self.temp_device) + + if not config.train_text_encoder_or_embedding(): + model.text_encoder_to(self.train_device) + + model.eval() + torch_gc() diff --git a/modules/modelSetup/BaseSanaSetup.py b/modules/modelSetup/BaseSanaSetup.py index 84078ff6f..c2189fd20 100644 --- a/modules/modelSetup/BaseSanaSetup.py +++ b/modules/modelSetup/BaseSanaSetup.py @@ -19,18 +19,12 @@ from modules.util.dtype_util import create_autocast_context, disable_fp16_autocast_context from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.quantization_util import quantize_layers +from modules.util.torch_util import torch_gc from modules.util.TrainProgress import TrainProgress import torch from torch import Tensor -PRESETS = { - "attn-mlp": ["attn1", "attn2", "ff."], - "attn-only": ["attn1", "attn2"], - "blocks": ["transformer_block"], - "full": [], -} - class BaseSanaSetup( BaseModelSetup, @@ -41,6 +35,12 @@ class BaseSanaSetup( ModelSetupEmbeddingMixin, metaclass=ABCMeta, ): + LAYER_PRESETS = { + "attn-mlp": ["attn1", "attn2", "ff."], + "attn-only": ["attn1", "attn2"], + "blocks": ["transformer_block"], + "full": [], + } def __init__(self, train_device: torch.device, temp_device: torch.device, debug_mode: bool): super().__init__(train_device, temp_device, debug_mode) @@ -241,63 +241,14 @@ def predict( if self.debug_mode: with torch.no_grad(): - self._save_text( - self._decode_tokens(batch['tokens'], model.tokenizer), - config.debug_dir + "/training_batches", - "7-prompt", - train_progress.global_step, - ) - - # noise - self._save_image( - self._project_latent_to_image(latent_noise), - config.debug_dir + "/training_batches", - "1-noise", - train_progress.global_step, - ) - - # noisy image - self._save_image( - self._project_latent_to_image(scaled_noisy_latent_image), - config.debug_dir + "/training_batches", - "2-noisy_image", - train_progress.global_step, - ) - - # predicted flow - self._save_image( - self._project_latent_to_image(predicted_flow), - config.debug_dir + "/training_batches", - "3-predicted_flow", - train_progress.global_step, - ) - - # flow - flow = latent_noise - scaled_latent_image - self._save_image( - self._project_latent_to_image(flow), - config.debug_dir + "/training_batches", - "4-flow", - train_progress.global_step, - ) - predicted_scaled_latent_image = scaled_noisy_latent_image - predicted_flow * sigma - - # predicted image - self._save_image( - self._project_latent_to_image(predicted_scaled_latent_image), - config.debug_dir + "/training_batches", - "5-predicted_image", - train_progress.global_step, - ) - - # image - self._save_image( - self._project_latent_to_image(scaled_latent_image), - config.debug_dir + "/training_batches", - "6-image", - model.train_progress.global_step, - ) + self._save_tokens("7-prompt", batch['tokens'], model.tokenizer, config, train_progress) + self._save_latent("1-noise", latent_noise, config, train_progress) + self._save_latent("2-noisy_image", scaled_noisy_latent_image, config, train_progress) + self._save_latent("3-predicted_flow", predicted_flow, config, train_progress) + self._save_latent("4-flow", flow, config, train_progress) + self._save_latent("5-predicted_image", predicted_scaled_latent_image, config, train_progress) + self._save_latent("6-image", scaled_latent_image, config, train_progress) return model_output_data @@ -315,3 +266,12 @@ def calculate_loss( train_device=self.train_device, betas=model.noise_scheduler.betas, ).mean() + + def prepare_text_caching(self, model: SanaModel, config: TrainConfig): + model.to(self.temp_device) + + if not config.train_text_encoder_or_embedding(): + model.text_encoder_to(self.train_device) + + model.eval() + torch_gc() diff --git a/modules/modelSetup/BaseStableDiffusion3Setup.py b/modules/modelSetup/BaseStableDiffusion3Setup.py index 5015b21af..14f876ed1 100644 --- a/modules/modelSetup/BaseStableDiffusion3Setup.py +++ b/modules/modelSetup/BaseStableDiffusion3Setup.py @@ -20,16 +20,12 @@ from modules.util.dtype_util import create_autocast_context, disable_fp16_autocast_context from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.quantization_util import quantize_layers +from modules.util.torch_util import torch_gc from modules.util.TrainProgress import TrainProgress import torch from torch import Tensor -PRESETS = { - "attn-only": ["attn"], - "blocks": ["transformer_block"], - "full": [], -} class BaseStableDiffusion3Setup( BaseModelSetup, @@ -40,6 +36,11 @@ class BaseStableDiffusion3Setup( ModelSetupEmbeddingMixin, metaclass=ABCMeta ): + LAYER_PRESETS = { + "attn-only": ["attn"], + "blocks": ["transformer_block"], + "full": [], + } def setup_optimizations( self, @@ -338,63 +339,14 @@ def predict( if config.debug_mode: with torch.no_grad(): - self._save_text( - self._decode_tokens(batch['tokens_1'], model.tokenizer_1), - config.debug_dir + "/training_batches", - "7-prompt", - train_progress.global_step, - ) - - # noise - self._save_image( - self._project_latent_to_image(latent_noise), - config.debug_dir + "/training_batches", - "1-noise", - train_progress.global_step, - ) - - # noisy image - self._save_image( - self._project_latent_to_image(scaled_noisy_latent_image), - config.debug_dir + "/training_batches", - "2-noisy_image", - train_progress.global_step, - ) - - # predicted flow - self._save_image( - self._project_latent_to_image(predicted_flow), - config.debug_dir + "/training_batches", - "3-predicted_flow", - train_progress.global_step, - ) - - # flow - flow = latent_noise - scaled_latent_image - self._save_image( - self._project_latent_to_image(flow), - config.debug_dir + "/training_batches", - "4-flow", - train_progress.global_step, - ) - predicted_scaled_latent_image = scaled_noisy_latent_image - predicted_flow * sigma - - # predicted image - self._save_image( - self._project_latent_to_image(predicted_scaled_latent_image), - config.debug_dir + "/training_batches", - "5-predicted_image", - train_progress.global_step, - ) - - # image - self._save_image( - self._project_latent_to_image(scaled_latent_image), - config.debug_dir + "/training_batches", - "6-image", - model.train_progress.global_step, - ) + self._save_tokens("7-prompt", batch['tokens_1'], model.tokenizer_1, config, train_progress) + self._save_latent("1-noise", latent_noise, config, train_progress) + self._save_latent("2-noisy_image", scaled_noisy_latent_image, config, train_progress) + self._save_latent("3-predicted_flow", predicted_flow, config, train_progress) + self._save_latent("4-flow", flow, config, train_progress) + self._save_latent("5-predicted_image", predicted_scaled_latent_image, config, train_progress) + self._save_latent("6-image", scaled_latent_image, config, train_progress) return model_output_data @@ -412,3 +364,18 @@ def calculate_loss( train_device=self.train_device, sigmas=model.noise_scheduler.sigmas, ).mean() + + def prepare_text_caching(self, model: StableDiffusion3Model, config: TrainConfig): + model.to(self.temp_device) + + if not config.train_text_encoder_or_embedding(): + model.text_encoder_to(self.train_device) + + if not config.train_text_encoder_2_or_embedding(): + model.text_encoder_2_to(self.train_device) + + if not config.train_text_encoder_3_or_embedding(): + model.text_encoder_3_to(self.train_device) + + model.eval() + torch_gc() diff --git a/modules/modelSetup/BaseStableDiffusionSetup.py b/modules/modelSetup/BaseStableDiffusionSetup.py index 0fc6ed0df..555f97d7e 100644 --- a/modules/modelSetup/BaseStableDiffusionSetup.py +++ b/modules/modelSetup/BaseStableDiffusionSetup.py @@ -19,16 +19,12 @@ from modules.util.dtype_util import create_autocast_context from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.quantization_util import quantize_layers +from modules.util.torch_util import torch_gc from modules.util.TrainProgress import TrainProgress import torch from torch import Tensor -PRESETS = { - "attn-mlp": ["attentions"], - "attn-only": ["attn"], - "full": [], -} class BaseStableDiffusionSetup( BaseModelSetup, @@ -39,6 +35,11 @@ class BaseStableDiffusionSetup( ModelSetupEmbeddingMixin, metaclass=ABCMeta, ): + LAYER_PRESETS = { + "attn-mlp": ["attentions"], + "attn-only": ["attn"], + "full": [], + } def __init__(self, train_device: torch.device, temp_device: torch.device, debug_mode: bool): super().__init__(train_device, temp_device, debug_mode) @@ -335,3 +336,12 @@ def calculate_loss( train_device=self.train_device, betas=model.noise_scheduler.betas, ).mean() + + def prepare_text_caching(self, model: StableDiffusionModel, config: TrainConfig): + model.to(self.temp_device) + + if not config.train_text_encoder_or_embedding(): + model.text_encoder_to(self.train_device) + + model.eval() + torch_gc() diff --git a/modules/modelSetup/BaseStableDiffusionXLSetup.py b/modules/modelSetup/BaseStableDiffusionXLSetup.py index 37121951a..a60f2551f 100644 --- a/modules/modelSetup/BaseStableDiffusionXLSetup.py +++ b/modules/modelSetup/BaseStableDiffusionXLSetup.py @@ -19,16 +19,12 @@ from modules.util.dtype_util import create_autocast_context, disable_fp16_autocast_context from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.quantization_util import quantize_layers +from modules.util.torch_util import torch_gc from modules.util.TrainProgress import TrainProgress import torch from torch import Tensor -PRESETS = { - "attn-mlp": ["attentions"], - "attn-only": ["attn"], - "full": [], -} class BaseStableDiffusionXLSetup( BaseModelSetup, @@ -39,6 +35,11 @@ class BaseStableDiffusionXLSetup( ModelSetupEmbeddingMixin, metaclass=ABCMeta ): + LAYER_PRESETS = { + "attn-mlp": ["attentions"], + "attn-only": ["attn"], + "full": [], + } def setup_optimizations( self, @@ -383,3 +384,15 @@ def calculate_loss( train_device=self.train_device, betas=model.noise_scheduler.betas, ).mean() + + def prepare_text_caching(self, model: StableDiffusionXLModel, config: TrainConfig): + model.to(self.temp_device) + + if not config.train_text_encoder_or_embedding(): + model.text_encoder_to(self.train_device) + + if not config.train_text_encoder_2_or_embedding(): + model.text_encoder_2_to(self.train_device) + + model.eval() + torch_gc() diff --git a/modules/modelSetup/BaseWuerstchenSetup.py b/modules/modelSetup/BaseWuerstchenSetup.py index 23b3440a5..f8f758910 100644 --- a/modules/modelSetup/BaseWuerstchenSetup.py +++ b/modules/modelSetup/BaseWuerstchenSetup.py @@ -23,21 +23,12 @@ ) from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.quantization_util import quantize_layers +from modules.util.torch_util import torch_gc from modules.util.TrainProgress import TrainProgress import torch from torch import Tensor -# This is correct for the latest cascade, but other Wuerstchen models may have -# different names. I honestly don't know what makes a good preset here so I'm -# just guessing. -PRESETS = { - "attn-only": ["attention"], - "full": [], - "down-blocks": ["down_blocks"], - "up-blocks": ["up_blocks"], - "mapper-only": ["mapper"], -} class BaseWuerstchenSetup( BaseModelSetup, @@ -48,6 +39,16 @@ class BaseWuerstchenSetup( ModelSetupEmbeddingMixin, metaclass=ABCMeta, ): + # This is correct for the latest cascade, but other Wuerstchen models may have + # different names. I honestly don't know what makes a good preset here so I'm + # just guessing. + LAYER_PRESETS = { + "attn-only": ["attention"], + "full": [], + "down-blocks": ["down_blocks"], + "up-blocks": ["up_blocks"], + "mapper-only": ["mapper"], + } def setup_optimizations( self, @@ -357,3 +358,12 @@ def calculate_loss( train_device=self.train_device, alphas_cumprod_fun=self.__alpha_cumprod, ).mean() + + def prepare_text_caching(self, model: WuerstchenModel, config: TrainConfig): + model.to(self.temp_device) + + if not config.train_text_encoder_or_embedding(): + model.text_encoder_to(self.train_device) + + model.eval() + torch_gc() diff --git a/modules/modelSetup/BaseZImageSetup.py b/modules/modelSetup/BaseZImageSetup.py index 727122f8b..e792500c1 100644 --- a/modules/modelSetup/BaseZImageSetup.py +++ b/modules/modelSetup/BaseZImageSetup.py @@ -17,17 +17,12 @@ from modules.util.dtype_util import create_autocast_context, disable_fp16_autocast_context from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.quantization_util import quantize_layers +from modules.util.torch_util import torch_gc from modules.util.TrainProgress import TrainProgress import torch from torch import Tensor -PRESETS = { - "full": [], - "blocks": ["layers"], - "attn-mlp": {'patterns': ["^(?=.*attention)(?!.*refiner).*", "^(?=.*feed_forward)(?!.*refiner).*"], 'regex': True}, - "attn-only": {'patterns': ["^(?=.*attention)(?!.*refiner).*"], 'regex': True}, -} class BaseZImageSetup( BaseModelSetup, @@ -38,6 +33,12 @@ class BaseZImageSetup( ModelSetupEmbeddingMixin, metaclass=ABCMeta ): + LAYER_PRESETS = { + "full": [], + "blocks": ["layers"], + "attn-mlp": {'patterns': ["^(?=.*attention)(?!.*refiner).*", "^(?=.*feed_forward)(?!.*refiner).*"], 'regex': True}, + "attn-only": {'patterns': ["^(?=.*attention)(?!.*refiner).*"], 'regex': True}, + } def setup_optimizations( self, @@ -149,63 +150,14 @@ def predict( if config.debug_mode: with torch.no_grad(): - self._save_text( #TODO share code - self._decode_tokens(batch['tokens'], model.tokenizer), - config.debug_dir + "/training_batches", - "7-prompt", - train_progress.global_step, - ) - - # noise - self._save_image( - self._project_latent_to_image(latent_noise), - config.debug_dir + "/training_batches", - "1-noise", - train_progress.global_step, - ) - - # noisy image - self._save_image( - self._project_latent_to_image(scaled_noisy_latent_image), - config.debug_dir + "/training_batches", - "2-noisy_image", - train_progress.global_step, - ) - - # predicted flow - self._save_image( - self._project_latent_to_image(predicted_flow), - config.debug_dir + "/training_batches", - "3-predicted_flow", - train_progress.global_step, - ) - - # flow - flow = latent_noise - scaled_latent_image - self._save_image( - self._project_latent_to_image(flow), - config.debug_dir + "/training_batches", - "4-flow", - train_progress.global_step, - ) - predicted_scaled_latent_image = scaled_noisy_latent_image - predicted_flow * sigma - - # predicted image - self._save_image( - self._project_latent_to_image(predicted_scaled_latent_image), - config.debug_dir + "/training_batches", - "5-predicted_image", - train_progress.global_step, - ) - - # image - self._save_image( - self._project_latent_to_image(scaled_latent_image), - config.debug_dir + "/training_batches", - "6-image", - model.train_progress.global_step, - ) + self._save_tokens("7-prompt", batch['tokens'], model.tokenizer, config, train_progress) + self._save_latent("1-noise", latent_noise, config, train_progress) + self._save_latent("2-noisy_image", scaled_noisy_latent_image, config, train_progress) + self._save_latent("3-predicted_flow", predicted_flow, config, train_progress) + self._save_latent("4-flow", flow, config, train_progress) + self._save_latent("5-predicted_image", predicted_scaled_latent_image, config, train_progress) + self._save_latent("6-image", scaled_latent_image, config, train_progress) return model_output_data @@ -223,3 +175,10 @@ def calculate_loss( train_device=self.train_device, sigmas=model.noise_scheduler.sigmas, ).mean() + + def prepare_text_caching(self, model: ZImageModel, config: TrainConfig): + model.to(self.temp_device) + model.text_encoder_to(self.train_device) + + model.eval() + torch_gc() diff --git a/modules/modelSetup/ChromaEmbeddingSetup.py b/modules/modelSetup/ChromaEmbeddingSetup.py index 033ca7505..724795b3d 100644 --- a/modules/modelSetup/ChromaEmbeddingSetup.py +++ b/modules/modelSetup/ChromaEmbeddingSetup.py @@ -1,6 +1,10 @@ from modules.model.ChromaModel import ChromaModel from modules.modelSetup.BaseChromaSetup import BaseChromaSetup +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters from modules.util.TrainProgress import TrainProgress @@ -91,3 +95,5 @@ def after_optimizer_step( if model.embedding_wrapper is not None: model.embedding_wrapper.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, ChromaEmbeddingSetup, ModelType.CHROMA_1, TrainingMethod.EMBEDDING) diff --git a/modules/modelSetup/ChromaFineTuneSetup.py b/modules/modelSetup/ChromaFineTuneSetup.py index 82897b89a..bff5a041d 100644 --- a/modules/modelSetup/ChromaFineTuneSetup.py +++ b/modules/modelSetup/ChromaFineTuneSetup.py @@ -1,6 +1,10 @@ from modules.model.ChromaModel import ChromaModel from modules.modelSetup.BaseChromaSetup import BaseChromaSetup +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.ModuleFilter import ModuleFilter from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters @@ -111,3 +115,5 @@ def after_optimizer_step( if model.embedding_wrapper is not None: model.embedding_wrapper.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, ChromaFineTuneSetup, ModelType.CHROMA_1, TrainingMethod.FINE_TUNE) diff --git a/modules/modelSetup/ChromaLoRASetup.py b/modules/modelSetup/ChromaLoRASetup.py index 9802c7b00..444e7724d 100644 --- a/modules/modelSetup/ChromaLoRASetup.py +++ b/modules/modelSetup/ChromaLoRASetup.py @@ -1,7 +1,11 @@ from modules.model.ChromaModel import ChromaModel from modules.modelSetup.BaseChromaSetup import BaseChromaSetup +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.module.LoRAModule import LoRAModuleWrapper +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters from modules.util.torch_util import state_dict_has_prefix @@ -139,3 +143,5 @@ def after_optimizer_step( if model.embedding_wrapper is not None: model.embedding_wrapper.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, ChromaLoRASetup, ModelType.CHROMA_1, TrainingMethod.LORA) diff --git a/modules/modelSetup/Flux2FineTuneSetup.py b/modules/modelSetup/Flux2FineTuneSetup.py new file mode 100644 index 000000000..06b0ff2ad --- /dev/null +++ b/modules/modelSetup/Flux2FineTuneSetup.py @@ -0,0 +1,88 @@ +from modules.model.Flux2Model import Flux2Model +from modules.modelSetup.BaseFlux2Setup import BaseFlux2Setup +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.util import factory +from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod +from modules.util.ModuleFilter import ModuleFilter +from modules.util.NamedParameterGroup import NamedParameterGroupCollection +from modules.util.optimizer_util import init_model_parameters +from modules.util.TrainProgress import TrainProgress + +import torch + + +class Flux2FineTuneSetup( + BaseFlux2Setup, +): + def __init__( + self, + train_device: torch.device, + temp_device: torch.device, + debug_mode: bool, + ): + super().__init__( + train_device=train_device, + temp_device=temp_device, + debug_mode=debug_mode, + ) + + def create_parameters( + self, + model: Flux2Model, + config: TrainConfig, + ) -> NamedParameterGroupCollection: + parameter_group_collection = NamedParameterGroupCollection() + + self._create_model_part_parameters(parameter_group_collection, "transformer", model.transformer, config.transformer, + freeze=ModuleFilter.create(config), debug=config.debug_mode) + return parameter_group_collection + + def __setup_requires_grad( + self, + model: Flux2Model, + config: TrainConfig, + ): + self._setup_model_part_requires_grad("transformer", model.transformer, config.transformer, model.train_progress) + model.vae.requires_grad_(False) + model.text_encoder.requires_grad_(False) + + + def setup_model( + self, + model: Flux2Model, + config: TrainConfig, + ): + self.__setup_requires_grad(model, config) + init_model_parameters(model, self.create_parameters(model, config), self.train_device) + + def setup_train_device( + self, + model: Flux2Model, + config: TrainConfig, + ): + vae_on_train_device = not config.latent_caching + text_encoder_on_train_device = not config.latent_caching + + model.text_encoder_to(self.train_device if text_encoder_on_train_device else self.temp_device) + model.vae_to(self.train_device if vae_on_train_device else self.temp_device) + model.transformer_to(self.train_device) + + model.text_encoder.eval() + model.vae.eval() + + if config.transformer.train: + model.transformer.train() + else: + model.transformer.eval() + + def after_optimizer_step( + self, + model: Flux2Model, + config: TrainConfig, + train_progress: TrainProgress + ): + self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, Flux2FineTuneSetup, ModelType.FLUX_DEV_2, TrainingMethod.FINE_TUNE) diff --git a/modules/modelSetup/Flux2LoRASetup.py b/modules/modelSetup/Flux2LoRASetup.py new file mode 100644 index 000000000..d9f9b428c --- /dev/null +++ b/modules/modelSetup/Flux2LoRASetup.py @@ -0,0 +1,101 @@ +from modules.model.Flux2Model import Flux2Model +from modules.modelSetup.BaseFlux2Setup import BaseFlux2Setup +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.module.LoRAModule import LoRAModuleWrapper +from modules.util import factory +from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod +from modules.util.NamedParameterGroup import NamedParameterGroupCollection +from modules.util.optimizer_util import init_model_parameters +from modules.util.TrainProgress import TrainProgress + +import torch + + +class Flux2LoRASetup( + BaseFlux2Setup, +): + def __init__( + self, + train_device: torch.device, + temp_device: torch.device, + debug_mode: bool, + ): + super().__init__( + train_device=train_device, + temp_device=temp_device, + debug_mode=debug_mode, + ) + + def create_parameters( + self, + model: Flux2Model, + config: TrainConfig, + ) -> NamedParameterGroupCollection: + parameter_group_collection = NamedParameterGroupCollection() + + self._create_model_part_parameters(parameter_group_collection, "transformer_lora", model.transformer_lora, config.transformer) + return parameter_group_collection + + def __setup_requires_grad( + self, + model: Flux2Model, + config: TrainConfig, + ): + model.text_encoder.requires_grad_(False) + model.transformer.requires_grad_(False) + model.vae.requires_grad_(False) + + self._setup_model_part_requires_grad("transformer_lora", model.transformer_lora, config.transformer, model.train_progress) + + def setup_model( + self, + model: Flux2Model, + config: TrainConfig, + ): + model.transformer_lora = LoRAModuleWrapper( + model.transformer, "lora_transformer", config, config.layer_filter.split(",") + ) + + if model.lora_state_dict: + model.transformer_lora.load_state_dict(model.lora_state_dict) + model.lora_state_dict = None + + model.transformer_lora.set_dropout(config.dropout_probability) + model.transformer_lora.to(dtype=config.lora_weight_dtype.torch_dtype()) + model.transformer_lora.hook_to_module() + + self.__setup_requires_grad(model, config) + + init_model_parameters(model, self.create_parameters(model, config), self.train_device) + + def setup_train_device( + self, + model: Flux2Model, + config: TrainConfig, + ): + vae_on_train_device = not config.latent_caching + text_encoder_on_train_device = not config.latent_caching + + model.text_encoder_to(self.train_device if text_encoder_on_train_device else self.temp_device) + model.vae_to(self.train_device if vae_on_train_device else self.temp_device) + model.transformer_to(self.train_device) + + model.text_encoder.eval() + model.vae.eval() + + if config.transformer.train: + model.transformer.train() + else: + model.transformer.eval() + + def after_optimizer_step( + self, + model: Flux2Model, + config: TrainConfig, + train_progress: TrainProgress + ): + self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, Flux2LoRASetup, ModelType.FLUX_DEV_2, TrainingMethod.LORA) diff --git a/modules/modelSetup/FluxEmbeddingSetup.py b/modules/modelSetup/FluxEmbeddingSetup.py index 7a71a98cc..7622c6c17 100644 --- a/modules/modelSetup/FluxEmbeddingSetup.py +++ b/modules/modelSetup/FluxEmbeddingSetup.py @@ -1,6 +1,10 @@ from modules.model.FluxModel import FluxModel from modules.modelSetup.BaseFluxSetup import BaseFluxSetup +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters from modules.util.TrainProgress import TrainProgress @@ -107,3 +111,6 @@ def after_optimizer_step( if model.embedding_wrapper_2 is not None: model.embedding_wrapper_2.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, FluxEmbeddingSetup, ModelType.FLUX_DEV_1, TrainingMethod.EMBEDDING) +factory.register(BaseModelSetup, FluxEmbeddingSetup, ModelType.FLUX_FILL_DEV_1, TrainingMethod.EMBEDDING) diff --git a/modules/modelSetup/FluxFineTuneSetup.py b/modules/modelSetup/FluxFineTuneSetup.py index 55ce6e4e7..8e09a9894 100644 --- a/modules/modelSetup/FluxFineTuneSetup.py +++ b/modules/modelSetup/FluxFineTuneSetup.py @@ -1,6 +1,10 @@ from modules.model.FluxModel import FluxModel from modules.modelSetup.BaseFluxSetup import BaseFluxSetup +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.ModuleFilter import ModuleFilter from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters @@ -135,3 +139,6 @@ def after_optimizer_step( if model.embedding_wrapper_2 is not None: model.embedding_wrapper_2.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, FluxFineTuneSetup, ModelType.FLUX_DEV_1, TrainingMethod.FINE_TUNE) +factory.register(BaseModelSetup, FluxFineTuneSetup, ModelType.FLUX_FILL_DEV_1, TrainingMethod.FINE_TUNE) diff --git a/modules/modelSetup/FluxLoRASetup.py b/modules/modelSetup/FluxLoRASetup.py index 0b58af8f4..2c9b2b12f 100644 --- a/modules/modelSetup/FluxLoRASetup.py +++ b/modules/modelSetup/FluxLoRASetup.py @@ -1,7 +1,11 @@ from modules.model.FluxModel import FluxModel from modules.modelSetup.BaseFluxSetup import BaseFluxSetup +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.module.LoRAModule import LoRAModuleWrapper +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters from modules.util.torch_util import state_dict_has_prefix @@ -177,3 +181,6 @@ def after_optimizer_step( if model.embedding_wrapper_2 is not None: model.embedding_wrapper_2.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, FluxLoRASetup, ModelType.FLUX_DEV_1, TrainingMethod.LORA) +factory.register(BaseModelSetup, FluxLoRASetup, ModelType.FLUX_FILL_DEV_1, TrainingMethod.LORA) diff --git a/modules/modelSetup/HiDreamEmbeddingSetup.py b/modules/modelSetup/HiDreamEmbeddingSetup.py index bea3304af..beb3b18ee 100644 --- a/modules/modelSetup/HiDreamEmbeddingSetup.py +++ b/modules/modelSetup/HiDreamEmbeddingSetup.py @@ -1,7 +1,10 @@ - from modules.model.HiDreamModel import HiDreamModel from modules.modelSetup.BaseHiDreamSetup import BaseHiDreamSetup +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters from modules.util.TrainProgress import TrainProgress @@ -137,3 +140,5 @@ def after_optimizer_step( if model.embedding_wrapper_4 is not None: model.embedding_wrapper_4.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, HiDreamEmbeddingSetup, ModelType.HI_DREAM_FULL, TrainingMethod.EMBEDDING) diff --git a/modules/modelSetup/HiDreamFineTuneSetup.py b/modules/modelSetup/HiDreamFineTuneSetup.py index b5ccf4698..9ed2fa1c8 100644 --- a/modules/modelSetup/HiDreamFineTuneSetup.py +++ b/modules/modelSetup/HiDreamFineTuneSetup.py @@ -2,7 +2,11 @@ from modules.model.HiDreamModel import HiDreamModel from modules.modelSetup.BaseHiDreamSetup import BaseHiDreamSetup +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.ModuleFilter import ModuleFilter from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters @@ -187,3 +191,5 @@ def after_optimizer_step( if model.embedding_wrapper_4 is not None: model.embedding_wrapper_4.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, HiDreamFineTuneSetup, ModelType.HI_DREAM_FULL, TrainingMethod.FINE_TUNE) diff --git a/modules/modelSetup/HiDreamLoRASetup.py b/modules/modelSetup/HiDreamLoRASetup.py index 5e4fded42..976df3ee9 100644 --- a/modules/modelSetup/HiDreamLoRASetup.py +++ b/modules/modelSetup/HiDreamLoRASetup.py @@ -2,8 +2,12 @@ from modules.model.HiDreamModel import HiDreamModel from modules.modelSetup.BaseHiDreamSetup import BaseHiDreamSetup +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.module.LoRAModule import LoRAModuleWrapper +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters from modules.util.torch_util import state_dict_has_prefix @@ -259,3 +263,5 @@ def after_optimizer_step( if model.embedding_wrapper_4 is not None: model.embedding_wrapper_4.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, HiDreamLoRASetup, ModelType.HI_DREAM_FULL, TrainingMethod.LORA) diff --git a/modules/modelSetup/HunyuanVideoEmbeddingSetup.py b/modules/modelSetup/HunyuanVideoEmbeddingSetup.py index 99c75d339..355a10cbd 100644 --- a/modules/modelSetup/HunyuanVideoEmbeddingSetup.py +++ b/modules/modelSetup/HunyuanVideoEmbeddingSetup.py @@ -1,7 +1,10 @@ - from modules.model.HunyuanVideoModel import HunyuanVideoModel from modules.modelSetup.BaseHunyuanVideoSetup import BaseHunyuanVideoSetup +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters from modules.util.TrainProgress import TrainProgress @@ -106,3 +109,5 @@ def after_optimizer_step( if model.embedding_wrapper_2 is not None: model.embedding_wrapper_2.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, HunyuanVideoEmbeddingSetup, ModelType.HUNYUAN_VIDEO, TrainingMethod.EMBEDDING) diff --git a/modules/modelSetup/HunyuanVideoFineTuneSetup.py b/modules/modelSetup/HunyuanVideoFineTuneSetup.py index a1057f748..2d81ed795 100644 --- a/modules/modelSetup/HunyuanVideoFineTuneSetup.py +++ b/modules/modelSetup/HunyuanVideoFineTuneSetup.py @@ -1,7 +1,10 @@ - from modules.model.HunyuanVideoModel import HunyuanVideoModel from modules.modelSetup.BaseHunyuanVideoSetup import BaseHunyuanVideoSetup +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.ModuleFilter import ModuleFilter from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters @@ -135,3 +138,5 @@ def after_optimizer_step( if model.embedding_wrapper_2 is not None: model.embedding_wrapper_2.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, HunyuanVideoFineTuneSetup, ModelType.HUNYUAN_VIDEO, TrainingMethod.FINE_TUNE) diff --git a/modules/modelSetup/HunyuanVideoLoRASetup.py b/modules/modelSetup/HunyuanVideoLoRASetup.py index 2652ba484..e856b539b 100644 --- a/modules/modelSetup/HunyuanVideoLoRASetup.py +++ b/modules/modelSetup/HunyuanVideoLoRASetup.py @@ -2,8 +2,12 @@ from modules.model.HunyuanVideoModel import HunyuanVideoModel from modules.modelSetup.BaseHunyuanVideoSetup import BaseHunyuanVideoSetup +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.module.LoRAModule import LoRAModuleWrapper +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters from modules.util.torch_util import state_dict_has_prefix @@ -181,3 +185,5 @@ def after_optimizer_step( model.embedding_wrapper_2.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, HunyuanVideoLoRASetup, ModelType.HUNYUAN_VIDEO, TrainingMethod.LORA) diff --git a/modules/modelSetup/PixArtAlphaEmbeddingSetup.py b/modules/modelSetup/PixArtAlphaEmbeddingSetup.py index b9ebd6c2d..f84e2f105 100644 --- a/modules/modelSetup/PixArtAlphaEmbeddingSetup.py +++ b/modules/modelSetup/PixArtAlphaEmbeddingSetup.py @@ -1,6 +1,10 @@ from modules.model.PixArtAlphaModel import PixArtAlphaModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.BasePixArtAlphaSetup import BasePixArtAlphaSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters from modules.util.TrainProgress import TrainProgress @@ -86,3 +90,6 @@ def after_optimizer_step( self._normalize_output_embeddings(model.all_text_encoder_embeddings()) model.embedding_wrapper.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, PixArtAlphaEmbeddingSetup, ModelType.PIXART_ALPHA, TrainingMethod.EMBEDDING) +factory.register(BaseModelSetup, PixArtAlphaEmbeddingSetup, ModelType.PIXART_SIGMA, TrainingMethod.EMBEDDING) diff --git a/modules/modelSetup/PixArtAlphaFineTuneSetup.py b/modules/modelSetup/PixArtAlphaFineTuneSetup.py index 4ad4911d3..be2aab4c6 100644 --- a/modules/modelSetup/PixArtAlphaFineTuneSetup.py +++ b/modules/modelSetup/PixArtAlphaFineTuneSetup.py @@ -1,6 +1,10 @@ from modules.model.PixArtAlphaModel import PixArtAlphaModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.BasePixArtAlphaSetup import BasePixArtAlphaSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.ModuleFilter import ModuleFilter from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters @@ -114,3 +118,6 @@ def after_optimizer_step( self._normalize_output_embeddings(model.all_text_encoder_embeddings()) model.embedding_wrapper.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, PixArtAlphaFineTuneSetup, ModelType.PIXART_ALPHA, TrainingMethod.FINE_TUNE) +factory.register(BaseModelSetup, PixArtAlphaFineTuneSetup, ModelType.PIXART_SIGMA, TrainingMethod.FINE_TUNE) diff --git a/modules/modelSetup/PixArtAlphaLoRASetup.py b/modules/modelSetup/PixArtAlphaLoRASetup.py index 451a43b13..8570bcfd7 100644 --- a/modules/modelSetup/PixArtAlphaLoRASetup.py +++ b/modules/modelSetup/PixArtAlphaLoRASetup.py @@ -1,7 +1,11 @@ from modules.model.PixArtAlphaModel import PixArtAlphaModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.BasePixArtAlphaSetup import BasePixArtAlphaSetup from modules.module.LoRAModule import LoRAModuleWrapper +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters from modules.util.torch_util import state_dict_has_prefix @@ -134,3 +138,6 @@ def after_optimizer_step( self._normalize_output_embeddings(model.all_text_encoder_embeddings()) model.embedding_wrapper.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, PixArtAlphaLoRASetup, ModelType.PIXART_ALPHA, TrainingMethod.LORA) +factory.register(BaseModelSetup, PixArtAlphaLoRASetup, ModelType.PIXART_SIGMA, TrainingMethod.LORA) diff --git a/modules/modelSetup/QwenFineTuneSetup.py b/modules/modelSetup/QwenFineTuneSetup.py index 2d305a1be..43d8962d7 100644 --- a/modules/modelSetup/QwenFineTuneSetup.py +++ b/modules/modelSetup/QwenFineTuneSetup.py @@ -1,6 +1,10 @@ from modules.model.QwenModel import QwenModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.BaseQwenSetup import BaseQwenSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.ModuleFilter import ModuleFilter from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters @@ -93,3 +97,5 @@ def after_optimizer_step( train_progress: TrainProgress ): self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, QwenFineTuneSetup, ModelType.QWEN, TrainingMethod.FINE_TUNE) diff --git a/modules/modelSetup/QwenLoRASetup.py b/modules/modelSetup/QwenLoRASetup.py index 644cd8656..8fcb5fdfb 100644 --- a/modules/modelSetup/QwenLoRASetup.py +++ b/modules/modelSetup/QwenLoRASetup.py @@ -1,7 +1,11 @@ from modules.model.QwenModel import QwenModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.BaseQwenSetup import BaseQwenSetup from modules.module.LoRAModule import LoRAModuleWrapper +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters from modules.util.torch_util import state_dict_has_prefix @@ -122,3 +126,5 @@ def after_optimizer_step( train_progress: TrainProgress ): self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, QwenLoRASetup, ModelType.QWEN, TrainingMethod.LORA) diff --git a/modules/modelSetup/SanaEmbeddingSetup.py b/modules/modelSetup/SanaEmbeddingSetup.py index 4d74d66b2..d73242680 100644 --- a/modules/modelSetup/SanaEmbeddingSetup.py +++ b/modules/modelSetup/SanaEmbeddingSetup.py @@ -1,6 +1,10 @@ from modules.model.SanaModel import SanaModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.BaseSanaSetup import BaseSanaSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters from modules.util.TrainProgress import TrainProgress @@ -86,3 +90,5 @@ def after_optimizer_step( self._normalize_output_embeddings(model.all_text_encoder_embeddings()) model.embedding_wrapper.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, SanaEmbeddingSetup, ModelType.SANA, TrainingMethod.EMBEDDING) diff --git a/modules/modelSetup/SanaFineTuneSetup.py b/modules/modelSetup/SanaFineTuneSetup.py index 110d442e4..63597c856 100644 --- a/modules/modelSetup/SanaFineTuneSetup.py +++ b/modules/modelSetup/SanaFineTuneSetup.py @@ -1,6 +1,10 @@ from modules.model.SanaModel import SanaModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.BaseSanaSetup import BaseSanaSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.ModuleFilter import ModuleFilter from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters @@ -108,3 +112,5 @@ def after_optimizer_step( self._normalize_output_embeddings(model.all_text_encoder_embeddings()) model.embedding_wrapper.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, SanaFineTuneSetup, ModelType.SANA, TrainingMethod.FINE_TUNE) diff --git a/modules/modelSetup/SanaLoRASetup.py b/modules/modelSetup/SanaLoRASetup.py index 0fd7c04c0..2bb8d266d 100644 --- a/modules/modelSetup/SanaLoRASetup.py +++ b/modules/modelSetup/SanaLoRASetup.py @@ -1,7 +1,11 @@ from modules.model.SanaModel import SanaModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.BaseSanaSetup import BaseSanaSetup from modules.module.LoRAModule import LoRAModuleWrapper +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters from modules.util.torch_util import state_dict_has_prefix @@ -134,3 +138,5 @@ def after_optimizer_step( self._normalize_output_embeddings(model.all_text_encoder_embeddings()) model.embedding_wrapper.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, SanaLoRASetup, ModelType.SANA, TrainingMethod.LORA) diff --git a/modules/modelSetup/StableDiffusion3EmbeddingSetup.py b/modules/modelSetup/StableDiffusion3EmbeddingSetup.py index 800451a79..01c70ba39 100644 --- a/modules/modelSetup/StableDiffusion3EmbeddingSetup.py +++ b/modules/modelSetup/StableDiffusion3EmbeddingSetup.py @@ -1,6 +1,10 @@ from modules.model.StableDiffusion3Model import StableDiffusion3Model +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.BaseStableDiffusion3Setup import BaseStableDiffusion3Setup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters from modules.util.TrainProgress import TrainProgress @@ -125,3 +129,6 @@ def after_optimizer_step( if model.embedding_wrapper_3 is not None: model.embedding_wrapper_3.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, StableDiffusion3EmbeddingSetup, ModelType.STABLE_DIFFUSION_3, TrainingMethod.EMBEDDING) +factory.register(BaseModelSetup, StableDiffusion3EmbeddingSetup, ModelType.STABLE_DIFFUSION_35, TrainingMethod.EMBEDDING) diff --git a/modules/modelSetup/StableDiffusion3FineTuneSetup.py b/modules/modelSetup/StableDiffusion3FineTuneSetup.py index d575a1f25..a3cc7a017 100644 --- a/modules/modelSetup/StableDiffusion3FineTuneSetup.py +++ b/modules/modelSetup/StableDiffusion3FineTuneSetup.py @@ -1,6 +1,10 @@ from modules.model.StableDiffusion3Model import StableDiffusion3Model +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.BaseStableDiffusion3Setup import BaseStableDiffusion3Setup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.ModuleFilter import ModuleFilter from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters @@ -161,3 +165,6 @@ def after_optimizer_step( if model.embedding_wrapper_3 is not None: model.embedding_wrapper_3.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, StableDiffusion3FineTuneSetup, ModelType.STABLE_DIFFUSION_3, TrainingMethod.FINE_TUNE) +factory.register(BaseModelSetup, StableDiffusion3FineTuneSetup, ModelType.STABLE_DIFFUSION_35, TrainingMethod.FINE_TUNE) diff --git a/modules/modelSetup/StableDiffusion3LoRASetup.py b/modules/modelSetup/StableDiffusion3LoRASetup.py index 7adabac36..86f472c23 100644 --- a/modules/modelSetup/StableDiffusion3LoRASetup.py +++ b/modules/modelSetup/StableDiffusion3LoRASetup.py @@ -1,7 +1,11 @@ from modules.model.StableDiffusion3Model import StableDiffusion3Model +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.BaseStableDiffusion3Setup import BaseStableDiffusion3Setup from modules.module.LoRAModule import LoRAModuleWrapper +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters from modules.util.torch_util import state_dict_has_prefix @@ -219,3 +223,6 @@ def after_optimizer_step( if model.embedding_wrapper_3 is not None: model.embedding_wrapper_3.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, StableDiffusion3LoRASetup, ModelType.STABLE_DIFFUSION_3, TrainingMethod.LORA) +factory.register(BaseModelSetup, StableDiffusion3LoRASetup, ModelType.STABLE_DIFFUSION_35, TrainingMethod.LORA) diff --git a/modules/modelSetup/StableDiffusionEmbeddingSetup.py b/modules/modelSetup/StableDiffusionEmbeddingSetup.py index bf1d88414..218d1a0a5 100644 --- a/modules/modelSetup/StableDiffusionEmbeddingSetup.py +++ b/modules/modelSetup/StableDiffusionEmbeddingSetup.py @@ -1,6 +1,10 @@ from modules.model.StableDiffusionModel import StableDiffusionModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.BaseStableDiffusionSetup import BaseStableDiffusionSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters from modules.util.TrainProgress import TrainProgress @@ -91,3 +95,12 @@ def after_optimizer_step( self._normalize_output_embeddings(model.all_text_encoder_embeddings()) model.embedding_wrapper.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, StableDiffusionEmbeddingSetup, ModelType.STABLE_DIFFUSION_15, TrainingMethod.EMBEDDING) +factory.register(BaseModelSetup, StableDiffusionEmbeddingSetup, ModelType.STABLE_DIFFUSION_15_INPAINTING, TrainingMethod.EMBEDDING) +factory.register(BaseModelSetup, StableDiffusionEmbeddingSetup, ModelType.STABLE_DIFFUSION_20, TrainingMethod.EMBEDDING) +factory.register(BaseModelSetup, StableDiffusionEmbeddingSetup, ModelType.STABLE_DIFFUSION_20_BASE, TrainingMethod.EMBEDDING) +factory.register(BaseModelSetup, StableDiffusionEmbeddingSetup, ModelType.STABLE_DIFFUSION_20_INPAINTING, TrainingMethod.EMBEDDING) +factory.register(BaseModelSetup, StableDiffusionEmbeddingSetup, ModelType.STABLE_DIFFUSION_20_DEPTH, TrainingMethod.EMBEDDING) +factory.register(BaseModelSetup, StableDiffusionEmbeddingSetup, ModelType.STABLE_DIFFUSION_21, TrainingMethod.EMBEDDING) +factory.register(BaseModelSetup, StableDiffusionEmbeddingSetup, ModelType.STABLE_DIFFUSION_21_BASE, TrainingMethod.EMBEDDING) diff --git a/modules/modelSetup/StableDiffusionFineTuneSetup.py b/modules/modelSetup/StableDiffusionFineTuneSetup.py index 11dacdd96..bf8374e75 100644 --- a/modules/modelSetup/StableDiffusionFineTuneSetup.py +++ b/modules/modelSetup/StableDiffusionFineTuneSetup.py @@ -1,6 +1,10 @@ from modules.model.StableDiffusionModel import StableDiffusionModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.BaseStableDiffusionSetup import BaseStableDiffusionSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.ModuleFilter import ModuleFilter from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters @@ -117,3 +121,12 @@ def after_optimizer_step( self._normalize_output_embeddings(model.all_text_encoder_embeddings()) model.embedding_wrapper.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, StableDiffusionFineTuneSetup, ModelType.STABLE_DIFFUSION_15, TrainingMethod.FINE_TUNE) +factory.register(BaseModelSetup, StableDiffusionFineTuneSetup, ModelType.STABLE_DIFFUSION_15_INPAINTING, TrainingMethod.FINE_TUNE) +factory.register(BaseModelSetup, StableDiffusionFineTuneSetup, ModelType.STABLE_DIFFUSION_20, TrainingMethod.FINE_TUNE) +factory.register(BaseModelSetup, StableDiffusionFineTuneSetup, ModelType.STABLE_DIFFUSION_20_BASE, TrainingMethod.FINE_TUNE) +factory.register(BaseModelSetup, StableDiffusionFineTuneSetup, ModelType.STABLE_DIFFUSION_20_INPAINTING, TrainingMethod.FINE_TUNE) +factory.register(BaseModelSetup, StableDiffusionFineTuneSetup, ModelType.STABLE_DIFFUSION_20_DEPTH, TrainingMethod.FINE_TUNE) +factory.register(BaseModelSetup, StableDiffusionFineTuneSetup, ModelType.STABLE_DIFFUSION_21, TrainingMethod.FINE_TUNE) +factory.register(BaseModelSetup, StableDiffusionFineTuneSetup, ModelType.STABLE_DIFFUSION_21_BASE, TrainingMethod.FINE_TUNE) diff --git a/modules/modelSetup/StableDiffusionFineTuneVaeSetup.py b/modules/modelSetup/StableDiffusionFineTuneVaeSetup.py index 375d3738e..0833528c5 100644 --- a/modules/modelSetup/StableDiffusionFineTuneVaeSetup.py +++ b/modules/modelSetup/StableDiffusionFineTuneVaeSetup.py @@ -1,6 +1,10 @@ from modules.model.StableDiffusionModel import StableDiffusionModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.BaseStableDiffusionSetup import BaseStableDiffusionSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.NamedParameterGroup import NamedParameterGroup, NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters from modules.util.TrainProgress import TrainProgress @@ -107,3 +111,12 @@ def after_optimizer_step( train_progress: TrainProgress ): pass + +factory.register(BaseModelSetup, StableDiffusionFineTuneVaeSetup, ModelType.STABLE_DIFFUSION_15, TrainingMethod.FINE_TUNE_VAE) +factory.register(BaseModelSetup, StableDiffusionFineTuneVaeSetup, ModelType.STABLE_DIFFUSION_15_INPAINTING, TrainingMethod.FINE_TUNE_VAE) +factory.register(BaseModelSetup, StableDiffusionFineTuneVaeSetup, ModelType.STABLE_DIFFUSION_20, TrainingMethod.FINE_TUNE_VAE) +factory.register(BaseModelSetup, StableDiffusionFineTuneVaeSetup, ModelType.STABLE_DIFFUSION_20_BASE, TrainingMethod.FINE_TUNE_VAE) +factory.register(BaseModelSetup, StableDiffusionFineTuneVaeSetup, ModelType.STABLE_DIFFUSION_20_INPAINTING, TrainingMethod.FINE_TUNE_VAE) +factory.register(BaseModelSetup, StableDiffusionFineTuneVaeSetup, ModelType.STABLE_DIFFUSION_20_DEPTH, TrainingMethod.FINE_TUNE_VAE) +factory.register(BaseModelSetup, StableDiffusionFineTuneVaeSetup, ModelType.STABLE_DIFFUSION_21, TrainingMethod.FINE_TUNE_VAE) +factory.register(BaseModelSetup, StableDiffusionFineTuneVaeSetup, ModelType.STABLE_DIFFUSION_21_BASE, TrainingMethod.FINE_TUNE_VAE) diff --git a/modules/modelSetup/StableDiffusionLoRASetup.py b/modules/modelSetup/StableDiffusionLoRASetup.py index 4c5ac2924..3bf45980b 100644 --- a/modules/modelSetup/StableDiffusionLoRASetup.py +++ b/modules/modelSetup/StableDiffusionLoRASetup.py @@ -1,7 +1,11 @@ from modules.model.StableDiffusionModel import StableDiffusionModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.BaseStableDiffusionSetup import BaseStableDiffusionSetup from modules.module.LoRAModule import LoRAModuleWrapper +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters from modules.util.torch_util import state_dict_has_prefix @@ -140,3 +144,12 @@ def after_optimizer_step( self._normalize_output_embeddings(model.all_text_encoder_embeddings()) model.embedding_wrapper.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, StableDiffusionLoRASetup, ModelType.STABLE_DIFFUSION_15, TrainingMethod.LORA) +factory.register(BaseModelSetup, StableDiffusionLoRASetup, ModelType.STABLE_DIFFUSION_15_INPAINTING, TrainingMethod.LORA) +factory.register(BaseModelSetup, StableDiffusionLoRASetup, ModelType.STABLE_DIFFUSION_20, TrainingMethod.LORA) +factory.register(BaseModelSetup, StableDiffusionLoRASetup, ModelType.STABLE_DIFFUSION_20_BASE, TrainingMethod.LORA) +factory.register(BaseModelSetup, StableDiffusionLoRASetup, ModelType.STABLE_DIFFUSION_20_INPAINTING, TrainingMethod.LORA) +factory.register(BaseModelSetup, StableDiffusionLoRASetup, ModelType.STABLE_DIFFUSION_20_DEPTH, TrainingMethod.LORA) +factory.register(BaseModelSetup, StableDiffusionLoRASetup, ModelType.STABLE_DIFFUSION_21, TrainingMethod.LORA) +factory.register(BaseModelSetup, StableDiffusionLoRASetup, ModelType.STABLE_DIFFUSION_21_BASE, TrainingMethod.LORA) diff --git a/modules/modelSetup/StableDiffusionXLEmbeddingSetup.py b/modules/modelSetup/StableDiffusionXLEmbeddingSetup.py index 3fddc0e5a..38341043d 100644 --- a/modules/modelSetup/StableDiffusionXLEmbeddingSetup.py +++ b/modules/modelSetup/StableDiffusionXLEmbeddingSetup.py @@ -1,6 +1,10 @@ from modules.model.StableDiffusionXLModel import StableDiffusionXLModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.BaseStableDiffusionXLSetup import BaseStableDiffusionXLSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters from modules.util.TrainProgress import TrainProgress @@ -104,3 +108,6 @@ def after_optimizer_step( model.embedding_wrapper_1.normalize_embeddings() model.embedding_wrapper_2.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, StableDiffusionXLEmbeddingSetup, ModelType.STABLE_DIFFUSION_XL_10_BASE, TrainingMethod.EMBEDDING) +factory.register(BaseModelSetup, StableDiffusionXLEmbeddingSetup, ModelType.STABLE_DIFFUSION_XL_10_BASE_INPAINTING, TrainingMethod.EMBEDDING) diff --git a/modules/modelSetup/StableDiffusionXLFineTuneSetup.py b/modules/modelSetup/StableDiffusionXLFineTuneSetup.py index 95b1044a5..91a17ce3b 100644 --- a/modules/modelSetup/StableDiffusionXLFineTuneSetup.py +++ b/modules/modelSetup/StableDiffusionXLFineTuneSetup.py @@ -1,6 +1,10 @@ from modules.model.StableDiffusionXLModel import StableDiffusionXLModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.BaseStableDiffusionXLSetup import BaseStableDiffusionXLSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.ModuleFilter import ModuleFilter from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters @@ -140,3 +144,6 @@ def after_optimizer_step( model.embedding_wrapper_1.normalize_embeddings() model.embedding_wrapper_2.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, StableDiffusionXLFineTuneSetup, ModelType.STABLE_DIFFUSION_XL_10_BASE, TrainingMethod.FINE_TUNE) +factory.register(BaseModelSetup, StableDiffusionXLFineTuneSetup, ModelType.STABLE_DIFFUSION_XL_10_BASE_INPAINTING, TrainingMethod.FINE_TUNE) diff --git a/modules/modelSetup/StableDiffusionXLLoRASetup.py b/modules/modelSetup/StableDiffusionXLLoRASetup.py index 560f2b44d..7c80cbfc5 100644 --- a/modules/modelSetup/StableDiffusionXLLoRASetup.py +++ b/modules/modelSetup/StableDiffusionXLLoRASetup.py @@ -1,7 +1,11 @@ from modules.model.StableDiffusionXLModel import StableDiffusionXLModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.BaseStableDiffusionXLSetup import BaseStableDiffusionXLSetup from modules.module.LoRAModule import LoRAModuleWrapper +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters from modules.util.torch_util import state_dict_has_prefix @@ -171,3 +175,6 @@ def after_optimizer_step( model.embedding_wrapper_1.normalize_embeddings() model.embedding_wrapper_2.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, StableDiffusionXLLoRASetup, ModelType.STABLE_DIFFUSION_XL_10_BASE, TrainingMethod.LORA) +factory.register(BaseModelSetup, StableDiffusionXLLoRASetup, ModelType.STABLE_DIFFUSION_XL_10_BASE_INPAINTING, TrainingMethod.LORA) diff --git a/modules/modelSetup/WuerstchenEmbeddingSetup.py b/modules/modelSetup/WuerstchenEmbeddingSetup.py index 06df76486..fd511d1b2 100644 --- a/modules/modelSetup/WuerstchenEmbeddingSetup.py +++ b/modules/modelSetup/WuerstchenEmbeddingSetup.py @@ -1,6 +1,10 @@ from modules.model.WuerstchenModel import WuerstchenModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.BaseWuerstchenSetup import BaseWuerstchenSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters from modules.util.TrainProgress import TrainProgress @@ -100,3 +104,6 @@ def after_optimizer_step( self._normalize_output_embeddings(model.all_prior_text_encoder_embeddings()) model.prior_embedding_wrapper.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, WuerstchenEmbeddingSetup, ModelType.WUERSTCHEN_2, TrainingMethod.EMBEDDING) +factory.register(BaseModelSetup, WuerstchenEmbeddingSetup, ModelType.STABLE_CASCADE_1, TrainingMethod.EMBEDDING) diff --git a/modules/modelSetup/WuerstchenFineTuneSetup.py b/modules/modelSetup/WuerstchenFineTuneSetup.py index 0f178db97..35631bc30 100644 --- a/modules/modelSetup/WuerstchenFineTuneSetup.py +++ b/modules/modelSetup/WuerstchenFineTuneSetup.py @@ -1,6 +1,10 @@ from modules.model.WuerstchenModel import WuerstchenModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.BaseWuerstchenSetup import BaseWuerstchenSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.ModuleFilter import ModuleFilter from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters @@ -121,3 +125,6 @@ def after_optimizer_step( self._normalize_output_embeddings(model.all_prior_text_encoder_embeddings()) model.prior_embedding_wrapper.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, WuerstchenFineTuneSetup, ModelType.WUERSTCHEN_2, TrainingMethod.FINE_TUNE) +factory.register(BaseModelSetup, WuerstchenFineTuneSetup, ModelType.STABLE_CASCADE_1, TrainingMethod.FINE_TUNE) diff --git a/modules/modelSetup/WuerstchenLoRASetup.py b/modules/modelSetup/WuerstchenLoRASetup.py index 90db8b9f6..756475295 100644 --- a/modules/modelSetup/WuerstchenLoRASetup.py +++ b/modules/modelSetup/WuerstchenLoRASetup.py @@ -1,7 +1,11 @@ from modules.model.WuerstchenModel import WuerstchenModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.BaseWuerstchenSetup import BaseWuerstchenSetup from modules.module.LoRAModule import LoRAModuleWrapper +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters from modules.util.torch_util import state_dict_has_prefix @@ -148,3 +152,6 @@ def after_optimizer_step( self._normalize_output_embeddings(model.all_prior_text_encoder_embeddings()) model.prior_embedding_wrapper.normalize_embeddings() self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, WuerstchenLoRASetup, ModelType.WUERSTCHEN_2, TrainingMethod.LORA) +factory.register(BaseModelSetup, WuerstchenLoRASetup, ModelType.STABLE_CASCADE_1, TrainingMethod.LORA) diff --git a/modules/modelSetup/ZImageFineTuneSetup.py b/modules/modelSetup/ZImageFineTuneSetup.py index b90770cd0..94dc885d6 100644 --- a/modules/modelSetup/ZImageFineTuneSetup.py +++ b/modules/modelSetup/ZImageFineTuneSetup.py @@ -1,6 +1,10 @@ from modules.model.ZImageModel import ZImageModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.BaseZImageSetup import BaseZImageSetup +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.ModuleFilter import ModuleFilter from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters @@ -80,3 +84,5 @@ def after_optimizer_step( train_progress: TrainProgress ): self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, ZImageFineTuneSetup, ModelType.Z_IMAGE, TrainingMethod.FINE_TUNE) diff --git a/modules/modelSetup/ZImageLoRASetup.py b/modules/modelSetup/ZImageLoRASetup.py index 4ce3974ea..9b357cd38 100644 --- a/modules/modelSetup/ZImageLoRASetup.py +++ b/modules/modelSetup/ZImageLoRASetup.py @@ -1,7 +1,11 @@ from modules.model.ZImageModel import ZImageModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.BaseZImageSetup import BaseZImageSetup from modules.module.LoRAModule import LoRAModuleWrapper +from modules.util import factory from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.NamedParameterGroup import NamedParameterGroupCollection from modules.util.optimizer_util import init_model_parameters from modules.util.TrainProgress import TrainProgress @@ -93,3 +97,5 @@ def after_optimizer_step( train_progress: TrainProgress ): self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, ZImageLoRASetup, ModelType.Z_IMAGE, TrainingMethod.LORA) diff --git a/modules/modelSetup/mixin/ModelSetupDebugMixin.py b/modules/modelSetup/mixin/ModelSetupDebugMixin.py index 82970cd94..958cc4d2e 100644 --- a/modules/modelSetup/mixin/ModelSetupDebugMixin.py +++ b/modules/modelSetup/mixin/ModelSetupDebugMixin.py @@ -1,6 +1,9 @@ import os from abc import ABCMeta +from modules.util.config.TrainConfig import TrainConfig +from modules.util.TrainProgress import TrainProgress + import torch from torch import Tensor from torchvision import transforms @@ -86,3 +89,11 @@ def _project_latent_to_image(self, latent_tensor: Tensor): result_max = result.max() result = (result - result_min) / (result_max - result_min) return result * 2 - 1 + + def _save_latent(self, name: str, latent: Tensor, config: TrainConfig, train_progress: TrainProgress): + directory = config.debug_dir + "/training_batches" + self._save_image(self._project_latent_to_image(latent), directory, name, train_progress.global_step) + + def _save_tokens(self, name: str, tokens: Tensor, tokenizer, config: TrainConfig, train_progress: TrainProgress): + directory = config.debug_dir + "/training_batches" + self._save_text(self._decode_tokens(tokens, tokenizer), directory, name, train_progress.global_step) diff --git a/modules/modelSetup/mixin/ModelSetupText2ImageMixin.py b/modules/modelSetup/mixin/ModelSetupText2ImageMixin.py new file mode 100644 index 000000000..75d17e69b --- /dev/null +++ b/modules/modelSetup/mixin/ModelSetupText2ImageMixin.py @@ -0,0 +1,23 @@ +from abc import ABCMeta, abstractmethod + +from modules.model.BaseModel import BaseModel +from modules.util.config.TrainConfig import TrainConfig + + +class ModelSetupText2ImageMixin(metaclass=ABCMeta): + @abstractmethod + def prepare_text_caching(model: BaseModel, config: TrainConfig): + pass + + #for future use in samplers etc. + '''@abstractmethod + def prepare_training(model: BaseModel): + pass + + @abstractmethod + def prepare_text_inference(model: BaseModel): + pass + + @abstractmethod + def prepare_image_inference(model: BaseModel): + pass''' diff --git a/modules/module/quantized/LinearW8A8.py b/modules/module/quantized/LinearW8A8.py index b73745879..3da6342ca 100644 --- a/modules/module/quantized/LinearW8A8.py +++ b/modules/module/quantized/LinearW8A8.py @@ -164,8 +164,8 @@ def torch_backward(a, b): run_benchmark(lambda: torch_backward(y_8, w_8), "torch mm backward int8") run_benchmark(lambda: mm_8bit(y_8, w_8), "triton mm backward int8") - run_benchmark(lambda: int8_forward_tokenwise(x, w_8, w_scale), "torch forward int", compile=True) - run_benchmark(lambda: int8_backward_axiswise(y, w_8, w_scale), "triton backward int", compile=True) + run_benchmark(lambda: int8_forward_tokenwise(x, w_8, w_scale, bias=None, compute_dtype=torch.bfloat16), "torch forward int", compile=True) + run_benchmark(lambda: int8_backward_axiswise(y, w_8, w_scale, bias=None, compute_dtype=torch.bfloat16), "triton backward int", compile=True) @torch.no_grad() diff --git a/modules/trainer/BaseTrainer.py b/modules/trainer/BaseTrainer.py index bef29dd78..20fc5eb7a 100644 --- a/modules/trainer/BaseTrainer.py +++ b/modules/trainer/BaseTrainer.py @@ -57,12 +57,13 @@ def create_model_setup(self) -> BaseModelSetup: self.config.debug_mode, ) - def create_data_loader(self, model: BaseModel, train_progress: TrainProgress, is_validation=False): + def create_data_loader(self, model: BaseModel, model_setup: BaseModelSetup, train_progress: TrainProgress, is_validation=False): return create.create_data_loader( self.train_device, self.temp_device, model, self.config.model_type, + model_setup, self.config.training_method, self.config, train_progress, diff --git a/modules/trainer/GenericTrainer.py b/modules/trainer/GenericTrainer.py index 13d1a7607..0607fe349 100644 --- a/modules/trainer/GenericTrainer.py +++ b/modules/trainer/GenericTrainer.py @@ -146,7 +146,7 @@ def start(self): self.callbacks.on_update_status("creating the data loader/caching") self.data_loader = self.create_data_loader( - self.model, self.model.train_progress + self.model, self.model_setup, self.model.train_progress ) self.model_saver = self.create_model_saver() diff --git a/modules/ui/ModelTab.py b/modules/ui/ModelTab.py index 278369e2b..7d52d1748 100644 --- a/modules/ui/ModelTab.py +++ b/modules/ui/ModelTab.py @@ -1,21 +1,11 @@ from pathlib import Path -from modules.modelSetup.BaseChromaSetup import PRESETS as chroma_presets -from modules.modelSetup.BaseFluxSetup import PRESETS as flux_presets -from modules.modelSetup.BaseHiDreamSetup import PRESETS as hidream_presets -from modules.modelSetup.BaseHunyuanVideoSetup import PRESETS as hunyuan_video_presets -from modules.modelSetup.BasePixArtAlphaSetup import PRESETS as pixart_presets -from modules.modelSetup.BaseQwenSetup import PRESETS as qwen_presets -from modules.modelSetup.BaseSanaSetup import PRESETS as sana_presets -from modules.modelSetup.BaseStableDiffusion3Setup import PRESETS as sd3_presets -from modules.modelSetup.BaseStableDiffusionSetup import PRESETS as sd_presets -from modules.modelSetup.BaseStableDiffusionXLSetup import PRESETS as sdxl_presets -from modules.modelSetup.BaseWuerstchenSetup import PRESETS as sc_presets -from modules.modelSetup.BaseZImageSetup import PRESETS as z_image_presets +from modules.util import create from modules.util.config.TrainConfig import TrainConfig from modules.util.enum.ConfigPart import ConfigPart from modules.util.enum.DataType import DataType from modules.util.enum.ModelFormat import ModelFormat +from modules.util.enum.ModelType import PeftType from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.ui import components from modules.util.ui.UIState import UIState @@ -66,8 +56,10 @@ def refresh_ui(self): self.__setup_wuerstchen_ui(base_frame) elif self.train_config.model_type.is_pixart(): self.__setup_pixart_alpha_ui(base_frame) - elif self.train_config.model_type.is_flux(): + elif self.train_config.model_type.is_flux_1(): self.__setup_flux_ui(base_frame) + elif self.train_config.model_type.is_flux_2(): + self.__setup_flux_2_ui(base_frame) elif self.train_config.model_type.is_z_image(): self.__setup_z_image_ui(base_frame) elif self.train_config.model_type.is_chroma(): @@ -142,6 +134,26 @@ def __setup_flux_ui(self, frame): allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) + def __setup_flux_2_ui(self, frame): + row = 0 + row = self.__create_base_dtype_components(frame, row) + row = self.__create_base_components( + frame, + row, + has_transformer=True, + allow_override_transformer=True, + has_text_encoder_1=True, + has_vae=True, + ) + row = self.__create_output_components( + frame, + row, + allow_safetensors=True, + allow_diffusers=self.train_config.training_method == TrainingMethod.FINE_TUNE, + allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, + allow_comfy=self.train_config.training_method == TrainingMethod.LORA and self.train_config.peft_type == PeftType.LORA, + ) + def __setup_z_image_ui(self, frame): row = 0 row = self.__create_base_dtype_components(frame, row) @@ -431,33 +443,8 @@ def __create_base_components( row += 1 - presets = [] - if self.train_config.model_type.is_stable_diffusion(): #TODO simplify and de-duplicate with layer filter on training tab - presets = sd_presets - elif self.train_config.model_type.is_stable_diffusion_xl(): - presets = sdxl_presets - elif self.train_config.model_type.is_stable_diffusion_3(): - presets = sd3_presets - elif self.train_config.model_type.is_wuerstchen(): - presets = sc_presets - elif self.train_config.model_type.is_pixart(): - presets = pixart_presets - elif self.train_config.model_type.is_flux(): - presets = flux_presets - elif self.train_config.model_type.is_qwen(): - presets = qwen_presets - elif self.train_config.model_type.is_chroma(): - presets = chroma_presets - elif self.train_config.model_type.is_sana(): - presets = sana_presets - elif self.train_config.model_type.is_hunyuan_video(): - presets = hunyuan_video_presets - elif self.train_config.model_type.is_z_image(): - presets = z_image_presets - elif self.train_config.model_type.is_hi_dream(): - presets = hidream_presets - else: - presets = {"full": []} + cls = create.get_model_setup_class(self.train_config.model_type, self.train_config.training_method) + presets = cls.LAYER_PRESETS if cls is not None else {"full": []} components.label(frame, row, 0, "Quantization") components.layer_filter_entry(frame, row, 1, self.ui_state, @@ -626,6 +613,7 @@ def __create_output_components( allow_safetensors: bool = False, allow_diffusers: bool = False, allow_legacy_safetensors: bool = False, + allow_comfy: bool = False, ) -> int: # output model destination components.label(frame, row, 0, "Model Output Destination", @@ -653,6 +641,8 @@ def __create_output_components( formats.append(("Diffusers", ModelFormat.DIFFUSERS)) # if allow_legacy_safetensors: # formats.append(("Legacy Safetensors", ModelFormat.LEGACY_SAFETENSORS)) + if allow_comfy: + formats.append(("Comfy", ModelFormat.COMFY_LORA)) components.label(frame, row, 0, "Output Format", tooltip="Format to use when saving the output model") diff --git a/modules/ui/TopBar.py b/modules/ui/TopBar.py index c53ea4160..00c211114 100644 --- a/modules/ui/TopBar.py +++ b/modules/ui/TopBar.py @@ -92,8 +92,9 @@ def __init__( ("Stable Cascade", ModelType.STABLE_CASCADE_1), ("PixArt Alpha", ModelType.PIXART_ALPHA), ("PixArt Sigma", ModelType.PIXART_SIGMA), - ("Flux Dev", ModelType.FLUX_DEV_1), + ("Flux Dev.1", ModelType.FLUX_DEV_1), ("Flux Fill Dev", ModelType.FLUX_FILL_DEV_1), + ("Flux Dev.2", ModelType.FLUX_DEV_2), ("Sana", ModelType.SANA), ("Hunyuan Video", ModelType.HUNYUAN_VIDEO), ("HiDream Full", ModelType.HI_DREAM_FULL), diff --git a/modules/ui/TrainingTab.py b/modules/ui/TrainingTab.py index 9cc2bcec8..2c564e8d3 100644 --- a/modules/ui/TrainingTab.py +++ b/modules/ui/TrainingTab.py @@ -1,19 +1,8 @@ -from modules.modelSetup.BaseChromaSetup import PRESETS as chroma_presets -from modules.modelSetup.BaseFluxSetup import PRESETS as flux_presets -from modules.modelSetup.BaseHiDreamSetup import PRESETS as hidream_presets -from modules.modelSetup.BaseHunyuanVideoSetup import PRESETS as hunyuan_video_presets -from modules.modelSetup.BasePixArtAlphaSetup import PRESETS as pixart_presets -from modules.modelSetup.BaseQwenSetup import PRESETS as qwen_presets -from modules.modelSetup.BaseSanaSetup import PRESETS as sana_presets -from modules.modelSetup.BaseStableDiffusion3Setup import PRESETS as sd3_presets -from modules.modelSetup.BaseStableDiffusionSetup import PRESETS as sd_presets -from modules.modelSetup.BaseStableDiffusionXLSetup import PRESETS as sdxl_presets -from modules.modelSetup.BaseWuerstchenSetup import PRESETS as sc_presets -from modules.modelSetup.BaseZImageSetup import PRESETS as z_image_presets from modules.ui.OffloadingWindow import OffloadingWindow from modules.ui.OptimizerParamsWindow import OptimizerParamsWindow from modules.ui.SchedulerParamsWindow import SchedulerParamsWindow from modules.ui.TimestepDistributionWindow import TimestepDistributionWindow +from modules.util import create from modules.util.config.TrainConfig import TrainConfig from modules.util.enum.DataType import DataType from modules.util.enum.EMAMode import EMAMode @@ -80,8 +69,10 @@ def refresh_ui(self): self.__setup_wuerstchen_ui(column_0, column_1, column_2) elif self.train_config.model_type.is_pixart(): self.__setup_pixart_alpha_ui(column_0, column_1, column_2) - elif self.train_config.model_type.is_flux(): + elif self.train_config.model_type.is_flux_1(): self.__setup_flux_ui(column_0, column_1, column_2) + elif self.train_config.model_type.is_flux_2(): + self.__setup_flux_2_ui(column_0, column_1, column_2) elif self.train_config.model_type.is_chroma(): self.__setup_chroma_ui(column_0, column_1, column_2) elif self.train_config.model_type.is_qwen(): @@ -178,6 +169,18 @@ def __setup_flux_ui(self, column_0, column_1, column_2): self.__create_loss_frame(column_2, 2) self.__create_layer_frame(column_2, 3) + def __setup_flux_2_ui(self, column_0, column_1, column_2): + self.__create_base_frame(column_0, 0) + self.__create_text_encoder_frame(column_0, 1, supports_clip_skip=False, supports_training=False, supports_sequence_length=True) + + self.__create_base2_frame(column_1, 0) + self.__create_transformer_frame(column_1, 1, supports_guidance_scale=True, supports_force_attention_mask=False) + self.__create_noise_frame(column_1, 2, supports_dynamic_timestep_shifting=True) + + self.__create_masked_frame(column_2, 1) + self.__create_loss_frame(column_2, 2) + self.__create_layer_frame(column_2, 3) + def __setup_chroma_ui(self, column_0, column_1, column_2): self.__create_base_frame(column_0, 0) self.__create_text_encoder_frame(column_0, 1) @@ -411,12 +414,11 @@ def __create_base2_frame(self, master, row, video_training_enabled: bool = False tooltip="Enables circular padding for all conv layers to better train seamless images") components.switch(frame, row, 1, self.ui_state, "force_circular_padding") - def __create_text_encoder_frame(self, master, row, supports_clip_skip=True, supports_training=True): + def __create_text_encoder_frame(self, master, row, supports_clip_skip=True, supports_training=True, supports_sequence_length=False): frame = ctk.CTkFrame(master=master, corner_radius=5) frame.grid(row=row, column=0, padx=5, pady=5, sticky="nsew") frame.grid_columnconfigure(0, weight=1) - # train text encoder if supports_training: components.label(frame, 0, 0, "Train Text Encoder", tooltip="Enables training the text encoder model") @@ -445,6 +447,13 @@ def __create_text_encoder_frame(self, master, row, supports_clip_skip=True, supp tooltip="The number of additional clip layers to skip. 0 = the model default") components.entry(frame, 4, 1, self.ui_state, "text_encoder_layer_skip") + if supports_sequence_length: + # text encoder sequence length + components.label(frame, row, 0, "Text Encoder Sequence Length", + tooltip="Number of tokens for captions") + components.entry(frame, row, 1, self.ui_state, "text_encoder_sequence_length") + row += 1 + def __create_text_encoder_n_frame( self, master, @@ -766,33 +775,8 @@ def __create_loss_frame(self, master, row, supports_vb_loss: bool = False): row += 1 def __create_layer_frame(self, master, row): - presets = [] - if self.train_config.model_type.is_stable_diffusion(): #TODO simplify - presets = sd_presets - elif self.train_config.model_type.is_stable_diffusion_xl(): - presets = sdxl_presets - elif self.train_config.model_type.is_stable_diffusion_3(): - presets = sd3_presets - elif self.train_config.model_type.is_wuerstchen(): - presets = sc_presets - elif self.train_config.model_type.is_pixart(): - presets = pixart_presets - elif self.train_config.model_type.is_flux(): - presets = flux_presets - elif self.train_config.model_type.is_qwen(): - presets = qwen_presets - elif self.train_config.model_type.is_chroma(): - presets = chroma_presets - elif self.train_config.model_type.is_sana(): - presets = sana_presets - elif self.train_config.model_type.is_hunyuan_video(): - presets = hunyuan_video_presets - elif self.train_config.model_type.is_hi_dream(): - presets = hidream_presets - elif self.train_config.model_type.is_z_image(): - presets = z_image_presets - else: - presets = {"full": []} + cls = create.get_model_setup_class(self.train_config.model_type, self.train_config.training_method) + presets = cls.LAYER_PRESETS if cls is not None else {"full": []} components.layer_filter_entry(master, row, 0, self.ui_state, preset_var_name="layer_filter_preset", presets=presets, preset_label="Layer Filter", diff --git a/modules/util/checkpointing_util.py b/modules/util/checkpointing_util.py index 133f97cf0..6ca75d8a3 100644 --- a/modules/util/checkpointing_util.py +++ b/modules/util/checkpointing_util.py @@ -25,6 +25,7 @@ from transformers.models.clip.modeling_clip import CLIPEncoderLayer from transformers.models.gemma2.modeling_gemma2 import Gemma2DecoderLayer from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from transformers.models.mistral.modeling_mistral import MistralDecoderLayer from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLDecoderLayer from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer from transformers.models.t5.modeling_t5 import T5Block @@ -111,7 +112,6 @@ def __init__(self, orig_module: nn.Module, orig_forward, train_device: torch.dev self.layer_index = layer_index def __checkpointing_forward(self, dummy: torch.Tensor, call_id: int, *args): - if self.layer_index == 0 and not torch.is_grad_enabled(): self.conductor.start_forward(True) @@ -131,7 +131,6 @@ def __checkpointing_forward(self, dummy: torch.Tensor, call_id: int, *args): def forward(self, *args, **kwargs): call_id = _generate_call_index() args = _kwargs_to_args(self.orig_forward if self.checkpoint is None else self.checkpoint.forward, args, kwargs) - if torch.is_grad_enabled(): return torch.utils.checkpoint.checkpoint( self.__checkpointing_forward, @@ -306,6 +305,16 @@ def enable_checkpointing_for_llama_encoder_layers( (LlamaDecoderLayer, []), ]) +def enable_checkpointing_for_mistral_encoder_layers( + model: nn.Module, + config: TrainConfig, +) -> LayerOffloadConductor: + return enable_checkpointing(model, config, False, [ + (MistralDecoderLayer, []), + ]) + + + def enable_checkpointing_for_qwen_encoder_layers( model: nn.Module, config: TrainConfig, @@ -339,6 +348,15 @@ def enable_checkpointing_for_flux_transformer( (model.single_transformer_blocks, ["hidden_states" ]), ]) +def enable_checkpointing_for_flux2_transformer( + model: nn.Module, + config: TrainConfig, +) -> LayerOffloadConductor: + return enable_checkpointing(model, config, config.compile, [ + (model.transformer_blocks, ["hidden_states", "encoder_hidden_states"]), + (model.single_transformer_blocks, ["hidden_states" ]), + ]) + def enable_checkpointing_for_chroma_transformer( model: nn.Module, diff --git a/modules/util/config/SampleConfig.py b/modules/util/config/SampleConfig.py index 4b72345f8..1b2aba652 100644 --- a/modules/util/config/SampleConfig.py +++ b/modules/util/config/SampleConfig.py @@ -19,6 +19,7 @@ class SampleConfig(BaseConfig): noise_scheduler: NoiseScheduler text_encoder_1_layer_skip: int + text_encoder_1_sequence_length: int | None text_encoder_2_layer_skip: int text_encoder_2_sequence_length: int | None text_encoder_3_layer_skip: int @@ -35,6 +36,7 @@ def __init__(self, data: list[(str, Any, type, bool)]): def from_train_config(self, train_config): self.text_encoder_1_layer_skip = train_config.text_encoder_layer_skip + self.text_encoder_1_sequence_length = train_config.text_encoder_sequence_length self.text_encoder_2_layer_skip = train_config.text_encoder_2_layer_skip self.text_encoder_2_sequence_length = train_config.text_encoder_2_sequence_length self.text_encoder_3_layer_skip = train_config.text_encoder_3_layer_skip @@ -60,6 +62,7 @@ def default_values(): data.append(("noise_scheduler", NoiseScheduler.DDIM, NoiseScheduler, False)) data.append(("text_encoder_1_layer_skip", 0, int, False)) + data.append(("text_encoder_1_sequence_length", None, int, True)) data.append(("text_encoder_2_layer_skip", 0, int, False)) data.append(("text_encoder_2_sequence_length", None, int, True)) data.append(("text_encoder_3_layer_skip", 0, int, False)) diff --git a/modules/util/config/TrainConfig.py b/modules/util/config/TrainConfig.py index ddaee4b89..3a24a10df 100644 --- a/modules/util/config/TrainConfig.py +++ b/modules/util/config/TrainConfig.py @@ -273,7 +273,7 @@ class TrainModelPartConfig(BaseConfig): stop_training_after_unit: TimeUnit learning_rate: float weight_dtype: DataType - dropout_probability: float + dropout_probability: float #this is text encoder caption dropout! train_embedding: bool attention_mask: bool guidance_scale: float @@ -430,7 +430,7 @@ class TrainConfig(BaseConfig): vb_loss_strength: float loss_weight_fn: LossWeight loss_weight_strength: float - dropout_probability: float + dropout_probability: float #this is LoRA dropout! loss_scaler: LossScaler learning_rate_scaler: LearningRateScaler clip_grad_norm: float @@ -872,6 +872,12 @@ def train_text_encoder_4_or_embedding(self) -> bool: or ((self.text_encoder_4.train_embedding or not self.model_type.has_multiple_text_encoders()) and self.train_any_embedding()) + def train_any_text_encoder_or_embedding(self) -> bool: + return (self.train_text_encoder_or_embedding() + or self.train_text_encoder_2_or_embedding() + or self.train_text_encoder_3_or_embedding() + or self.train_text_encoder_4_or_embedding()) + def all_embedding_configs(self): if self.training_method == TrainingMethod.EMBEDDING: return self.additional_embeddings + [self.embedding] @@ -1069,6 +1075,7 @@ def default_values() -> 'TrainConfig': text_encoder.learning_rate = None data.append(("text_encoder", text_encoder, TrainModelPartConfig, False)) data.append(("text_encoder_layer_skip", 0, int, False)) + data.append(("text_encoder_sequence_length", 512, int, True)) # text encoder 2 text_encoder_2 = TrainModelPartConfig.default_values() diff --git a/modules/util/convert_util.py b/modules/util/convert_util.py new file mode 100644 index 000000000..033ae9c71 --- /dev/null +++ b/modules/util/convert_util.py @@ -0,0 +1,294 @@ +from collections.abc import Callable +from dataclasses import dataclass + +import torch + +import parse + + +@dataclass +class ConversionPattern: + from_patterns: list[str] + to_patterns: list[str] + convert_fn: Callable | None + reverse_convert_fn: Callable | None + children : list["ConversionPattern"] + + +def _convert_item(in_key: str, input: dict, conversions: list[ConversionPattern], in_prefix: str="", out_prefix: str="", in_separator='.', out_separator='.'): + for conversion in conversions: + if conversion.children: + if len(conversion.from_patterns) > 1: + raise RuntimeError("Only leafs can have multiple from-patterns") + if len(conversion.to_patterns) > 1: + raise RuntimeError("Only leafs can have multiple to-patterns") + + match = parse.parse(in_prefix + conversion.from_patterns[0] + in_separator + "{post__}", in_key) + if match is None: + continue + child_in_prefix = in_prefix + conversion.from_patterns[0].format(*match.fixed, **match.named) + in_separator + child_out_prefix = out_prefix + conversion.to_patterns[0].format(*match.fixed, **match.named) + out_separator + return _convert_item(in_key, input, conversion.children, in_prefix=child_in_prefix, out_prefix=child_out_prefix, in_separator=in_separator, out_separator=out_separator) + else: + for pattern in conversion.from_patterns: + match = parse.parse(in_prefix + pattern, in_key) + if match is not None: + break + + if match is None: + for pattern in conversion.from_patterns: + match = parse.parse(in_prefix + pattern + in_separator + "{post__}", in_key) + if match is not None: + break + if match is None: + continue + in_postfix = in_separator + match.named['post__'] + out_postfix = out_separator + match.named['post__'] + else: + in_postfix = "" + out_postfix = "" + + in_keys = [] + in_values = [] + try: + for pattern in conversion.from_patterns: + new_in_key = in_prefix + pattern.format(*match.fixed, **match.named) + in_postfix + in_keys.append(new_in_key) + in_values.append(input[new_in_key]) + except KeyError: + #not a match, because not all from_patterns were found: + continue + + out_keys = [out_prefix + pattern.format(*match.fixed, **match.named) + out_postfix for pattern in conversion.to_patterns] + if conversion.convert_fn is not None: + out_values = conversion.convert_fn(*in_values) + if not isinstance(out_values, tuple): + out_values = (out_values, ) + + if len(out_values) != len(out_keys): + raise RuntimeError("convert_fn returned invalid number of outputs, for key " + in_key) + return in_keys, dict(zip(out_keys, out_values, strict=True)) + else: + if len(out_keys) > 1: + raise RuntimeError("A convert_fn must be provided if there are multiple to-patterns") + if len(in_keys) > 1: + raise RuntimeError("A convert_fn must be provided if there are multiple in-patterns") + return in_keys, { + out_keys[0]: in_values[0], + } + + return [in_key], None + +def _is_conversion_pattern_list(conversions: list): + return all(isinstance(entry, ConversionPattern) for entry in conversions) + +def _is_tuple_list(input: list): + return isinstance(input, list) and all(isinstance(entry, tuple) for entry in input) + +def _create_conversions_list(conversion_input: list): + if _is_tuple_list(conversion_input): + conversion_input = [conversion_input] + output = [] + for entry in conversion_input: + if _is_tuple_list(entry): + entry = _create_conversion_from_tuple_list(entry) + if _is_conversion_pattern_list(entry): + output.append(entry) + else: + raise RuntimeError("conversion input is invalid") + return output + + +def convert(input_orig: dict, conversion_input: list[ConversionPattern] | list, strict: bool=True, in_separator='.', out_separator='.'): + conversions_list = _create_conversions_list(conversion_input) + + input = input_orig.copy() + for conversions in conversions_list: + output = {} + while len(input) > 0: + in_key = next(iter(input)) + input_keys, output_items = _convert_item(in_key, input, conversions, in_separator=in_separator, out_separator=out_separator) + if output_items is None: + if strict: + raise RuntimeError("No conversion found for key ", + in_key) + if in_key in output and not output[in_key].equal(input[in_key]): + raise RuntimeError(f"key {in_key} was generated twice during conversion and is not equal") + output[in_key] = input[in_key] + else: + for k, v in output_items.items(): + if k in output and not torch.equal(v, output[k]): + raise RuntimeError(f"key {k} was generated twice during conversion and is not equal") + + output |= output_items + for k in input_keys: + input.pop(k) + + assert len(input) == 0 + input = output + + return output + + +def reverse_conversion_pattern(input: ConversionPattern): + if input.convert_fn is not None and input.reverse_convert_fn is None: + raise RuntimeError("Conversion cannot be reversed: no reverse_convert_fn defined") + + return ConversionPattern( + from_patterns=input.to_patterns, + to_patterns=input.from_patterns, + convert_fn=input.reverse_convert_fn, + reverse_convert_fn=input.convert_fn, + children=reverse_conversion(input.children), + ) + +def reverse_conversion(input: list[ConversionPattern]): + return [reverse_conversion_pattern(entry) for entry in input] + +def _create_pattern_list(input: str | list[str]): + pattern = input + if isinstance(pattern, str): + pattern = [pattern] + if not isinstance(pattern, list) or any(not isinstance(f, str) for f in pattern): + raise ValueError("conversion pattern must either be a string, or a list of strings") + return pattern + + +def _create_conversion_pattern_from_tuple(input: tuple | ConversionPattern): + if isinstance(input, ConversionPattern): + return input + if not isinstance(input, tuple) or len(input) < 2: + raise ValueError("conversion entry must be a tuple of at least 2 items") + + from_patterns = _create_pattern_list(input[0]) + if isinstance(input[1], list) and all(isinstance(entry, tuple) for entry in input[1]): + children_in = input[1] + to_patterns = from_patterns + else: + to_patterns = _create_pattern_list(input[1]) + children_in = input[2] if len(input) > 2 and isinstance(input[2], list) else None + + convert_fn = None + reverse_convert_fn = None + children = None + if children_in is not None: + children = _create_conversion_from_tuple_list(children_in) + elif len(input) > 2: + convert_fn = input[2] + reverse_convert_fn = input[3] if len(input) > 3 else None + + if (len(from_patterns) > 1 or len(to_patterns) > 1) and convert_fn is None: + raise ValueError("conversion entries with more than one to- or from-pattern require a convert function") + + return ConversionPattern(from_patterns, to_patterns, convert_fn, reverse_convert_fn, children) + +def _create_conversion_from_tuple_list(input: list): + return [_create_conversion_pattern_from_tuple(entry) for entry in input] + +def fuse_qkv(q, k, v): + return torch.cat([q, k, v], dim=0) + +def fuse_qkv_mlp(q, k, v, mlp): + return torch.cat([q, k, v, mlp], dim=0) + + +def remove_prefix(prefix: str | None = None): + if prefix is None: + prefix = "prefix__" + return [("{" + prefix + "}.{key}", "{key}")] + +def add_prefix(prefix: str, separator='.'): + return [("{}", prefix + separator + "{}")] + +def lora_fuse_qkv(q_up, q_down, q_alpha, k_up, k_down, k_alpha, v_up, v_down, v_alpha): + dim, rank = q_up.shape + qkv_up = torch.zeros( + 3 * dim, + 3 * rank, + device=q_up.device, + dtype=q_up.dtype, + ) + qkv_up[dim*0:dim*1, rank*0:rank*1] = q_up + qkv_up[dim*1:dim*2, rank*1:rank*2] = k_up + qkv_up[dim*2:dim*3, rank*2:rank*3] = v_up + qkv_down = torch.cat([q_down, k_down, v_down], dim=0) + + qkv_alpha = q_alpha * 3 + if q_alpha != k_alpha or q_alpha != v_alpha: + raise NotImplementedError("fused layers must have the same alpha") + + return qkv_up, qkv_down, qkv_alpha + +def lora_fuse_qkv_mlp(q_up, q_down, q_alpha, k_up, k_down, k_alpha, v_up, v_down, v_alpha, mlp_up, mlp_down, mlp_alpha): + dim, rank = q_up.shape + mlp_dim = mlp_up.shape[0] + qkv_up = torch.zeros( + 3 * dim + mlp_dim, + 4 * rank, + device=q_up.device, + dtype=q_up.dtype, + ) + qkv_up[dim*0:dim*1, rank*0:rank*1] = q_up + qkv_up[dim*1:dim*2, rank*1:rank*2] = k_up + qkv_up[dim*2:dim*3, rank*2:rank*3] = v_up + qkv_up[dim*3:, rank*3:rank*4] = mlp_up + qkv_down = torch.cat([q_down, k_down, v_down, mlp_down], dim=0) + + qkv_alpha = q_alpha * 4 + if q_alpha != k_alpha or q_alpha != v_alpha or q_alpha != mlp_alpha: + raise NotImplementedError("fused layers must have the same alpha") + + return qkv_up, qkv_down, qkv_alpha + +def lora_fuse_qkv_to_qkv_mlp(q_up, q_down, q_alpha, k_up, k_down, k_alpha, v_up, v_down, v_alpha): + #TODO where to get output shape from, if there is no MLP dim? + raise NotImplementedError + +def lora_fuse_mlp_to_qkv_mlp(mlp_up, mlp_down, mlp_alpha): + #TODO where to get output shape from, if there is no qkv dim? + raise NotImplementedError + +def swap_chunks(input: torch.Tensor, chunks: int=2, dim: int=0) -> torch.Tensor: + chunks = input.chunk(chunks, dim=dim) + return torch.cat(chunks, dim=dim) + +def lora_qkv_fusion(q: str, k: str, v: str, qkv: str): + return [ + ([f"{q}.lora_up.weight", f"{q}.lora_down.weight", f"{q}.alpha", + f"{k}.lora_up.weight", f"{k}.lora_down.weight", f"{k}.alpha", + f"{v}.lora_up.weight", f"{v}.lora_down.weight", f"{v}.alpha"], + [f"{qkv}.lora_up.weight", f"{qkv}.lora_down.weight", f"{qkv}.alpha"], lora_fuse_qkv), + ] + +def lora_qkv_mlp_fusion(q: str, k: str, v: str, mlp: str, qkv_mlp: str, separator: str='.'): + return [ + ([f"{q}.lora_up.weight", f"{q}.lora_down.weight", f"{q}.alpha", + f"{k}.lora_up.weight", f"{k}.lora_down.weight", f"{k}.alpha", + f"{v}.lora_up.weight", f"{v}.lora_down.weight", f"{v}.alpha", + f"{mlp}.lora_up.weight", f"{mlp}.lora_down.weight", f"{mlp}.alpha"], + [f"{qkv_mlp}.lora_up.weight", f"{qkv_mlp}.lora_down.weight", f"{qkv_mlp}.alpha"], lora_fuse_qkv_mlp + ), + + #qkv only, in case there are no mlp layers: + ([f"{q}.lora_up.weight", f"{q}.lora_down.weight", f"{q}.alpha", + f"{k}.lora_up.weight", f"{k}.lora_down.weight", f"{k}.alpha", + f"{v}.lora_up.weight", f"{v}.lora_down.weight", f"{v}.alpha"], + [f"{qkv_mlp}.lora_up.weight", f"{qkv_mlp}.lora_down.weight", f"{qkv_mlp}.alpha"], + lambda q_up, q_down, q_alpha, k_up, k_down, k_alpha, v_up, v_down, v_alpha: lora_fuse_qkv_to_qkv_mlp(q_up, q_down, q_alpha, k_up, k_down, k_alpha, v_up, v_down, v_alpha) + ), + + #mlp only, in case there are no qkv layers: + ([f"{mlp}.lora_up.weight", f"{mlp}.lora_down.weight", f"{mlp}.alpha"], + [f"{qkv_mlp}.lora_up.weight", f"{qkv_mlp}.lora_down.weight", f"{qkv_mlp}.alpha"], + lambda mlp_up, mlp_down, mlp_alpha: lora_fuse_mlp_to_qkv_mlp(mlp_up, mlp_down, mlp_alpha) + ), + ] + +def qkv_fusion(q: str, k: str, v: str, qkv: str, separator: str='.'): + return [ + ([q, k, v], qkv, fuse_qkv) + ] + +def qkv_mlp_fusion(q: str, k: str, v: str, mlp: str, qkv: str, separator: str='.'): + return [ + ([q, k, v, mlp], qkv, fuse_qkv_mlp) + ] diff --git a/modules/util/create.py b/modules/util/create.py index 4f0cd6e13..99cb74e1a 100644 --- a/modules/util/create.py +++ b/modules/util/create.py @@ -4,139 +4,13 @@ import modules.util.multi_gpu_util as multi from modules.dataLoader.BaseDataLoader import BaseDataLoader -from modules.dataLoader.ChromaBaseDataLoader import ChromaBaseDataLoader -from modules.dataLoader.FluxBaseDataLoader import FluxBaseDataLoader -from modules.dataLoader.HiDreamBaseDataLoader import HiDreamBaseDataLoader -from modules.dataLoader.HunyuanVideoBaseDataLoader import HunyuanVideoBaseDataLoader -from modules.dataLoader.PixArtAlphaBaseDataLoader import PixArtAlphaBaseDataLoader -from modules.dataLoader.QwenBaseDataLoader import QwenBaseDataLoader -from modules.dataLoader.SanaBaseDataLoader import SanaBaseDataLoader -from modules.dataLoader.StableDiffusion3BaseDataLoader import StableDiffusion3BaseDataLoader -from modules.dataLoader.StableDiffusionBaseDataLoader import StableDiffusionBaseDataLoader -from modules.dataLoader.StableDiffusionFineTuneVaeDataLoader import StableDiffusionFineTuneVaeDataLoader -from modules.dataLoader.StableDiffusionXLBaseDataLoader import StableDiffusionXLBaseDataLoader -from modules.dataLoader.WuerstchenBaseDataLoader import WuerstchenBaseDataLoader -from modules.dataLoader.ZImageBaseDataLoader import ZImageBaseDataLoader from modules.model.BaseModel import BaseModel from modules.modelLoader.BaseModelLoader import BaseModelLoader -from modules.modelLoader.ChromaEmbeddingModelLoader import ChromaEmbeddingModelLoader -from modules.modelLoader.ChromaFineTuneModelLoader import ChromaFineTuneModelLoader -from modules.modelLoader.ChromaLoRAModelLoader import ChromaLoRAModelLoader -from modules.modelLoader.FluxEmbeddingModelLoader import FluxEmbeddingModelLoader -from modules.modelLoader.FluxFineTuneModelLoader import FluxFineTuneModelLoader -from modules.modelLoader.FluxLoRAModelLoader import FluxLoRAModelLoader -from modules.modelLoader.HiDreamEmbeddingModelLoader import HiDreamEmbeddingModelLoader -from modules.modelLoader.HiDreamFineTuneModelLoader import HiDreamFineTuneModelLoader -from modules.modelLoader.HiDreamLoRAModelLoader import HiDreamLoRAModelLoader -from modules.modelLoader.HunyuanVideoEmbeddingModelLoader import HunyuanVideoEmbeddingModelLoader -from modules.modelLoader.HunyuanVideoFineTuneModelLoader import HunyuanVideoFineTuneModelLoader -from modules.modelLoader.HunyuanVideoLoRAModelLoader import HunyuanVideoLoRAModelLoader -from modules.modelLoader.PixArtAlphaEmbeddingModelLoader import PixArtAlphaEmbeddingModelLoader -from modules.modelLoader.PixArtAlphaFineTuneModelLoader import PixArtAlphaFineTuneModelLoader -from modules.modelLoader.PixArtAlphaLoRAModelLoader import PixArtAlphaLoRAModelLoader -from modules.modelLoader.QwenFineTuneModelLoader import QwenFineTuneModelLoader -from modules.modelLoader.QwenLoRAModelLoader import QwenLoRAModelLoader -from modules.modelLoader.SanaEmbeddingModelLoader import SanaEmbeddingModelLoader -from modules.modelLoader.SanaFineTuneModelLoader import SanaFineTuneModelLoader -from modules.modelLoader.SanaLoRAModelLoader import SanaLoRAModelLoader -from modules.modelLoader.StableDiffusion3EmbeddingModelLoader import StableDiffusion3EmbeddingModelLoader -from modules.modelLoader.StableDiffusion3FineTuneModelLoader import StableDiffusion3FineTuneModelLoader -from modules.modelLoader.StableDiffusion3LoRAModelLoader import StableDiffusion3LoRAModelLoader -from modules.modelLoader.StableDiffusionEmbeddingModelLoader import StableDiffusionEmbeddingModelLoader -from modules.modelLoader.StableDiffusionFineTuneModelLoader import StableDiffusionFineTuneModelLoader -from modules.modelLoader.StableDiffusionLoRAModelLoader import StableDiffusionLoRAModelLoader -from modules.modelLoader.StableDiffusionXLEmbeddingModelLoader import StableDiffusionXLEmbeddingModelLoader -from modules.modelLoader.StableDiffusionXLFineTuneModelLoader import StableDiffusionXLFineTuneModelLoader -from modules.modelLoader.StableDiffusionXLLoRAModelLoader import StableDiffusionXLLoRAModelLoader -from modules.modelLoader.WuerstchenEmbeddingModelLoader import WuerstchenEmbeddingModelLoader -from modules.modelLoader.WuerstchenFineTuneModelLoader import WuerstchenFineTuneModelLoader -from modules.modelLoader.WuerstchenLoRAModelLoader import WuerstchenLoRAModelLoader -from modules.modelLoader.ZImageModelLoader import ZImageFineTuneModelLoader, ZImageLoRAModelLoader -from modules.modelSampler import BaseModelSampler -from modules.modelSampler.ChromaSampler import ChromaSampler -from modules.modelSampler.FluxSampler import FluxSampler -from modules.modelSampler.HiDreamSampler import HiDreamSampler -from modules.modelSampler.HunyuanVideoSampler import HunyuanVideoSampler -from modules.modelSampler.PixArtAlphaSampler import PixArtAlphaSampler -from modules.modelSampler.QwenSampler import QwenSampler -from modules.modelSampler.SanaSampler import SanaSampler -from modules.modelSampler.StableDiffusion3Sampler import StableDiffusion3Sampler -from modules.modelSampler.StableDiffusionSampler import StableDiffusionSampler -from modules.modelSampler.StableDiffusionVaeSampler import StableDiffusionVaeSampler -from modules.modelSampler.StableDiffusionXLSampler import StableDiffusionXLSampler -from modules.modelSampler.WuerstchenSampler import WuerstchenSampler -from modules.modelSampler.ZImageSampler import ZImageSampler +from modules.modelSampler.BaseModelSampler import BaseModelSampler from modules.modelSaver.BaseModelSaver import BaseModelSaver -from modules.modelSaver.ChromaEmbeddingModelSaver import ChromaEmbeddingModelSaver -from modules.modelSaver.ChromaFineTuneModelSaver import ChromaFineTuneModelSaver -from modules.modelSaver.ChromaLoRAModelSaver import ChromaLoRAModelSaver -from modules.modelSaver.FluxEmbeddingModelSaver import FluxEmbeddingModelSaver -from modules.modelSaver.FluxFineTuneModelSaver import FluxFineTuneModelSaver -from modules.modelSaver.FluxLoRAModelSaver import FluxLoRAModelSaver -from modules.modelSaver.HiDreamEmbeddingModelSaver import HiDreamEmbeddingModelSaver -from modules.modelSaver.HiDreamLoRAModelSaver import HiDreamLoRAModelSaver -from modules.modelSaver.HunyuanVideoEmbeddingModelSaver import HunyuanVideoEmbeddingModelSaver -from modules.modelSaver.HunyuanVideoFineTuneModelSaver import HunyuanVideoFineTuneModelSaver -from modules.modelSaver.HunyuanVideoLoRAModelSaver import HunyuanVideoLoRAModelSaver -from modules.modelSaver.PixArtAlphaEmbeddingModelSaver import PixArtAlphaEmbeddingModelSaver -from modules.modelSaver.PixArtAlphaFineTuneModelSaver import PixArtAlphaFineTuneModelSaver -from modules.modelSaver.PixArtAlphaLoRAModelSaver import PixArtAlphaLoRAModelSaver -from modules.modelSaver.QwenFineTuneModelSaver import QwenFineTuneModelSaver -from modules.modelSaver.QwenLoRAModelSaver import QwenLoRAModelSaver -from modules.modelSaver.SanaEmbeddingModelSaver import SanaEmbeddingModelSaver -from modules.modelSaver.SanaFineTuneModelSaver import SanaFineTuneModelSaver -from modules.modelSaver.SanaLoRAModelSaver import SanaLoRAModelSaver -from modules.modelSaver.StableDiffusion3EmbeddingModelSaver import StableDiffusion3EmbeddingModelSaver -from modules.modelSaver.StableDiffusion3FineTuneModelSaver import StableDiffusion3FineTuneModelSaver -from modules.modelSaver.StableDiffusion3LoRAModelSaver import StableDiffusion3LoRAModelSaver -from modules.modelSaver.StableDiffusionEmbeddingModelSaver import StableDiffusionEmbeddingModelSaver -from modules.modelSaver.StableDiffusionFineTuneModelSaver import StableDiffusionFineTuneModelSaver -from modules.modelSaver.StableDiffusionLoRAModelSaver import StableDiffusionLoRAModelSaver -from modules.modelSaver.StableDiffusionXLEmbeddingModelSaver import StableDiffusionXLEmbeddingModelSaver -from modules.modelSaver.StableDiffusionXLFineTuneModelSaver import StableDiffusionXLFineTuneModelSaver -from modules.modelSaver.StableDiffusionXLLoRAModelSaver import StableDiffusionXLLoRAModelSaver -from modules.modelSaver.WuerstchenEmbeddingModelSaver import WuerstchenEmbeddingModelSaver -from modules.modelSaver.WuerstchenFineTuneModelSaver import WuerstchenFineTuneModelSaver -from modules.modelSaver.WuerstchenLoRAModelSaver import WuerstchenLoRAModelSaver -from modules.modelSaver.ZImageFineTuneModelSaver import ZImageFineTuneModelSaver -from modules.modelSaver.ZImageLoRAModelSaver import ZImageLoRAModelSaver from modules.modelSetup.BaseModelSetup import BaseModelSetup -from modules.modelSetup.ChromaEmbeddingSetup import ChromaEmbeddingSetup -from modules.modelSetup.ChromaFineTuneSetup import ChromaFineTuneSetup -from modules.modelSetup.ChromaLoRASetup import ChromaLoRASetup -from modules.modelSetup.FluxEmbeddingSetup import FluxEmbeddingSetup -from modules.modelSetup.FluxFineTuneSetup import FluxFineTuneSetup -from modules.modelSetup.FluxLoRASetup import FluxLoRASetup -from modules.modelSetup.HiDreamEmbeddingSetup import HiDreamEmbeddingSetup -from modules.modelSetup.HiDreamFineTuneSetup import HiDreamFineTuneSetup -from modules.modelSetup.HiDreamLoRASetup import HiDreamLoRASetup -from modules.modelSetup.HunyuanVideoEmbeddingSetup import HunyuanVideoEmbeddingSetup -from modules.modelSetup.HunyuanVideoFineTuneSetup import HunyuanVideoFineTuneSetup -from modules.modelSetup.HunyuanVideoLoRASetup import HunyuanVideoLoRASetup -from modules.modelSetup.PixArtAlphaEmbeddingSetup import PixArtAlphaEmbeddingSetup -from modules.modelSetup.PixArtAlphaFineTuneSetup import PixArtAlphaFineTuneSetup -from modules.modelSetup.PixArtAlphaLoRASetup import PixArtAlphaLoRASetup -from modules.modelSetup.QwenFineTuneSetup import QwenFineTuneSetup -from modules.modelSetup.QwenLoRASetup import QwenLoRASetup -from modules.modelSetup.SanaEmbeddingSetup import SanaEmbeddingSetup -from modules.modelSetup.SanaFineTuneSetup import SanaFineTuneSetup -from modules.modelSetup.SanaLoRASetup import SanaLoRASetup -from modules.modelSetup.StableDiffusion3EmbeddingSetup import StableDiffusion3EmbeddingSetup -from modules.modelSetup.StableDiffusion3FineTuneSetup import StableDiffusion3FineTuneSetup -from modules.modelSetup.StableDiffusion3LoRASetup import StableDiffusion3LoRASetup -from modules.modelSetup.StableDiffusionEmbeddingSetup import StableDiffusionEmbeddingSetup -from modules.modelSetup.StableDiffusionFineTuneSetup import StableDiffusionFineTuneSetup -from modules.modelSetup.StableDiffusionFineTuneVaeSetup import StableDiffusionFineTuneVaeSetup -from modules.modelSetup.StableDiffusionLoRASetup import StableDiffusionLoRASetup -from modules.modelSetup.StableDiffusionXLEmbeddingSetup import StableDiffusionXLEmbeddingSetup -from modules.modelSetup.StableDiffusionXLFineTuneSetup import StableDiffusionXLFineTuneSetup -from modules.modelSetup.StableDiffusionXLLoRASetup import StableDiffusionXLLoRASetup -from modules.modelSetup.WuerstchenEmbeddingSetup import WuerstchenEmbeddingSetup -from modules.modelSetup.WuerstchenFineTuneSetup import WuerstchenFineTuneSetup -from modules.modelSetup.WuerstchenLoRASetup import WuerstchenLoRASetup -from modules.modelSetup.ZImageFineTuneSetup import ZImageFineTuneSetup -from modules.modelSetup.ZImageLoRASetup import ZImageLoRASetup from modules.module.EMAModule import EMAModuleWrapper +from modules.util import factory from modules.util.callbacks.TrainCallbacks import TrainCallbacks from modules.util.commands.TrainCommands import TrainCommands from modules.util.config.TrainConfig import TrainConfig @@ -176,170 +50,32 @@ UniPCMultistepScheduler, ) +factory.import_dir("modules/modelSampler", "modules.modelSampler") +factory.import_dir("modules/modelLoader", "modules.modelLoader") +factory.import_dir("modules/modelSaver", "modules.modelSaver") +factory.import_dir("modules/modelSetup", "modules.modelSetup") +factory.import_dir("modules/dataLoader", "modules.dataLoader") def create_model_loader( model_type: ModelType, training_method: TrainingMethod = TrainingMethod.FINE_TUNE, ) -> BaseModelLoader | None: - match training_method: #TODO simplify - case TrainingMethod.FINE_TUNE: - if model_type.is_stable_diffusion(): - return StableDiffusionFineTuneModelLoader() - if model_type.is_stable_diffusion_xl(): - return StableDiffusionXLFineTuneModelLoader() - if model_type.is_wuerstchen(): - return WuerstchenFineTuneModelLoader() - if model_type.is_pixart(): - return PixArtAlphaFineTuneModelLoader() - if model_type.is_stable_diffusion_3(): - return StableDiffusion3FineTuneModelLoader() - if model_type.is_flux(): - return FluxFineTuneModelLoader() - if model_type.is_chroma(): - return ChromaFineTuneModelLoader() - if model_type.is_qwen(): - return QwenFineTuneModelLoader() - if model_type.is_z_image(): - return ZImageFineTuneModelLoader() - if model_type.is_sana(): - return SanaFineTuneModelLoader() - if model_type.is_hunyuan_video(): - return HunyuanVideoFineTuneModelLoader() - if model_type.is_hi_dream(): - return HiDreamFineTuneModelLoader() - case TrainingMethod.FINE_TUNE_VAE: - if model_type.is_stable_diffusion(): - return StableDiffusionFineTuneModelLoader() - case TrainingMethod.LORA: - if model_type.is_stable_diffusion(): - return StableDiffusionLoRAModelLoader() - if model_type.is_stable_diffusion_xl(): - return StableDiffusionXLLoRAModelLoader() - if model_type.is_wuerstchen(): - return WuerstchenLoRAModelLoader() - if model_type.is_pixart(): - return PixArtAlphaLoRAModelLoader() - if model_type.is_stable_diffusion_3(): - return StableDiffusion3LoRAModelLoader() - if model_type.is_flux(): - return FluxLoRAModelLoader() - if model_type.is_chroma(): - return ChromaLoRAModelLoader() - if model_type.is_qwen(): - return QwenLoRAModelLoader() - if model_type.is_z_image(): - return ZImageLoRAModelLoader() - if model_type.is_sana(): - return SanaLoRAModelLoader() - if model_type.is_hunyuan_video(): - return HunyuanVideoLoRAModelLoader() - if model_type.is_hi_dream(): - return HiDreamLoRAModelLoader() - case TrainingMethod.EMBEDDING: - if model_type.is_stable_diffusion(): - return StableDiffusionEmbeddingModelLoader() - if model_type.is_stable_diffusion_xl(): - return StableDiffusionXLEmbeddingModelLoader() - if model_type.is_wuerstchen(): - return WuerstchenEmbeddingModelLoader() - if model_type.is_pixart(): - return PixArtAlphaEmbeddingModelLoader() - if model_type.is_stable_diffusion_3(): - return StableDiffusion3EmbeddingModelLoader() - if model_type.is_flux(): - return FluxEmbeddingModelLoader() - if model_type.is_chroma(): - return ChromaEmbeddingModelLoader() - if model_type.is_sana(): - return SanaEmbeddingModelLoader() - if model_type.is_hunyuan_video(): - return HunyuanVideoEmbeddingModelLoader() - if model_type.is_hi_dream(): - return HiDreamEmbeddingModelLoader() - - return None + cls = factory.get(BaseModelLoader, model_type, training_method) + return cls() if cls is not None else None def create_model_saver( model_type: ModelType, training_method: TrainingMethod = TrainingMethod.FINE_TUNE, ) -> BaseModelSaver | None: - match training_method: - case TrainingMethod.FINE_TUNE: - if model_type.is_stable_diffusion(): - return StableDiffusionFineTuneModelSaver() - if model_type.is_stable_diffusion_xl(): - return StableDiffusionXLFineTuneModelSaver() - if model_type.is_wuerstchen(): - return WuerstchenFineTuneModelSaver() - if model_type.is_pixart(): - return PixArtAlphaFineTuneModelSaver() - if model_type.is_stable_diffusion_3(): - return StableDiffusion3FineTuneModelSaver() - if model_type.is_flux(): - return FluxFineTuneModelSaver() - if model_type.is_chroma(): - return ChromaFineTuneModelSaver() - if model_type.is_qwen(): - return QwenFineTuneModelSaver() - if model_type.is_z_image(): - return ZImageFineTuneModelSaver() - if model_type.is_sana(): - return SanaFineTuneModelSaver() - if model_type.is_hunyuan_video(): - return HunyuanVideoFineTuneModelSaver() - case TrainingMethod.FINE_TUNE_VAE: - if model_type.is_stable_diffusion(): - return StableDiffusionFineTuneModelSaver() - case TrainingMethod.LORA: - if model_type.is_stable_diffusion(): - return StableDiffusionLoRAModelSaver() - if model_type.is_stable_diffusion_xl(): - return StableDiffusionXLLoRAModelSaver() - if model_type.is_wuerstchen(): - return WuerstchenLoRAModelSaver() - if model_type.is_pixart(): - return PixArtAlphaLoRAModelSaver() - if model_type.is_stable_diffusion_3(): - return StableDiffusion3LoRAModelSaver() - if model_type.is_flux(): - return FluxLoRAModelSaver() - if model_type.is_chroma(): - return ChromaLoRAModelSaver() - if model_type.is_qwen(): - return QwenLoRAModelSaver() - if model_type.is_z_image(): - return ZImageLoRAModelSaver() - if model_type.is_sana(): - return SanaLoRAModelSaver() - if model_type.is_hunyuan_video(): - return HunyuanVideoLoRAModelSaver() - if model_type.is_hi_dream(): - return HiDreamLoRAModelSaver() - case TrainingMethod.EMBEDDING: - if model_type.is_stable_diffusion(): - return StableDiffusionEmbeddingModelSaver() - if model_type.is_stable_diffusion_xl(): - return StableDiffusionXLEmbeddingModelSaver() - if model_type.is_wuerstchen(): - return WuerstchenEmbeddingModelSaver() - if model_type.is_pixart(): - return PixArtAlphaEmbeddingModelSaver() - if model_type.is_stable_diffusion_3(): - return StableDiffusion3EmbeddingModelSaver() - if model_type.is_flux(): - return FluxEmbeddingModelSaver() - if model_type.is_chroma(): - return ChromaEmbeddingModelSaver() - if model_type.is_sana(): - return SanaEmbeddingModelSaver() - if model_type.is_hunyuan_video(): - return HunyuanVideoEmbeddingModelSaver() - if model_type.is_hi_dream(): - return HiDreamEmbeddingModelSaver() - - return None + cls = factory.get(BaseModelSaver, model_type, training_method) + return cls() if cls is not None else None +def get_model_setup_class( + model_type: ModelType, + training_method: TrainingMethod = TrainingMethod.FINE_TUNE, +) -> type | None: + return factory.get(BaseModelSetup, model_type, training_method) def create_model_setup( model_type: ModelType, @@ -348,84 +84,8 @@ def create_model_setup( training_method: TrainingMethod = TrainingMethod.FINE_TUNE, debug_mode: bool = False, ) -> BaseModelSetup | None: - match training_method: - case TrainingMethod.FINE_TUNE: - if model_type.is_stable_diffusion(): - return StableDiffusionFineTuneSetup(train_device, temp_device, debug_mode) - if model_type.is_stable_diffusion_xl(): - return StableDiffusionXLFineTuneSetup(train_device, temp_device, debug_mode) - if model_type.is_wuerstchen(): - return WuerstchenFineTuneSetup(train_device, temp_device, debug_mode) - if model_type.is_pixart(): - return PixArtAlphaFineTuneSetup(train_device, temp_device, debug_mode) - if model_type.is_stable_diffusion_3(): - return StableDiffusion3FineTuneSetup(train_device, temp_device, debug_mode) - if model_type.is_flux(): - return FluxFineTuneSetup(train_device, temp_device, debug_mode) - if model_type.is_chroma(): - return ChromaFineTuneSetup(train_device, temp_device, debug_mode) - if model_type.is_qwen(): - return QwenFineTuneSetup(train_device, temp_device, debug_mode) - if model_type.is_z_image(): - return ZImageFineTuneSetup(train_device, temp_device, debug_mode) - if model_type.is_sana(): - return SanaFineTuneSetup(train_device, temp_device, debug_mode) - if model_type.is_hunyuan_video(): - return HunyuanVideoFineTuneSetup(train_device, temp_device, debug_mode) - if model_type.is_hi_dream(): - return HiDreamFineTuneSetup(train_device, temp_device, debug_mode) - case TrainingMethod.FINE_TUNE_VAE: - if model_type.is_stable_diffusion(): - return StableDiffusionFineTuneVaeSetup(train_device, temp_device, debug_mode) - case TrainingMethod.LORA: - if model_type.is_stable_diffusion(): - return StableDiffusionLoRASetup(train_device, temp_device, debug_mode) - if model_type.is_stable_diffusion_xl(): - return StableDiffusionXLLoRASetup(train_device, temp_device, debug_mode) - if model_type.is_wuerstchen(): - return WuerstchenLoRASetup(train_device, temp_device, debug_mode) - if model_type.is_pixart(): - return PixArtAlphaLoRASetup(train_device, temp_device, debug_mode) - if model_type.is_stable_diffusion_3(): - return StableDiffusion3LoRASetup(train_device, temp_device, debug_mode) - if model_type.is_flux(): - return FluxLoRASetup(train_device, temp_device, debug_mode) - if model_type.is_chroma(): - return ChromaLoRASetup(train_device, temp_device, debug_mode) - if model_type.is_qwen(): - return QwenLoRASetup(train_device, temp_device, debug_mode) - if model_type.is_z_image(): - return ZImageLoRASetup(train_device, temp_device, debug_mode) - if model_type.is_sana(): - return SanaLoRASetup(train_device, temp_device, debug_mode) - if model_type.is_hunyuan_video(): - return HunyuanVideoLoRASetup(train_device, temp_device, debug_mode) - if model_type.is_hi_dream(): - return HiDreamLoRASetup(train_device, temp_device, debug_mode) - case TrainingMethod.EMBEDDING: - if model_type.is_stable_diffusion(): - return StableDiffusionEmbeddingSetup(train_device, temp_device, debug_mode) - if model_type.is_stable_diffusion_xl(): - return StableDiffusionXLEmbeddingSetup(train_device, temp_device, debug_mode) - if model_type.is_wuerstchen(): - return WuerstchenEmbeddingSetup(train_device, temp_device, debug_mode) - if model_type.is_pixart(): - return PixArtAlphaEmbeddingSetup(train_device, temp_device, debug_mode) - if model_type.is_stable_diffusion_3(): - return StableDiffusion3EmbeddingSetup(train_device, temp_device, debug_mode) - if model_type.is_flux(): - return FluxEmbeddingSetup(train_device, temp_device, debug_mode) - if model_type.is_chroma(): - return ChromaEmbeddingSetup(train_device, temp_device, debug_mode) - if model_type.is_sana(): - return SanaEmbeddingSetup(train_device, temp_device, debug_mode) - if model_type.is_hunyuan_video(): - return HunyuanVideoEmbeddingSetup(train_device, temp_device, debug_mode) - if model_type.is_hi_dream(): - return HiDreamEmbeddingSetup(train_device, temp_device, debug_mode) - - return None - + cls = factory.get(BaseModelSetup, model_type, training_method) + return cls(train_device, temp_device, debug_mode) if cls is not None else None def create_model_sampler( train_device: torch.device, @@ -434,44 +94,17 @@ def create_model_sampler( model_type: ModelType, training_method: TrainingMethod = TrainingMethod.FINE_TUNE, ) -> BaseModelSampler: - match training_method: - case TrainingMethod.FINE_TUNE | TrainingMethod.LORA | TrainingMethod.EMBEDDING: - if model_type.is_stable_diffusion(): - return StableDiffusionSampler(train_device, temp_device, model, model_type) - if model_type.is_stable_diffusion_xl(): - return StableDiffusionXLSampler(train_device, temp_device, model, model_type) - if model_type.is_wuerstchen(): - return WuerstchenSampler(train_device, temp_device, model, model_type) - if model_type.is_pixart(): - return PixArtAlphaSampler(train_device, temp_device, model, model_type) - if model_type.is_stable_diffusion_3(): - return StableDiffusion3Sampler(train_device, temp_device, model, model_type) - if model_type.is_flux(): - return FluxSampler(train_device, temp_device, model, model_type) - if model_type.is_chroma(): - return ChromaSampler(train_device, temp_device, model, model_type) - if model_type.is_qwen(): - return QwenSampler(train_device, temp_device, model, model_type) - if model_type.is_z_image(): - return ZImageSampler(train_device, temp_device, model, model_type) - if model_type.is_sana(): - return SanaSampler(train_device, temp_device, model, model_type) - if model_type.is_hunyuan_video(): - return HunyuanVideoSampler(train_device, temp_device, model, model_type) - if model_type.is_hi_dream(): - return HiDreamSampler(train_device, temp_device, model, model_type) - case TrainingMethod.FINE_TUNE_VAE: - if model_type.is_stable_diffusion(): - return StableDiffusionVaeSampler(train_device, temp_device, model, model_type) - - return None - + cls = factory.get(BaseModelSampler, model_type, training_method) + if cls is None: + cls = factory.get(BaseModelSampler, model_type) + return cls(train_device, temp_device, model, model_type) if cls is not None else None def create_data_loader( train_device: torch.device, temp_device: torch.device, model: BaseModel, model_type: ModelType, + model_setup: BaseModelSetup, training_method: TrainingMethod = TrainingMethod.FINE_TUNE, config: TrainConfig = None, train_progress: TrainProgress | None = None, @@ -483,38 +116,10 @@ def create_data_loader( if train_progress is None: train_progress = TrainProgress() - match training_method: - case TrainingMethod.FINE_TUNE | TrainingMethod.LORA | TrainingMethod.EMBEDDING: - if model_type.is_stable_diffusion(): - return StableDiffusionBaseDataLoader(train_device, temp_device, config, model, train_progress, is_validation) - if model_type.is_stable_diffusion_xl(): - return StableDiffusionXLBaseDataLoader(train_device, temp_device, config, model, train_progress, is_validation) - if model_type.is_wuerstchen(): - return WuerstchenBaseDataLoader(train_device, temp_device, config, model, train_progress, is_validation) - if model_type.is_pixart(): - return PixArtAlphaBaseDataLoader(train_device, temp_device, config, model, train_progress, is_validation) - if model_type.is_stable_diffusion_3(): - return StableDiffusion3BaseDataLoader(train_device, temp_device, config, model, train_progress, is_validation) - if model_type.is_flux(): - return FluxBaseDataLoader(train_device, temp_device, config, model, train_progress, is_validation) - if model_type.is_chroma(): - return ChromaBaseDataLoader(train_device, temp_device, config, model, train_progress, is_validation) - if model_type.is_qwen(): - return QwenBaseDataLoader(train_device, temp_device, config, model, train_progress, is_validation) - if model_type.is_z_image(): - return ZImageBaseDataLoader(train_device, temp_device, config, model, train_progress, is_validation) - if model_type.is_sana(): - return SanaBaseDataLoader(train_device, temp_device, config, model, train_progress, is_validation) - if model_type.is_hunyuan_video(): - return HunyuanVideoBaseDataLoader(train_device, temp_device, config, model, train_progress, is_validation) - if model_type.is_hi_dream(): - return HiDreamBaseDataLoader(train_device, temp_device, config, model, train_progress, is_validation) - case TrainingMethod.FINE_TUNE_VAE: - if model_type.is_stable_diffusion(): - return StableDiffusionFineTuneVaeDataLoader(train_device, temp_device, config, model, train_progress, is_validation) - - return None - + cls = factory.get(BaseDataLoader, model_type, training_method) + if cls is None: + cls = factory.get(BaseDataLoader, model_type) + return cls(train_device, temp_device, config, model, model_setup, train_progress, is_validation) if cls is not None else None def create_optimizer( parameter_group_collection: NamedParameterGroupCollection, diff --git a/modules/util/enum/ModelFormat.py b/modules/util/enum/ModelFormat.py index 597ad4442..70193a61b 100644 --- a/modules/util/enum/ModelFormat.py +++ b/modules/util/enum/ModelFormat.py @@ -6,6 +6,7 @@ class ModelFormat(Enum): CKPT = 'CKPT' SAFETENSORS = 'SAFETENSORS' LEGACY_SAFETENSORS = 'LEGACY_SAFETENSORS' + COMFY_LORA = 'COMFY_LORA' INTERNAL = 'INTERNAL' # an internal format that stores all information to resume training @@ -23,6 +24,8 @@ def file_extension(self) -> str: return '.safetensors' case ModelFormat.LEGACY_SAFETENSORS: return '.safetensors' + case ModelFormat.COMFY_LORA: + return '.safetensors' case _: return '' diff --git a/modules/util/enum/ModelType.py b/modules/util/enum/ModelType.py index bb8740e97..8dfebd0ff 100644 --- a/modules/util/enum/ModelType.py +++ b/modules/util/enum/ModelType.py @@ -25,6 +25,7 @@ class ModelType(Enum): FLUX_DEV_1 = 'FLUX_DEV_1' FLUX_FILL_DEV_1 = 'FLUX_FILL_DEV_1' + FLUX_DEV_2 = 'FLUX_DEV_2' SANA = 'SANA' @@ -77,9 +78,17 @@ def is_pixart_sigma(self): return self == ModelType.PIXART_SIGMA def is_flux(self): + return self == ModelType.FLUX_DEV_1 \ + or self == ModelType.FLUX_FILL_DEV_1 \ + or self == ModelType.FLUX_DEV_2 + + def is_flux_1(self): return self == ModelType.FLUX_DEV_1 \ or self == ModelType.FLUX_FILL_DEV_1 + def is_flux_2(self): + return self == ModelType.FLUX_DEV_2 + def is_chroma(self): return self == ModelType.CHROMA_1 @@ -116,7 +125,7 @@ def has_depth_input(self): def has_multiple_text_encoders(self): return self.is_stable_diffusion_3() \ or self.is_stable_diffusion_xl() \ - or self.is_flux() \ + or self.is_flux_1() \ or self.is_hunyuan_video() \ or self.is_hi_dream() \ diff --git a/modules/util/factory.py b/modules/util/factory.py new file mode 100644 index 000000000..0480286e8 --- /dev/null +++ b/modules/util/factory.py @@ -0,0 +1,25 @@ +import importlib +import pkgutil + +__registry = {} + +def get(base_cls, *args, **kwargs): + entries = __registry.get(base_cls) + if entries is None: + return None + for entry in entries: + if entry[0] == args and entry[1] == kwargs: + return entry[2] + return None + +def register(base_cls, cls, *args, **kwargs): + if get(base_cls, *args, **kwargs) is not None: + raise RuntimeError(f"{cls} already registered as an implementation of {base_cls} with the same criteria {args} {kwargs}") + + if base_cls not in __registry: + __registry[base_cls] = [] + __registry[base_cls].append((args, kwargs, cls)) + +def import_dir(path: str, parent: str): + for _finder, name, _ispkg in pkgutil.walk_packages([path], parent+"."): + importlib.import_module(name) diff --git a/requirements-global.txt b/requirements-global.txt index afa429bf2..632c6d55a 100644 --- a/requirements-global.txt +++ b/requirements-global.txt @@ -11,6 +11,7 @@ matplotlib==3.10.3 av==14.4.0 yt-dlp #no pinned version, frequently updated for compatibility with sites scenedetect==0.6.6 +parse==1.20.2 # pytorch accelerate==1.7.0 @@ -20,7 +21,8 @@ pytorch-lightning==2.5.1.post0 # diffusion models #Note: check whether Qwen bugs in diffusers have been fixed before upgrading diffusers (see BaseQwenSetup): --e git+https://github.com/huggingface/diffusers.git@256e010#egg=diffusers +#-e git+https://github.com/huggingface/diffusers.git@256e010#egg=diffusers +-e git+https://github.com/dxqb/diffusers.git@flux2_tuples#egg=diffusers gguf==0.17.1 transformers==4.56.2 sentencepiece==0.2.1 # transitive dependency of transformers for tokenizer loading diff --git a/resources/sd_model_spec/flux_dev_2.0-lora.json b/resources/sd_model_spec/flux_dev_2.0-lora.json new file mode 100644 index 000000000..03c0aed03 --- /dev/null +++ b/resources/sd_model_spec/flux_dev_2.0-lora.json @@ -0,0 +1,6 @@ +{ + "modelspec.sai_model_spec": "1.0.0", + "modelspec.architecture": "Flux.2-dev/lora", + "modelspec.implementation": "https://github.com/huggingface/diffusers", + "modelspec.title": "FluxDev 2.0 LoRA" +} diff --git a/resources/sd_model_spec/flux_dev_2.0.json b/resources/sd_model_spec/flux_dev_2.0.json new file mode 100644 index 000000000..c743cc9e8 --- /dev/null +++ b/resources/sd_model_spec/flux_dev_2.0.json @@ -0,0 +1,6 @@ +{ + "modelspec.sai_model_spec": "1.0.0", + "modelspec.architecture": "Flux.2-dev", + "modelspec.implementation": "https://github.com/huggingface/diffusers", + "modelspec.title": "FluxDev 2.0" +} From 17056c5c84b3d9e2ff3042c1d3f037ec08bdb86b Mon Sep 17 00:00:00 2001 From: dxqb Date: Tue, 30 Dec 2025 18:20:01 +0100 Subject: [PATCH 03/11] mgds dependency --- requirements-global.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-global.txt b/requirements-global.txt index 632c6d55a..399946e30 100644 --- a/requirements-global.txt +++ b/requirements-global.txt @@ -34,7 +34,7 @@ pooch==1.8.2 open-clip-torch==2.32.0 # data loader --e git+https://github.com/Nerogar/mgds.git@385578f#egg=mgds +-e git+https://github.com/dxqb/mgds.git@flux2#egg=mgds # optimizers dadaptation==3.2 # dadaptation optimizers From b2ceed1d1bce15a58a356e5b747cd9a10b08de72 Mon Sep 17 00:00:00 2001 From: O-J1 <18110006+O-J1@users.noreply.github.com> Date: Mon, 5 Jan 2026 15:43:40 +1100 Subject: [PATCH 04/11] Bump requests ver (#1245) --- requirements-global.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-global.txt b/requirements-global.txt index 97c2b9add..89b2ca348 100644 --- a/requirements-global.txt +++ b/requirements-global.txt @@ -56,5 +56,5 @@ fabric==3.2.2 # debug psutil==7.0.0 -requests==2.32.3 +requests==2.32.5 deepdiff==8.6.1 # output easy to read diff for troublshooting From 04146d4dcf395bcc16979f0ea1a3804e44c725aa Mon Sep 17 00:00:00 2001 From: dxqb <183307934+dxqb@users.noreply.github.com> Date: Wed, 7 Jan 2026 18:48:28 +0100 Subject: [PATCH 05/11] Advanced Optimizers 2.0: torch.compile support, enhancements, etc. (#1250) by Koratahiu --------- Co-authored-by: Koratahiu~ --- modules/ui/OptimizerParamsWindow.py | 1 + modules/util/config/TrainConfig.py | 2 ++ modules/util/create.py | 8 ++++++++ modules/util/optimizer_util.py | 8 ++++++++ requirements-global.txt | 2 +- 5 files changed, 20 insertions(+), 1 deletion(-) diff --git a/modules/ui/OptimizerParamsWindow.py b/modules/ui/OptimizerParamsWindow.py index 8c3e038c2..ed50e9f57 100644 --- a/modules/ui/OptimizerParamsWindow.py +++ b/modules/ui/OptimizerParamsWindow.py @@ -198,6 +198,7 @@ def create_dynamic_ui( 'approx_mars': {'title': 'Approx MARS-M', 'tooltip': 'Enables Approximated MARS-M, a variance reduction technique. It uses the previous step\'s gradient to correct the current update, leading to lower losses and improved convergence stability. This requires additional state to store the previous gradient.', 'type': 'bool'}, 'kappa_p': {'title': 'Lion-K P-value', 'tooltip': 'Controls the Lp-norm geometry for the Lion update. 1.0 = Standard Lion (Sign update, coordinate-wise), best for Transformers. 2.0 = Spherical Lion (Normalized update, rotational invariant), best for Conv2d layers (in unet models). Values between 1.0 and 2.0 interpolate behavior between the two.', 'type': 'float'}, 'auto_kappa_p': {'title': 'Auto Lion-K', 'tooltip': 'Automatically determines the optimal P-value based on layer dimensions. Uses p=2.0 (Spherical) for 4D (Conv) tensors for stability and rotational invariance, and p=1.0 (Sign) for 2D (Linear) tensors for sparsity. Overrides the manual P-value. Recommend for unet models.', 'type': 'bool'}, + 'compile': {'title': 'Compiled Optimizer', 'tooltip': 'Enables PyTorch compilation for the optimizer internal step logic. This is intended to improve performance by allowing PyTorch to fuse operations and optimize the computational graph.', 'type': 'bool'}, } # @formatter:on diff --git a/modules/util/config/TrainConfig.py b/modules/util/config/TrainConfig.py index ddaee4b89..b8184e1c2 100644 --- a/modules/util/config/TrainConfig.py +++ b/modules/util/config/TrainConfig.py @@ -143,6 +143,7 @@ class TrainOptimizerConfig(BaseConfig): approx_mars: False kappa_p: float auto_kappa_p: False + compile: False def __init__(self, data: list[(str, Any, type, bool)]): super().__init__(data) @@ -261,6 +262,7 @@ def default_values(): data.append(("approx_mars", False, bool, False)) data.append(("kappa_p", None, float, True)) data.append(("auto_kappa_p", False, bool, False)) + data.append(("compile", False, bool, False)) return TrainOptimizerConfig(data) diff --git a/modules/util/create.py b/modules/util/create.py index 4f0cd6e13..2a3cfe17d 100644 --- a/modules/util/create.py +++ b/modules/util/create.py @@ -1080,6 +1080,7 @@ def create_optimizer( alpha=optimizer_config.alpha if optimizer_config.alpha is not None else 5, kourkoutas_beta=optimizer_config.kourkoutas_beta if optimizer_config.kourkoutas_beta is not None else False, k_warmup_steps=optimizer_config.k_warmup_steps if optimizer_config.k_warmup_steps is not None else 0, + compiled_optimizer=optimizer_config.compile if optimizer_config.compile is not None else False, ) # ADOPT_ADV Optimizer @@ -1106,6 +1107,7 @@ def create_optimizer( alpha_grad=optimizer_config.alpha_grad if optimizer_config.alpha_grad is not None else 100, kourkoutas_beta=optimizer_config.kourkoutas_beta if optimizer_config.kourkoutas_beta is not None else False, k_warmup_steps=optimizer_config.k_warmup_steps if optimizer_config.k_warmup_steps is not None else 0, + compiled_optimizer=optimizer_config.compile if optimizer_config.compile is not None else False, ) # PRODIGY_ADV Optimizer @@ -1139,6 +1141,7 @@ def create_optimizer( alpha_grad=optimizer_config.alpha_grad if optimizer_config.alpha_grad is not None else 100, kourkoutas_beta=optimizer_config.kourkoutas_beta if optimizer_config.kourkoutas_beta is not None else False, k_warmup_steps=optimizer_config.k_warmup_steps if optimizer_config.k_warmup_steps is not None else 0, + compiled_optimizer=optimizer_config.compile if optimizer_config.compile is not None else False, ) # SIMPLIFIED_AdEMAMix Optimizer @@ -1161,6 +1164,7 @@ def create_optimizer( orthogonal_gradient=optimizer_config.orthogonal_gradient if optimizer_config.orthogonal_gradient is not None else False, kourkoutas_beta=optimizer_config.kourkoutas_beta if optimizer_config.kourkoutas_beta is not None else False, k_warmup_steps=optimizer_config.k_warmup_steps if optimizer_config.k_warmup_steps is not None else 0, + compiled_optimizer=optimizer_config.compile if optimizer_config.compile is not None else False, ) # LION_ADV Optimizer @@ -1180,6 +1184,7 @@ def create_optimizer( orthogonal_gradient=optimizer_config.orthogonal_gradient if optimizer_config.orthogonal_gradient is not None else False, kappa_p=optimizer_config.kappa_p if optimizer_config.kappa_p is not None else 1.0, auto_kappa_p=optimizer_config.auto_kappa_p if optimizer_config.auto_kappa_p is not None else False, + compiled_optimizer=optimizer_config.compile if optimizer_config.compile is not None else False, ) # LION_PRODIGY_ADV Optimizer @@ -1206,6 +1211,7 @@ def create_optimizer( orthogonal_gradient=optimizer_config.orthogonal_gradient if optimizer_config.orthogonal_gradient is not None else False, kappa_p=optimizer_config.kappa_p if optimizer_config.kappa_p is not None else 1.0, auto_kappa_p=optimizer_config.auto_kappa_p if optimizer_config.auto_kappa_p is not None else False, + compiled_optimizer=optimizer_config.compile if optimizer_config.compile is not None else False, ) # MUON_ADV Optimizer @@ -1254,6 +1260,7 @@ def create_optimizer( accelerated_ns=optimizer_config.accelerated_ns if optimizer_config.accelerated_ns is not None else False, orthogonal_gradient=optimizer_config.orthogonal_gradient if optimizer_config.orthogonal_gradient is not None else False, approx_mars=optimizer_config.approx_mars if optimizer_config.approx_mars is not None else False, + compiled_optimizer=optimizer_config.compile if optimizer_config.compile is not None else False, **adam_kwargs ) @@ -1307,6 +1314,7 @@ def create_optimizer( accelerated_ns=optimizer_config.accelerated_ns if optimizer_config.accelerated_ns is not None else False, orthogonal_gradient=optimizer_config.orthogonal_gradient if optimizer_config.orthogonal_gradient is not None else False, approx_mars=optimizer_config.approx_mars if optimizer_config.approx_mars is not None else False, + compiled_optimizer=optimizer_config.compile if optimizer_config.compile is not None else False, **adam_kwargs ) diff --git a/modules/util/optimizer_util.py b/modules/util/optimizer_util.py index add2bec24..d20204996 100644 --- a/modules/util/optimizer_util.py +++ b/modules/util/optimizer_util.py @@ -457,6 +457,7 @@ def init_model_parameters( "use_bias_correction": True, "nnmf_factor": False, "stochastic_rounding": True, + "compile": False, "fused_back_pass": False, "use_atan2": False, "cautious_mask": False, @@ -476,6 +477,7 @@ def init_model_parameters( "weight_decay": 0.0, "nnmf_factor": False, "stochastic_rounding": True, + "compile": False, "fused_back_pass": False, "use_atan2": False, "cautious_mask": False, @@ -498,6 +500,7 @@ def init_model_parameters( "weight_decay": 0.0, "nnmf_factor": False, "stochastic_rounding": True, + "compile": False, "fused_back_pass": False, "d0": 1e-6, "d_coef": 1.0, @@ -529,6 +532,7 @@ def init_model_parameters( "use_bias_correction": True, "nnmf_factor": False, "stochastic_rounding": True, + "compile": False, "fused_back_pass": False, "orthogonal_gradient": False, "kourkoutas_beta": False, @@ -542,6 +546,7 @@ def init_model_parameters( "clip_threshold": None, "nnmf_factor": False, "stochastic_rounding": True, + "compile": False, "fused_back_pass": False, "cautious_mask": False, "orthogonal_gradient": False, @@ -557,6 +562,7 @@ def init_model_parameters( "clip_threshold": None, "nnmf_factor": False, "stochastic_rounding": True, + "compile": False, "fused_back_pass": False, "d0": 1e-6, "d_coef": 1.0, @@ -580,6 +586,7 @@ def init_model_parameters( "rms_rescaling": True, "nnmf_factor": False, "stochastic_rounding": True, + "compile": False, "fused_back_pass": False, "MuonWithAuxAdam": True, "muon_hidden_layers": None, @@ -610,6 +617,7 @@ def init_model_parameters( "rms_rescaling": True, "nnmf_factor": False, "stochastic_rounding": True, + "compile": False, "fused_back_pass": False, "MuonWithAuxAdam": True, "muon_hidden_layers": None, diff --git a/requirements-global.txt b/requirements-global.txt index 89b2ca348..2b7847046 100644 --- a/requirements-global.txt +++ b/requirements-global.txt @@ -41,7 +41,7 @@ prodigyopt==1.1.2 # prodigy optimizer schedulefree==1.4.1 # schedule-free optimizers pytorch_optimizer==3.6.0 # pytorch optimizers prodigy-plus-schedule-free==2.0.1 # Prodigy plus optimizer -adv_optm==1.4.1 # advanced optimizers +adv_optm==2.0.1 # advanced optimizers -e git+https://github.com/KellerJordan/Muon.git@f90a42b#egg=muon-optimizer # Profiling From 563be6484cc97b80ba530e4b6a498a31756c3b52 Mon Sep 17 00:00:00 2001 From: dxqb <183307934+dxqb@users.noreply.github.com> Date: Wed, 7 Jan 2026 18:50:19 +0100 Subject: [PATCH 06/11] SignSGD with Momentum (Signum) Optimizer (#1251) by Koratahiu --------- Co-authored-by: Koratahiu~ --- modules/util/create.py | 17 +++++++++++++++++ modules/util/enum/Optimizer.py | 2 ++ modules/util/optimizer_util.py | 12 ++++++++++++ requirements-global.txt | 2 +- 4 files changed, 32 insertions(+), 1 deletion(-) diff --git a/modules/util/create.py b/modules/util/create.py index 2a3cfe17d..a03ce3067 100644 --- a/modules/util/create.py +++ b/modules/util/create.py @@ -1167,6 +1167,23 @@ def create_optimizer( compiled_optimizer=optimizer_config.compile if optimizer_config.compile is not None else False, ) + # SignSGD_ADV Optimizer + case Optimizer.SIGNSGD_ADV: + from adv_optm import SignSGD_adv + optimizer = SignSGD_adv( + params=parameters, + lr=config.learning_rate, + momentum=optimizer_config.momentum if optimizer_config.momentum is not None else 0, + weight_decay=optimizer_config.weight_decay if optimizer_config.weight_decay is not None else 0.0, + nnmf_factor=optimizer_config.nnmf_factor if optimizer_config.nnmf_factor is not None else False, + cautious_wd=optimizer_config.cautious_wd if optimizer_config.cautious_wd is not None else False, + stochastic_rounding=optimizer_config.stochastic_rounding, + orthogonal_gradient=optimizer_config.orthogonal_gradient if optimizer_config.orthogonal_gradient is not None else False, + compiled_optimizer=optimizer_config.compile if optimizer_config.compile is not None else False, + Simplified_AdEMAMix=optimizer_config.Simplified_AdEMAMix if optimizer_config.Simplified_AdEMAMix is not None else False, + alpha_grad=optimizer_config.alpha_grad if optimizer_config.alpha_grad is not None else 100, + ) + # LION_ADV Optimizer case Optimizer.LION_ADV: from adv_optm import Lion_adv diff --git a/modules/util/enum/Optimizer.py b/modules/util/enum/Optimizer.py index c3edfb837..58edf419c 100644 --- a/modules/util/enum/Optimizer.py +++ b/modules/util/enum/Optimizer.py @@ -42,6 +42,7 @@ class Optimizer(Enum): # 32 bit is torch and not bnb SGD = 'SGD' SGD_8BIT = 'SGD_8BIT' + SIGNSGD_ADV = 'SIGNSGD_ADV' # Schedule-free optimizers SCHEDULE_FREE_ADAMW = 'SCHEDULE_FREE_ADAMW' @@ -116,6 +117,7 @@ def supports_fused_back_pass(self): Optimizer.LION_PRODIGY_ADV, Optimizer.MUON_ADV, Optimizer.ADAMUON_ADV, + Optimizer.SIGNSGD_ADV, ] # Small helper for adjusting learning rates to adaptive optimizers. diff --git a/modules/util/optimizer_util.py b/modules/util/optimizer_util.py index d20204996..f879bb5d9 100644 --- a/modules/util/optimizer_util.py +++ b/modules/util/optimizer_util.py @@ -538,6 +538,18 @@ def init_model_parameters( "kourkoutas_beta": False, "k_warmup_steps": None, }, + Optimizer.SIGNSGD_ADV: { + "momentum": 0.99, + "cautious_wd": False, + "weight_decay": 0.0, + "nnmf_factor": False, + "stochastic_rounding": True, + "compiled_optimizer": False, + "fused_back_pass": False, + "orthogonal_gradient": False, + "Simplified_AdEMAMix": False, + "alpha_grad": 100.0, + }, Optimizer.LION_ADV: { "beta1": 0.9, "beta2": 0.99, diff --git a/requirements-global.txt b/requirements-global.txt index 2b7847046..a6f3d0cc4 100644 --- a/requirements-global.txt +++ b/requirements-global.txt @@ -41,7 +41,7 @@ prodigyopt==1.1.2 # prodigy optimizer schedulefree==1.4.1 # schedule-free optimizers pytorch_optimizer==3.6.0 # pytorch optimizers prodigy-plus-schedule-free==2.0.1 # Prodigy plus optimizer -adv_optm==2.0.1 # advanced optimizers +adv_optm==2.1.0 # advanced optimizers -e git+https://github.com/KellerJordan/Muon.git@f90a42b#egg=muon-optimizer # Profiling From ef5b00c42697e38419014cc5cf45b5dac963a3e7 Mon Sep 17 00:00:00 2001 From: O-J1 <18110006+O-J1@users.noreply.github.com> Date: Tue, 13 Jan 2026 15:18:16 +1100 Subject: [PATCH 07/11] Bump av ver to 16.1.0 --- requirements-global.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-global.txt b/requirements-global.txt index a6f3d0cc4..3438ce9ef 100644 --- a/requirements-global.txt +++ b/requirements-global.txt @@ -8,7 +8,7 @@ PyYAML==6.0.2 huggingface-hub==0.34.4 scipy==1.15.3 matplotlib==3.10.3 -av==14.4.0 +av==16.1.0 yt-dlp #no pinned version, frequently updated for compatibility with sites scenedetect==0.6.6 From aaf4d12cffceb26d9fde210d38874e1d756bc40c Mon Sep 17 00:00:00 2001 From: O-J1 <18110006+O-J1@users.noreply.github.com> Date: Wed, 14 Jan 2026 02:18:23 +1100 Subject: [PATCH 08/11] Update README.md --- README.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 997f639f5..67a7c3e1a 100644 --- a/README.md +++ b/README.md @@ -45,15 +45,15 @@ OneTrainer is a one-stop solution for all your Diffusion training needs. - Windows: Double click or execute `install.bat` - Linux and Mac: Execute `install.sh` - #### Manual installation - - 1. Clone the repository `git clone https://github.com/Nerogar/OneTrainer.git` - 2. Navigate into the cloned directory `cd OneTrainer` - 3. Set up a virtual environment `python -m venv venv` - 4. Activate the new venv: - - Windows: `venv\scripts\activate` - - Linux and Mac: Depends on your shell, activate the venv accordingly - 5. Install the requirements `pip install -r requirements.txt` +#### Manual installation + +1. Clone the repository `git clone https://github.com/Nerogar/OneTrainer.git` +2. Navigate into the cloned directory `cd OneTrainer` +3. Set up a virtual environment `python -m venv venv` +4. Activate the new venv: + - Windows: `venv\scripts\activate` + - Linux and Mac: Depends on your shell, activate the venv accordingly +5. Install the requirements `pip install -r requirements.txt` > [!Tip] > Some Linux distributions are missing required packages for instance: On Ubuntu you must install `libGL`: From f1f93206fb44cab18d5fb881317848634497e281 Mon Sep 17 00:00:00 2001 From: dxqb Date: Sat, 17 Jan 2026 01:02:21 +0100 Subject: [PATCH 09/11] fix arguments to validation data loader --- modules/trainer/GenericTrainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/trainer/GenericTrainer.py b/modules/trainer/GenericTrainer.py index 0607fe349..52f812b61 100644 --- a/modules/trainer/GenericTrainer.py +++ b/modules/trainer/GenericTrainer.py @@ -158,7 +158,7 @@ def start(self): if self.config.validation: self.validation_data_loader = self.create_data_loader( - self.model, self.model.train_progress, is_validation=True + self.model, self.model_setup, self.model.train_progress, is_validation=True ) def __save_config_to_workspace(self): From a3be5ee356b09c492036c0c32ce78c15c823796e Mon Sep 17 00:00:00 2001 From: dxqb Date: Mon, 19 Jan 2026 21:46:30 +0100 Subject: [PATCH 10/11] unpatchify, to match the shape of masks --- modules/modelSetup/BaseFlux2Setup.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/modelSetup/BaseFlux2Setup.py b/modules/modelSetup/BaseFlux2Setup.py index ecd64efe7..df09457e4 100644 --- a/modules/modelSetup/BaseFlux2Setup.py +++ b/modules/modelSetup/BaseFlux2Setup.py @@ -159,8 +159,9 @@ def predict( model_output_data = { 'loss_type': 'target', 'timestep': timestep, - 'predicted': predicted_flow, - 'target': flow, + #unpatchify, to make the shape match the mask shape of masked training: + 'predicted': model.unpatchify_latents(predicted_flow), + 'target': model.unpatchify_latents(flow), } if config.debug_mode: From 792510a4a5faca61a4a3e28b1ee6c530eb4d2ff5 Mon Sep 17 00:00:00 2001 From: dxqb Date: Mon, 19 Jan 2026 21:55:50 +0100 Subject: [PATCH 11/11] rename Comfy and remove filter, because UI is not updated when you change Peft type --- modules/ui/ModelTab.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/modules/ui/ModelTab.py b/modules/ui/ModelTab.py index 7d52d1748..e7d765b1d 100644 --- a/modules/ui/ModelTab.py +++ b/modules/ui/ModelTab.py @@ -5,7 +5,6 @@ from modules.util.enum.ConfigPart import ConfigPart from modules.util.enum.DataType import DataType from modules.util.enum.ModelFormat import ModelFormat -from modules.util.enum.ModelType import PeftType from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.ui import components from modules.util.ui.UIState import UIState @@ -151,7 +150,7 @@ def __setup_flux_2_ui(self, frame): allow_safetensors=True, allow_diffusers=self.train_config.training_method == TrainingMethod.FINE_TUNE, allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, - allow_comfy=self.train_config.training_method == TrainingMethod.LORA and self.train_config.peft_type == PeftType.LORA, + allow_comfy=self.train_config.training_method == TrainingMethod.LORA, ) def __setup_z_image_ui(self, frame): @@ -642,7 +641,7 @@ def __create_output_components( # if allow_legacy_safetensors: # formats.append(("Legacy Safetensors", ModelFormat.LEGACY_SAFETENSORS)) if allow_comfy: - formats.append(("Comfy", ModelFormat.COMFY_LORA)) + formats.append(("Comfy LoRA", ModelFormat.COMFY_LORA)) components.label(frame, row, 0, "Output Format", tooltip="Format to use when saving the output model")