From aa243b271689068d474589173d24fc3e95c42484 Mon Sep 17 00:00:00 2001 From: gushiqiao <975033167> Date: Fri, 23 Jan 2026 06:17:33 +0000 Subject: [PATCH 1/4] [Feat] Support LTX2 sequence parallel, tensor parallel, and VAE decode tiling --- configs/ltx2/ltx2.json | 4 +- configs/ltx2/ltx2_distill_fp8.json | 2 +- configs/ltx2/ltx2_fp8.json | 1 - configs/ltx2/ltx2_tp.json | 20 + configs/ltx2/ltx2_ulysses.json | 21 + configs/wan/wan_i2v.json | 6 +- examples/ltx2/ltxt_i2av.py | 12 +- examples/ltx2/ltxt_i2av_distilled_fp8.py | 17 +- examples/ltx2/ltxt_t2av_distilled_fp8.py | 19 +- lightx2v/common/ops/mm/mm_weight.py | 96 ++++ lightx2v/common/ops/norm/rms_norm_weight.py | 60 +++ lightx2v/infer.py | 14 +- .../models/networks/ltx2/infer/pre_infer.py | 2 +- .../networks/ltx2/infer/transformer_infer.py | 214 +++++++-- lightx2v/models/networks/ltx2/infer/utils.py | 16 +- lightx2v/models/networks/ltx2/model.py | 438 ++++++++++++++++-- .../ltx2/weights/transformer_weights.py | 292 +++++++++++- .../networks/wan/infer/transformer_infer.py | 4 +- lightx2v/models/networks/wan/model.py | 2 + lightx2v/models/runners/ltx2/ltx2_runner.py | 69 ++- .../models/runners/wan/wan_distill_runner.py | 2 +- .../models/video_encoders/hf/ltx2/model.py | 27 ++ lightx2v/pipeline.py | 12 +- lightx2v/utils/input_info.py | 5 +- lightx2v/utils/set_config.py | 31 +- lightx2v/utils/utils.py | 49 ++ scripts/ltx2/run_ltx2_i2av.sh | 8 +- scripts/ltx2/run_ltx2_i2av_tp.sh | 21 + scripts/ltx2/run_ltx2_i2av_ulysses.sh | 21 + scripts/ltx2/run_ltx2_t2av.sh | 2 +- scripts/ltx2/run_ltx2_t2av_cfg_parallel.sh | 0 scripts/ltx2/run_ltx2_t2av_tp.sh | 19 + scripts/ltx2/run_ltx2_t2av_ulysses.sh | 19 + scripts/wan/run_wan_i2v.sh | 8 +- 34 files changed, 1357 insertions(+), 176 deletions(-) create mode 100755 configs/ltx2/ltx2_tp.json create mode 100755 configs/ltx2/ltx2_ulysses.json mode change 100644 => 100755 scripts/ltx2/run_ltx2_i2av.sh create mode 100755 scripts/ltx2/run_ltx2_i2av_tp.sh create mode 100755 scripts/ltx2/run_ltx2_i2av_ulysses.sh mode change 100644 => 100755 scripts/ltx2/run_ltx2_t2av_cfg_parallel.sh create mode 100755 scripts/ltx2/run_ltx2_t2av_tp.sh create mode 100755 scripts/ltx2/run_ltx2_t2av_ulysses.sh diff --git a/configs/ltx2/ltx2.json b/configs/ltx2/ltx2.json index 7c0d6e3b0..dabc0bc72 100755 --- a/configs/ltx2/ltx2.json +++ b/configs/ltx2/ltx2.json @@ -1,7 +1,6 @@ { "infer_steps": 40, "target_video_length": 121, - "text_len": 512, "target_height": 512, "target_width": 768, "attn_type": "sage_attn2", @@ -14,5 +13,6 @@ "audio_fps": 24000, "audio_mel_bins":16, "double_precision_rope": true, - "dit_original_ckpt": "/data/nvme0/gushiqiao/models/official_models/LTX-2/ltx-2-19b-dev.safetensors" + "use_tiling_vae": false, + "dit_original_ckpt": "Lightricks/LTX-2/ltx-2-19b-dev.safetensors" } diff --git a/configs/ltx2/ltx2_distill_fp8.json b/configs/ltx2/ltx2_distill_fp8.json index 6c16342d0..2ee31c3f3 100755 --- a/configs/ltx2/ltx2_distill_fp8.json +++ b/configs/ltx2/ltx2_distill_fp8.json @@ -13,7 +13,7 @@ "audio_fps": 24000, "audio_mel_bins":16, "double_precision_rope": true, - "dit_quantized_ckpt": "/data/nvme0/gushiqiao/models/official_models/LTX-2/ltx-2-19b-distilled-fp8.safetensors", + "dit_quantized_ckpt": "Lightricks/LTX-2/ltx-2-19b-distilled-fp8.safetensors", "dit_quantized": true, "dit_quant_scheme": "fp8-pertensor", "skip_fp8_block_index" : [0, 43, 44, 45, 46, 47] diff --git a/configs/ltx2/ltx2_fp8.json b/configs/ltx2/ltx2_fp8.json index c728de5a6..37e4c873e 100755 --- a/configs/ltx2/ltx2_fp8.json +++ b/configs/ltx2/ltx2_fp8.json @@ -1,7 +1,6 @@ { "infer_steps": 40, "target_video_length": 121, - "text_len": 512, "target_height": 512, "target_width": 768, "attn_type": "sage_attn2", diff --git a/configs/ltx2/ltx2_tp.json b/configs/ltx2/ltx2_tp.json new file mode 100755 index 000000000..93b007c4d --- /dev/null +++ b/configs/ltx2/ltx2_tp.json @@ -0,0 +1,20 @@ +{ + "infer_steps": 40, + "target_video_length": 121, + "target_height": 512, + "target_width": 768, + "attn_type": "sage_attn2", + "sample_guide_scale": 4, + "sample_shift": [2.05, 0.95], + "enable_cfg": true, + "cpu_offload": false, + "num_channels_latents": 128, + "fps": 24, + "audio_fps": 24000, + "audio_mel_bins":16, + "double_precision_rope": true, + "dit_original_ckpt": "Lightricks/LTX-2/ltx-2-19b-dev.safetensors", + "parallel": { + "tensor_p_size": 2 + } +} diff --git a/configs/ltx2/ltx2_ulysses.json b/configs/ltx2/ltx2_ulysses.json new file mode 100755 index 000000000..092e6674e --- /dev/null +++ b/configs/ltx2/ltx2_ulysses.json @@ -0,0 +1,21 @@ +{ + "infer_steps": 40, + "target_video_length": 121, + "target_height": 512, + "target_width": 768, + "attn_type": "sage_attn2", + "sample_guide_scale": 4, + "sample_shift": [2.05, 0.95], + "enable_cfg": true, + "cpu_offload": false, + "num_channels_latents": 128, + "fps": 24, + "audio_fps": 24000, + "audio_mel_bins":16, + "double_precision_rope": true, + "dit_original_ckpt": "Lightricks/LTX-2/ltx-2-19b-dev.safetensors", + "parallel": { + "seq_p_size": 2, + "seq_p_attn_type": "ulysses" + } +} diff --git a/configs/wan/wan_i2v.json b/configs/wan/wan_i2v.json index 6c2107083..ff4b83b1e 100755 --- a/configs/wan/wan_i2v.json +++ b/configs/wan/wan_i2v.json @@ -9,5 +9,9 @@ "sample_guide_scale": 5, "sample_shift": 3, "enable_cfg": true, - "cpu_offload": false + "cpu_offload": false, + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ulysses" + } } diff --git a/examples/ltx2/ltxt_i2av.py b/examples/ltx2/ltxt_i2av.py index a79745677..c66b8402d 100755 --- a/examples/ltx2/ltxt_i2av.py +++ b/examples/ltx2/ltxt_i2av.py @@ -24,15 +24,21 @@ ) seed = 42 -images = "/path/to/woman.jpeg:0:1.0" -prompt = "A young woman with wavy, shoulder-length light brown hair stands outdoors on a foggy day. She wears a cozy pink turtleneck sweater, with a serene expression and piercing blue eyes. A wooden fence and a misty, grassy field fade into the background, evoking a calm and introspective mood." +image_path = "/path/to/woman.jpeg" # For multiple images, use comma-separated paths: "path1.jpg,path2.jpg" +image_strength = 1.0 # Scalar: use same strength for all images, or list: [1.0, 0.8] for different strengths +prompt = "A young woman with wavy, shoulder-length light brown hair is singing and dancing joyfully outdoors on a foggy day. She wears a cozy pink turtleneck sweater, swaying gracefully to the music with animated expressions and bright, piercing blue eyes. Her movements are fluid and energetic as she twirls and gestures expressively. A wooden fence and a misty, grassy field fade into the background, creating a dreamy atmosphere for her lively performance." negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." save_result_path = "/path/to/save_results/output.mp4" +# Note: image_strength can also be set in config_json +# For scalar: image_strength = 1.0 (all images use same strength) +# For list: image_strength = [1.0, 0.8] (must match number of images) + pipe.generate( seed=seed, prompt=prompt, - images=images, + image_path=image_path, + image_strength=image_strength, negative_prompt=negative_prompt, save_result_path=save_result_path, ) diff --git a/examples/ltx2/ltxt_i2av_distilled_fp8.py b/examples/ltx2/ltxt_i2av_distilled_fp8.py index faf8ccf9d..904f725b7 100755 --- a/examples/ltx2/ltxt_i2av_distilled_fp8.py +++ b/examples/ltx2/ltxt_i2av_distilled_fp8.py @@ -1,14 +1,14 @@ from lightx2v import LightX2VPipeline pipe = LightX2VPipeline( - model_path="Lightricks/LTX-2/", + model_path="/data/nvme0/gushiqiao/models/official_models/LTX-2/", model_cls="ltx2", task="i2av", ) pipe.enable_quantize( dit_quantized=True, - dit_quantized_ckpt="Lightricks/LTX-2/ltx-2-19b-distilled-fp8.safetensors", + dit_quantized_ckpt="/data/nvme0/gushiqiao/models/official_models/LTX-2/ltx-2-19b-distilled-fp8.safetensors", quant_scheme="fp8-pertensor", skip_fp8_block_index=[0, 43, 44, 45, 46, 47], ) @@ -35,15 +35,20 @@ ) seed = 42 -images = "/path/to/woman.jpeg:0:1.0" -prompt = "A young woman with wavy, shoulder-length light brown hair stands outdoors on a foggy day. She wears a cozy pink turtleneck sweater, with a serene expression and piercing blue eyes. A wooden fence and a misty, grassy field fade into the background, evoking a calm and introspective mood." +image_path = "/data/nvme0/gushiqiao/models/code/LightX2V/assets/inputs/imgs/woman.jpeg" # For multiple images, use comma-separated paths: "path1.jpg,path2.jpg" +image_strength = 1.0 # Scalar: use same strength for all images, or list: [1.0, 0.8] for different strengths +prompt = "A young woman with wavy, shoulder-length light brown hair is singing and dancing joyfully outdoors on a foggy day. She wears a cozy pink turtleneck sweater, swaying gracefully to the music with animated expressions and bright, piercing blue eyes. Her movements are fluid and energetic as she twirls and gestures expressively. A wooden fence and a misty, grassy field fade into the background, creating a dreamy atmosphere for her lively performance." negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." -save_result_path = "/path/to/save_results/output.mp4" +save_result_path = "/data/nvme0/gushiqiao/models/code/LightX2V/save_results/output_lightx2v_ltx2_i2av_distilled_fp8.mp4" +# Note: image_strength can also be set in config_json +# For scalar: image_strength = 1.0 (all images use same strength) +# For list: image_strength = [1.0, 0.8] (must match number of images) pipe.generate( seed=seed, prompt=prompt, - images=images, + image_path=image_path, + image_strength=image_strength, negative_prompt=negative_prompt, save_result_path=save_result_path, ) diff --git a/examples/ltx2/ltxt_t2av_distilled_fp8.py b/examples/ltx2/ltxt_t2av_distilled_fp8.py index 2bcfcb75e..0e81115cc 100755 --- a/examples/ltx2/ltxt_t2av_distilled_fp8.py +++ b/examples/ltx2/ltxt_t2av_distilled_fp8.py @@ -1,10 +1,10 @@ from lightx2v import LightX2VPipeline -pipe = LightX2VPipeline(model_path="Lightricks/LTX-2", model_cls="ltx2", task="t2av") +pipe = LightX2VPipeline(model_path="/data/nvme0/gushiqiao/models/official_models/LTX-2", model_cls="ltx2", task="t2av") pipe.enable_quantize( dit_quantized=True, - dit_quantized_ckpt="Lightricks/LTX-2/ltx-2-19b-distilled-fp8.safetensors", + dit_quantized_ckpt="/data/nvme0/gushiqiao/models/official_models/LTX-2/ltx-2-19b-distilled-fp8.safetensors", quant_scheme="fp8-pertensor", skip_fp8_block_index=[0, 43, 44, 45, 46, 47], ) @@ -33,11 +33,12 @@ seed = 42 prompt = "A beautiful sunset over the ocean" negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." -save_result_path = "/path/to/save_results/output.mp4" +save_result_path = "/data/nvme0/gushiqiao/models/code/LightX2V/save_results/output_lightx2v_ltx2_t2av_distilled_fp8.mp4" -pipe.generate( - seed=seed, - prompt=prompt, - negative_prompt=negative_prompt, - save_result_path=save_result_path, -) +for seed in range(10): + pipe.generate( + seed=seed, + prompt=prompt, + negative_prompt=negative_prompt, + save_result_path=save_result_path, + ) diff --git a/lightx2v/common/ops/mm/mm_weight.py b/lightx2v/common/ops/mm/mm_weight.py index c27e9cef6..1b723dab1 100755 --- a/lightx2v/common/ops/mm/mm_weight.py +++ b/lightx2v/common/ops/mm/mm_weight.py @@ -2,6 +2,7 @@ from abc import ABCMeta, abstractmethod import torch +import torch.distributed as dist from loguru import logger from safetensors import safe_open @@ -92,6 +93,8 @@ except ImportError: marlin_cuda_quant = None +import torch.distributed as dist + class MMWeightTemplate(metaclass=ABCMeta): def __init__( @@ -2143,3 +2146,96 @@ def apply(self, input_tensor): if self.has_lora_branch: return output_tensor + self.apply_lora(input_tensor) return output_tensor + + +@MM_WEIGHT_REGISTER("TensorParallel") +class MMWeightTP(MMWeightTemplate): + """ + Tensor Parallel wrapper for any MMWeight type. + + This is a generic wrapper that can wrap any MMWeight implementation (Default, fp8, int8, etc.) + and add tensor parallelism support by: + 1. Handling weight splitting in load() method + 2. Adding all-reduce for row-wise split in apply() method + + Supports column-wise and row-wise weight splitting: + - Column split: weight [in_dim, out_dim] -> [in_dim, out_dim/tp_size] per rank + - Row split: weight [in_dim, out_dim] -> [in_dim/tp_size, out_dim] per rank + """ + + def __init__( + self, + weight_name, + bias_name, + mm_type="Default", + tp_group=None, + tp_rank=0, + tp_size=1, + split_dim="col", + create_cuda_buffer=False, + create_cpu_buffer=False, + lazy_load=False, + lazy_load_file=None, + is_post_adapter=False, + lora_prefix="diffusion_model.blocks", + lora_path="", + ): + super().__init__( + weight_name, + bias_name, + create_cuda_buffer, + create_cpu_buffer, + lazy_load, + lazy_load_file, + is_post_adapter, + lora_prefix, + lora_path, + ) + self.tp_group = tp_group + self.tp_rank = tp_rank + self.tp_size = tp_size + self.split_dim = split_dim # "col" for column split, "row" for row split + assert split_dim in ["col", "row"], f"split_dim must be 'col' or 'row', got {split_dim}" + + self._mm = MM_WEIGHT_REGISTER.get(mm_type, MMWeight)( + weight_name=weight_name, + bias_name=bias_name, + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=lazy_load, + lazy_load_file=lazy_load_file, + is_post_adapter=is_post_adapter, + lora_prefix=lora_prefix, + lora_path=lora_path, + ) + self._row_split_bias = None + + def load(self, weight_dict): + """Load weights using internal MMWeight's load method. + + Note: Weights in weight_dict are already split by _load_weights_from_rank0. + The format is [out_dim/tp_size, in_dim] for column split or [out_dim, in_dim/tp_size] for row split. + MMWeight.load will handle the transposition via create_default_tensors. + + For row split, bias is not split and should be added after all-reduce. + We temporarily remove bias from _mm to prevent it from being added before all-reduce. + """ + self._mm.load(weight_dict) + if self.split_dim == "row" and self.bias_name is not None and self.bias_name in weight_dict: + self._row_split_bias = self._mm.bias.clone() + self._mm.bias = None + + def apply(self, input_tensor): + """Apply matrix multiplication with tensor parallel support.""" + # Use internal MMWeight's apply method (handles fp8, int8, etc.) + # For row split, _mm.bias is None, so bias won't be added here + output = self._mm.apply(input_tensor) + + # For row split, need all-reduce to combine results from all ranks + if self.split_dim == "row" and self.tp_size > 1 and self.tp_group is not None: + dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.tp_group) + # Add bias after all-reduce (bias is not split for row split) + if self._row_split_bias is not None: + output = output + self._row_split_bias + + return output diff --git a/lightx2v/common/ops/norm/rms_norm_weight.py b/lightx2v/common/ops/norm/rms_norm_weight.py index 077a40e9e..148e59960 100755 --- a/lightx2v/common/ops/norm/rms_norm_weight.py +++ b/lightx2v/common/ops/norm/rms_norm_weight.py @@ -1,6 +1,7 @@ from abc import ABCMeta, abstractmethod import torch +import torch.distributed as dist from loguru import logger from safetensors import safe_open @@ -176,6 +177,65 @@ def apply(self, input_tensor): return input_tensor +@RMS_WEIGHT_REGISTER("TensorParallel") +class RMSWeightTP(RMSWeightTemplate): + """ + RMSNorm weight module with tensor parallelism support. + + The weight is split along the hidden dimension to match the split QKV outputs. + """ + + def __init__( + self, + weight_name, + tp_group=None, + tp_rank=0, + tp_size=1, + create_cuda_buffer=False, + create_cpu_buffer=False, + lazy_load=False, + lazy_load_file=None, + is_post_adapter=False, + eps=1e-6, + lora_prefix="diffusion_model.blocks", + lora_path="", + ): + super().__init__( + weight_name, + create_cuda_buffer, + create_cpu_buffer, + lazy_load, + lazy_load_file, + is_post_adapter, + eps, + lora_prefix, + lora_path, + ) + self.tp_group = tp_group + self.tp_rank = tp_rank + self.tp_size = tp_size + + def apply(self, input_tensor): + local_sum = input_tensor.pow(2).sum(-1, keepdim=True) + + # All-reduce to get global sum + if self.tp_size > 1 and self.tp_group is not None: + dist.all_reduce(local_sum, op=dist.ReduceOp.SUM, group=self.tp_group) + + # Compute global mean: global_sum / hidden_dim + hidden_dim = input_tensor.shape[-1] * self.tp_size + global_mean = local_sum / hidden_dim + + # Apply normalization with global mean + if self.sensitive_layer_dtype != self.infer_dtype: + input_tensor = input_tensor * torch.rsqrt(global_mean.float() + self.eps).to(self.infer_dtype) + input_tensor = (input_tensor * self._get_actual_weight()).to(self.infer_dtype) + else: + input_tensor = input_tensor * torch.rsqrt(global_mean + self.eps) + input_tensor = input_tensor * self._get_actual_weight() + return input_tensor + + @RMS_WEIGHT_REGISTER("sgl-kernel") class RMSWeightSgl(RMSWeight): def __init__( diff --git a/lightx2v/infer.py b/lightx2v/infer.py index a8dedda2a..84c62d79f 100755 --- a/lightx2v/infer.py +++ b/lightx2v/infer.py @@ -24,7 +24,7 @@ from lightx2v.utils.profiler import * from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.set_config import print_config, set_config, set_parallel_config -from lightx2v.utils.utils import seed_all, validate_task_arguments +from lightx2v.utils.utils import seed_all, validate_config_paths, validate_task_arguments from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER @@ -75,15 +75,15 @@ def main(): parser.add_argument("--prompt", type=str, default="", help="The input prompt for text-to-video generation") parser.add_argument("--negative_prompt", type=str, default="") - parser.add_argument("--image_path", type=str, default="", help="The path to input image file for image-to-video (i2v) task") - parser.add_argument("--last_frame_path", type=str, default="", help="The path to last frame file for first-last-frame-to-video (flf2v) task") - parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file or directory for audio-to-video (s2v) task") parser.add_argument( - "--images", + "--image_path", type=str, default="", - help="Image conditioning for I2AV task. Format: 'path1:frame_idx1:strength1,path2:frame_idx2:strength2'. Example: 'cat.jpg:0:1.0,dog.jpg:60:0.8'", + help="The path to input image file(s) for image-to-video (i2v) or image-to-audio-video (i2av) task. Multiple paths should be comma-separated. Example: 'path1.jpg,path2.jpg'", ) + parser.add_argument("--last_frame_path", type=str, default="", help="The path to last frame file for first-last-frame-to-video (flf2v) task") + parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file or directory for audio-to-video (s2v) task") + parser.add_argument("--image_strength", type=float, default=1.0, help="The strength of the image-to-audio-video (i2av) task") # [Warning] For vace task, need refactor. parser.add_argument( "--src_ref_images", @@ -149,6 +149,8 @@ def main(): print_config(config) + validate_config_paths(config) + with ProfilingContext4DebugL1("Total Cost"): # init runner runner = init_runner(config) diff --git a/lightx2v/models/networks/ltx2/infer/pre_infer.py b/lightx2v/models/networks/ltx2/infer/pre_infer.py index dd59a405e..abb274b6b 100755 --- a/lightx2v/models/networks/ltx2/infer/pre_infer.py +++ b/lightx2v/models/networks/ltx2/infer/pre_infer.py @@ -53,7 +53,7 @@ def _prepare_positional_embeddings( max_pos=max_pos, use_middle_indices_grid=use_middle_indices_grid, num_attention_heads=num_attention_heads, - rope_type=LTXRopeType.SPLIT, + rope_type="split", freq_grid_generator=freq_grid_generator, ) return pe diff --git a/lightx2v/models/networks/ltx2/infer/transformer_infer.py b/lightx2v/models/networks/ltx2/infer/transformer_infer.py index c3af96e57..c3da867c3 100755 --- a/lightx2v/models/networks/ltx2/infer/transformer_infer.py +++ b/lightx2v/models/networks/ltx2/infer/transformer_infer.py @@ -9,10 +9,11 @@ """ import torch +import torch.distributed as dist from lightx2v.models.networks.ltx2.infer.module_io import LTX2PreInferModuleOutput from lightx2v.models.networks.ltx2.infer.triton_ops import fuse_scale_shift_kernel, fused_rmsnorm_modulate -from lightx2v.models.networks.ltx2.infer.utils import LTXRopeType, apply_rotary_emb, modulate_torch_naive, modulate_with_rmsnorm_torch_naive, rmsnorm_torch_naive +from lightx2v.models.networks.ltx2.infer.utils import apply_rotary_emb, modulate_torch_naive, modulate_with_rmsnorm_torch_naive, rmsnorm_torch_naive from lightx2v.models.networks.wan.infer.triton_ops import norm_infer @@ -31,13 +32,31 @@ def __init__(self, config): config: Model configuration dictionary """ self.config = config - self.rope_type = LTXRopeType(config["rope_type"]) + self.rope_type = config["rope_type"] self.blocks_num = config.get("num_layers", 48) self.v_num_heads = config.get("num_attention_heads", 32) self.v_head_dim = config.get("attention_head_dim", 128) self.a_num_heads = config.get("audio_num_attention_heads", 32) self.a_head_dim = config.get("audio_attention_head_dim", 64) self.clean_cuda_cache = config.get("clean_cuda_cache", False) + + if config.get("seq_parallel", False): + self.seq_p_group = config.get("device_mesh").get_group(mesh_dim="seq_p") + self.seq_p_fp8_comm = config.get("parallel", {}).get("seq_p_fp8_comm", False) + else: + self.seq_p_group = None + self.seq_p_fp8_comm = False + + # Initialize tensor parallel group + if config.get("tensor_parallel", False): + self.tp_group = config.get("device_mesh").get_group(mesh_dim="tensor_p") + self.tp_rank = dist.get_rank(self.tp_group) + self.tp_size = dist.get_world_size(self.tp_group) + else: + self.tp_group = None + self.tp_rank = 0 + self.tp_size = 1 + if config.get("norm_modulate_backend", "triton") == "triton": self.norm_infer_func = norm_infer self.modulate_func = fuse_scale_shift_kernel @@ -46,11 +65,68 @@ def __init__(self, config): self.norm_infer_func = rmsnorm_torch_naive self.modulate_func = modulate_torch_naive self.modulate_with_rmsnorm_func = modulate_with_rmsnorm_torch_naive + self.reset_infer_states() def set_scheduler(self, scheduler): """Set the scheduler for inference.""" self.scheduler = scheduler + def reset_infer_states(self): + """Reset inference states for cumulative sequence lengths.""" + # Only cache cu_seqlens_qkv for self-attention (q, k, v have same length) + # For cross-attention, cu_seqlens_kv varies by context type, create dynamically + self.v_attn_cu_seqlens_qkv = None # For video self-attention + self.a_attn_cu_seqlens_qkv = None # For audio self-attention + + def _create_cu_seqlens(self, seq_len: int, device: torch.device) -> torch.Tensor: + """ + Create cumulative sequence lengths tensor for attention. + + Args: + seq_len: Sequence length + device: Device to place the tensor on + + Returns: + Cumulative sequence lengths tensor [0, seq_len] + """ + if self.config["attn_type"] in ["flash_attn2", "flash_attn3"]: + return torch.tensor([0, seq_len]).cumsum(0, dtype=torch.int32).to(device, non_blocking=True) + else: + return torch.tensor([0, seq_len]).cumsum(0, dtype=torch.int32) + + def _gather_cross_attn_context(self, context: torch.Tensor, k_pe=None): + """ + Gather context and k_pe from all ranks for cross-attention in sequence parallel mode. + + Args: + context: Local context tensor to gather + k_pe: Optional tuple of (cos_freqs, sin_freqs) for key positional embeddings + + Returns: + Tuple of (gathered_context, gathered_k_pe) + """ + world_size = dist.get_world_size(self.seq_p_group) + + # Gather context + context_gathered = [torch.zeros_like(context) for _ in range(world_size)] + dist.all_gather(context_gathered, context, group=self.seq_p_group) + gathered_context = torch.cat(context_gathered, dim=0) + + # Gather k_pe if provided + gathered_k_pe = k_pe + if k_pe is not None: + cos_freqs, sin_freqs = k_pe + # Determine sequence dimension: for 4D tensors [B, H, T, D], seq_dim is 2 + seq_dim = 2 if cos_freqs.dim() == 4 else (0 if cos_freqs.dim() == 2 else 1) + + cos_gathered = [torch.zeros_like(cos_freqs) for _ in range(world_size)] + sin_gathered = [torch.zeros_like(sin_freqs) for _ in range(world_size)] + dist.all_gather(cos_gathered, cos_freqs.contiguous(), group=self.seq_p_group) + dist.all_gather(sin_gathered, sin_freqs.contiguous(), group=self.seq_p_group) + gathered_k_pe = (torch.cat(cos_gathered, dim=seq_dim), torch.cat(sin_gathered, dim=seq_dim)) + + return gathered_context, gathered_k_pe + def _infer_attn( self, attn_phase, @@ -59,38 +135,109 @@ def _infer_attn( pe=None, k_pe=None, is_audio=False, + need_gather_video_context=False, # Only True for video-to-audio cross-attention ) -> torch.Tensor: + """ + Unified attention inference method supporting both TP and non-TP modes. + + Args: + attn_phase: LTX2Attention or LTX2AttentionTP instance + x: Input tensor [seq_len, hidden_dim] + context: Context tensor for cross-attention (None for self-attention) + pe: Positional embeddings for query + k_pe: Positional embeddings for key + is_audio: Whether this is audio attention + need_gather_video_context: Whether to gather video context for cross-attention (only for SP) + + Returns: + Attention output tensor [seq_len, hidden_dim] + """ + use_tp = self.tp_size > 1 + is_self_attn = context is None + context = x if is_self_attn else context + q = attn_phase.to_q.apply(x) - context = x if context is None else context + # For sequence parallel (non-TP), gather context if needed + if need_gather_video_context and self.config.get("seq_parallel", False) and not use_tp: + context, k_pe = self._gather_cross_attn_context(context, k_pe) k = attn_phase.to_k.apply(context) v = attn_phase.to_v.apply(context) q = attn_phase.q_norm.apply(q) k = attn_phase.k_norm.apply(k) + if pe is not None: q = apply_rotary_emb(q, pe, self.rope_type).squeeze() k = apply_rotary_emb(k, pe if k_pe is None else k_pe, self.rope_type).squeeze() - out = attn_phase.attn_func.apply( - q=q.view( - -1, - self.v_num_heads if not is_audio else self.a_num_heads, - self.v_head_dim if not is_audio else self.a_head_dim, - ), - k=k.view( - -1, - self.v_num_heads if not is_audio else self.a_num_heads, - self.v_head_dim if not is_audio else self.a_head_dim, - ), - v=v.view( - -1, - self.v_num_heads if not is_audio else self.a_num_heads, - self.v_head_dim if not is_audio else self.a_head_dim, - ), - max_seqlen_q=q.size(0), - ) + + num_heads = self.v_num_heads if not is_audio else self.a_num_heads + head_dim = self.v_head_dim if not is_audio else self.a_head_dim + seq_len = q.size(0) + + # For TP, heads are split across ranks + num_heads_effective = num_heads // self.tp_size if use_tp else num_heads + + q = q.view(-1, num_heads_effective, head_dim) + k = k.view(-1, num_heads_effective, head_dim) + v = v.view(-1, num_heads_effective, head_dim) + + # For video self-attention with sequence parallel (non-TP only) + if is_self_attn and not is_audio and self.config.get("seq_parallel", False) and not use_tp: + # Cache cu_seqlens_qkv for self-attention (q, k, v have same length) + if self.v_attn_cu_seqlens_qkv is None: + self.v_attn_cu_seqlens_qkv = self._create_cu_seqlens(q.shape[0], q.device) + + out = attn_phase.attn_func_parallel.apply( + q=q, + k=k, + v=v, + slice_qkv_len=seq_len, + cu_seqlens_qkv=self.v_attn_cu_seqlens_qkv, + attention_module=attn_phase.attn_func, + attention_type=self.config["attn_type"], + seq_p_group=self.seq_p_group, + use_fp8_comm=self.seq_p_fp8_comm, + use_tensor_fusion=False, + enable_head_parallel=False, + ) + else: + # For all other attention types (cross-attn, audio self-attn, TP, non-parallel) + # Cache cu_seqlens_qkv for self-attention only + if is_self_attn: + if not is_audio and self.v_attn_cu_seqlens_qkv is None: + self.v_attn_cu_seqlens_qkv = self._create_cu_seqlens(q.shape[0], q.device) + elif is_audio and self.a_attn_cu_seqlens_qkv is None: + self.a_attn_cu_seqlens_qkv = self._create_cu_seqlens(q.shape[0], q.device) + + cu_seqlens_q = self.v_attn_cu_seqlens_qkv if not is_audio else self.a_attn_cu_seqlens_qkv + cu_seqlens_kv = cu_seqlens_q # For self-attn, q and k have same length + else: + # For cross-attention, always create cu_seqlens dynamically + # because k length varies by context type (text, audio, video) + cu_seqlens_q = self._create_cu_seqlens(q.shape[0], q.device) + cu_seqlens_kv = self._create_cu_seqlens(k.shape[0], k.device) + + out = attn_phase.attn_func.apply( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=q.size(0), + max_seqlen_kv=k.size(0), + ) return attn_phase.to_out.apply(out) def _infer_ffn(self, ffn_phase, x: torch.Tensor) -> torch.Tensor: - """Apply feed-forward network.""" + """ + Unified feed-forward network inference method supporting both TP and non-TP modes. + + Args: + ffn_phase: LTX2FFN or LTX2FFNTP instance + x: Input tensor [seq_len, hidden_dim] + + Returns: + FFN output tensor [seq_len, hidden_dim] + """ x = ffn_phase.net_0_proj.apply(x) x = torch.nn.functional.gelu(x, approximate="tanh") return ffn_phase.net_2.apply(x) @@ -117,6 +264,8 @@ def infer_block(self, block, vx, ax, pre_infer_out: LTX2PreInferModuleOutput): ) norm_vx = self.modulate_with_rmsnorm_func(vx, vscale_msa, vshift_msa, weight=None, bias=None, eps=1e-6) + + # Video self-attention vx = ( vx + self._infer_attn( @@ -127,7 +276,7 @@ def infer_block(self, block, vx, ax, pre_infer_out: LTX2PreInferModuleOutput): ) * vgate_msa ) - + # Video cross-attention vx = vx + self._infer_attn( attn_phase=block.compute_phases[1], x=self.norm_infer_func(vx, weight=None, bias=None, eps=1e-6), @@ -145,6 +294,8 @@ def infer_block(self, block, vx, ax, pre_infer_out: LTX2PreInferModuleOutput): ) norm_ax = self.modulate_with_rmsnorm_func(ax, ascale_msa, ashift_msa, weight=None, bias=None, eps=1e-6) + + # Audio self-attention ax = ( ax + self._infer_attn( @@ -155,7 +306,7 @@ def infer_block(self, block, vx, ax, pre_infer_out: LTX2PreInferModuleOutput): ) * agate_msa ) - + # Audio cross-attention ax = ax + self._infer_attn( attn_phase=block.compute_phases[3], x=self.norm_infer_func(ax, weight=None, bias=None, eps=1e-6), @@ -196,12 +347,12 @@ def infer_block(self, block, vx, ax, pre_infer_out: LTX2PreInferModuleOutput): ) # Audio-to-video cross-attention - # vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v) + shift_ca_video_hidden_states_a2v - # ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v - + # Video queries attend to audio context + # Audio is global (not split), so no need to gather vx_scaled = self.modulate_func(vx_norm3, scale_ca_video_hidden_states_a2v, shift_ca_video_hidden_states_a2v) ax_scaled = self.modulate_func(ax_norm3, scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v) + # Audio-to-video cross-attention vx = ( vx + self._infer_attn( @@ -211,13 +362,14 @@ def infer_block(self, block, vx, ax, pre_infer_out: LTX2PreInferModuleOutput): pe=pre_infer_out.video_args.cross_positional_embeddings, k_pe=pre_infer_out.audio_args.cross_positional_embeddings, is_audio=True, + need_gather_video_context=False, # Audio is global, no gather needed ) * gate_out_a2v ) # Video-to-audio cross-attention - # ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a - # vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a + # Audio queries need to attend to full video context + # In TP, video is NOT split (unlike SP), so no gather needed ax_scaled = self.modulate_func(ax_norm3, scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a) vx_scaled = self.modulate_func(vx_norm3, scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a) @@ -230,6 +382,7 @@ def infer_block(self, block, vx, ax, pre_infer_out: LTX2PreInferModuleOutput): pe=pre_infer_out.audio_args.cross_positional_embeddings, k_pe=pre_infer_out.video_args.cross_positional_embeddings, is_audio=True, + need_gather_video_context=not (self.tp_size > 1), # Need gather for SP, not for TP ) * gate_out_v2a ) @@ -282,6 +435,9 @@ def infer(self, weights, pre_infer_out: LTX2PreInferModuleOutput): Returns: Tuple of (video_x, audio_x, video_timestep, audio_timestep) after transformer blocks """ + # Reset inference states at the beginning of each inference + self.reset_infer_states() + vx = pre_infer_out.video_args.x ax = pre_infer_out.audio_args.x diff --git a/lightx2v/models/networks/ltx2/infer/utils.py b/lightx2v/models/networks/ltx2/infer/utils.py index f059466cc..cfc35bf7a 100755 --- a/lightx2v/models/networks/ltx2/infer/utils.py +++ b/lightx2v/models/networks/ltx2/infer/utils.py @@ -1,6 +1,5 @@ import functools import math -from enum import Enum from typing import Callable, Tuple import numpy as np @@ -71,19 +70,14 @@ def get_timestep_embedding( return emb -class LTXRopeType(Enum): - INTERLEAVED = "interleaved" - SPLIT = "split" - - def apply_rotary_emb( input_tensor: torch.Tensor, freqs_cis: Tuple[torch.Tensor, torch.Tensor], - rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + rope_type: str = "split", ) -> torch.Tensor: - if rope_type == LTXRopeType.INTERLEAVED: + if rope_type == "interleaved": return apply_interleaved_rotary_emb(input_tensor, *freqs_cis) - elif rope_type == LTXRopeType.SPLIT: + elif rope_type == "split": return apply_split_rotary_emb(input_tensor, *freqs_cis) else: raise ValueError(f"Invalid rope type: {rope_type}") @@ -234,7 +228,7 @@ def precompute_freqs_cis( max_pos: list[int] | None = None, use_middle_indices_grid: bool = False, num_attention_heads: int = 32, - rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + rope_type: str = "split", freq_grid_generator: Callable[[float, int, int, torch.device], torch.Tensor] = generate_freq_grid_pytorch, ) -> tuple[torch.Tensor, torch.Tensor]: if max_pos is None: @@ -243,7 +237,7 @@ def precompute_freqs_cis( indices = freq_grid_generator(theta, indices_grid.shape[1], dim) freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid) - if rope_type == LTXRopeType.SPLIT: + if rope_type == "split": expected_freqs = dim // 2 current_freqs = freqs.shape[-1] pad_size = expected_freqs - current_freqs diff --git a/lightx2v/models/networks/ltx2/model.py b/lightx2v/models/networks/ltx2/model.py index 17e572d63..22890ec31 100755 --- a/lightx2v/models/networks/ltx2/model.py +++ b/lightx2v/models/networks/ltx2/model.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist +import torch.nn.functional as F from loguru import logger from safetensors import safe_open @@ -40,6 +41,28 @@ def __init__(self, model_path, config, device, lora_path=None, lora_strength=1.0 self.cpu_offload = self.config.get("cpu_offload", False) self.offload_granularity = self.config.get("offload_granularity", "block") + # Initialize sequence parallel group + if self.config.get("seq_parallel", False): + self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") + else: + self.seq_p_group = None + + if self.config.get("tensor_parallel", False): + self.use_tp = True + self.tp_group = self.config.get("device_mesh").get_group(mesh_dim="tensor_p") + self.tp_rank = dist.get_rank(self.tp_group) + self.tp_size = dist.get_world_size(self.tp_group) + else: + self.tp_group = None + self.use_tp = False + self.tp_rank = 0 + self.tp_size = 1 + + self.padding_multiple = self.config.get("padding_multiple", 1) + + # Track original video sequence length before padding (for sequence parallel) + self.original_video_seq_len = None + # self.model_type = model_type self.remove_keys = ["text_embedding_projection", "audio_vae", "vae", "vocoder", "model.diffusion_model.audio_embeddings_connector", "model.diffusion_model.video_embeddings_connector"] self.lazy_load = self.config.get("lazy_load", False) @@ -96,7 +119,7 @@ def _should_load_weights(self): # Single GPU mode return True elif dist.is_initialized(): - if self.config.get("load_from_rank0", False): + if self.use_tp: # Multi-GPU mode, only rank 0 loads if dist.get_rank() == 0: logger.info(f"Loading weights from {self.model_path}") @@ -236,12 +259,9 @@ def _init_weights(self, weight_dict=None): # Load quantized weights weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer) - if self.config.get("device_mesh") is not None and self.config.get("load_from_rank0", False): + if self.config.get("device_mesh") is not None and self.config.get("tensor_parallel", False): weight_dict = self._load_weights_from_rank0(weight_dict, is_weight_loader) - if hasattr(self, "adapter_weights_dict"): - weight_dict.update(self.adapter_weights_dict) - self.original_weight_dict = weight_dict else: self.original_weight_dict = weight_dict @@ -273,64 +293,199 @@ def _apply_weights(self, weight_dict=None): gc.collect() def _load_weights_from_rank0(self, weight_dict, is_weight_loader): - logger.info("Loading distributed weights") + """ + Load and distribute weights from rank 0 to all ranks. + + Only supports tensor parallel mode with CUDA device. + CPU offload is not supported. + """ + # CPU offload is not supported + if self.cpu_offload: + raise NotImplementedError("_load_weights_from_rank0 does not support CPU offload. Please set cpu_offload=False.") + + logger.info("Loading distributed weights with tensor parallel (CUDA only)") global_src_rank = 0 - target_device = "cpu" if self.cpu_offload else "cuda" if is_weight_loader: + # Rank 0: prepare weights (split for TP) + processed_weight_dict = {} meta_dict = {} + for key, tensor in weight_dict.items(): - meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype} + # Only process .weight keys for TP splitting (bias is handled separately) + if key.endswith(".weight") and self._is_tp_weight(key): + # Split weights for TP: create one entry per rank with original key name + split_weights = self._split_weight_for_tp(key, tensor, self.tp_size) + # Store all split weights temporarily (will be filtered by rank later) + for rank_idx in range(self.tp_size): + rank_key = f"{key}__tp_rank_{rank_idx}" + processed_weight_dict[rank_key] = split_weights[rank_idx] + if rank_idx == 0: # Use rank 0's shape for meta (all ranks have same shape after split) + meta_dict[key] = {"shape": split_weights[rank_idx].shape, "dtype": split_weights[rank_idx].dtype, "is_tp": True} + + # Also handle bias if it exists (for column split weights) + bias_key = key.replace(".weight", ".bias") + if bias_key in weight_dict and self._get_split_type(key) == "col": + # Column split: bias also needs to be split + bias_tensor = weight_dict[bias_key] + assert bias_tensor.shape[0] % self.tp_size == 0, f"bias dimension ({bias_tensor.shape[0]}) must be divisible by tp_size ({self.tp_size}) for {bias_key}" + chunk_size = bias_tensor.shape[0] // self.tp_size + for rank_idx in range(self.tp_size): + rank_bias_key = f"{bias_key}__tp_rank_{rank_idx}" + start_idx = rank_idx * chunk_size + end_idx = start_idx + chunk_size + processed_weight_dict[rank_bias_key] = bias_tensor[start_idx:end_idx] + if rank_idx == 0: + meta_dict[bias_key] = {"shape": bias_tensor[start_idx:end_idx].shape, "dtype": bias_tensor.dtype, "is_tp": True} + # For row split weights, bias is not split (added after all-reduce) + else: + # Non-TP weights or bias (bias is handled above if it's a TP weight's bias) + # Skip bias keys that are already processed above + if not (key.endswith(".bias") and key.replace(".bias", ".weight") in processed_weight_dict): + processed_weight_dict[key] = tensor + meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype, "is_tp": False} obj_list = [meta_dict] dist.broadcast_object_list(obj_list, src=global_src_rank) synced_meta_dict = obj_list[0] + weight_dict = processed_weight_dict # Use processed weights else: obj_list = [None] dist.broadcast_object_list(obj_list, src=global_src_rank) synced_meta_dict = obj_list[0] + # Allocate tensors on CUDA distributed_weight_dict = {} for key, meta in synced_meta_dict.items(): - distributed_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) + is_tp = meta.get("is_tp", False) + device = torch.device(f"cuda:{torch.cuda.current_device()}") + if is_tp: + # TP weight: each rank gets its own slice + distributed_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device=device) + else: + # Non-TP weight: all ranks get full weight + distributed_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device=device) - if target_device == "cuda": - dist.barrier(device_ids=[torch.cuda.current_device()]) + dist.barrier() + # Distribute weights for key in sorted(synced_meta_dict.keys()): - if is_weight_loader: - distributed_weight_dict[key].copy_(weight_dict[key], non_blocking=True) + meta = synced_meta_dict[key] + is_tp = meta.get("is_tp", False) + + if is_tp: + # TP weight: rank 0 sends different slices to each rank + # Use send/recv to ensure each rank gets its own slice + dist.barrier(group=self.tp_group) - if target_device == "cpu": if is_weight_loader: - gpu_tensor = distributed_weight_dict[key].cuda() - dist.broadcast(gpu_tensor, src=global_src_rank) - distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True) - del gpu_tensor - torch.cuda.empty_cache() + # Rank 0: send each rank's slice in order + for rank_idx in range(self.tp_size): + rank_key = f"{key}__tp_rank_{rank_idx}" + if rank_key in weight_dict: + if rank_idx == self.tp_rank: + # Copy to my own buffer + distributed_weight_dict[key].copy_(weight_dict[rank_key], non_blocking=True) + else: + # Send to other ranks + dist.send(weight_dict[rank_key].contiguous(), dst=rank_idx, group=self.tp_group) else: - gpu_tensor = torch.empty_like(distributed_weight_dict[key], device="cuda") - dist.broadcast(gpu_tensor, src=global_src_rank) - distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True) - del gpu_tensor - torch.cuda.empty_cache() - - if distributed_weight_dict[key].is_pinned(): - distributed_weight_dict[key].copy_(distributed_weight_dict[key], non_blocking=True) + # Other ranks: receive from rank 0 + dist.recv(distributed_weight_dict[key], src=global_src_rank, group=self.tp_group) else: + # Non-TP weight: broadcast to all ranks + if is_weight_loader: + distributed_weight_dict[key].copy_(weight_dict[key], non_blocking=True) + dist.broadcast(distributed_weight_dict[key], src=global_src_rank) - if target_device == "cuda": - torch.cuda.synchronize() - else: - for tensor in distributed_weight_dict.values(): - if tensor.is_pinned(): - tensor.copy_(tensor, non_blocking=False) + torch.cuda.synchronize() - logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}") + logger.info(f"Weights distributed across {dist.get_world_size()} devices on CUDA") return distributed_weight_dict + def _is_tp_weight(self, key): + """Check if a weight key needs TP splitting. + + TP weights include: + - Attention layers: to_q, to_k, to_v, to_out.0, q_norm, k_norm + - FFN layers: net.0.proj, net.2 + """ + # Generic patterns that apply to all attention and FFN layers + tp_patterns = [ + ".to_q.", + ".to_k.", + ".to_v.", + ".to_out.0.", + ".q_norm.", + ".k_norm.", + ".net.0.proj.", + ".net.2.", + ] + return any(pattern in key for pattern in tp_patterns) + + def _get_split_type(self, key): + """Determine the split type for a weight key. + + Returns: + "col": Column split (to_q, to_k, to_v, net.0.proj) + "row": Row split (to_out.0, net.2) + "norm": Norm split (q_norm, k_norm) + None: No split needed + """ + if ".q_norm." in key or ".k_norm." in key: + return "norm" + elif ".to_q." in key or ".to_k." in key or ".to_v." in key or ".net.0.proj." in key: + return "col" + elif ".to_out.0." in key or ".net.2." in key: + return "row" + return None + + def _split_weight_for_tp(self, key, weight, tp_size): + """ + Split a weight tensor for tensor parallel. + Returns a list of split weights, one for each rank. + """ + split_type = self._get_split_type(key) + if split_type is None: + # Unknown pattern, don't split + return [weight] * tp_size + + if split_type == "norm": + # 1D weights (norm weights): [hidden_dim] -> [hidden_dim/tp_size] per rank + assert weight.dim() == 1, f"Norm weight should be 1D, got {weight.dim()}D for {key}" + assert weight.shape[0] % tp_size == 0, f"hidden_dim ({weight.shape[0]}) must be divisible by tp_size ({tp_size}) for {key}" + chunk_size = weight.shape[0] // tp_size + return [weight[rank_idx * chunk_size : (rank_idx + 1) * chunk_size] for rank_idx in range(tp_size)] + + # 2D weights (linear layer weights) + assert weight.dim() == 2, f"Linear weight should be 2D, got {weight.dim()}D for {key}" + + # Transpose to [in_dim, out_dim] format for easier splitting + weight_t = weight.t() # [in_dim, out_dim] + + if split_type == "col": + # Column split: [out_dim, in_dim] -> [out_dim/tp_size, in_dim] per rank + # Split along out_dim dimension + assert weight_t.shape[1] % tp_size == 0, f"out_dim ({weight_t.shape[1]}) must be divisible by tp_size ({tp_size}) for {key}" + chunk_size = weight_t.shape[1] // tp_size + split_weights = [] + for rank_idx in range(tp_size): + split_weight = weight_t[:, rank_idx * chunk_size : (rank_idx + 1) * chunk_size].t() # Back to [out_dim/tp_size, in_dim] + split_weights.append(split_weight) + else: # split_type == "row" + # Row split: [out_dim, in_dim] -> [out_dim, in_dim/tp_size] per rank + # Split along in_dim dimension + assert weight_t.shape[0] % tp_size == 0, f"in_dim ({weight_t.shape[0]}) must be divisible by tp_size ({tp_size}) for {key}" + chunk_size = weight_t.shape[0] // tp_size + split_weights = [] + for rank_idx in range(tp_size): + split_weight = weight_t[rank_idx * chunk_size : (rank_idx + 1) * chunk_size, :].t() # Back to [out_dim, in_dim/tp_size] + split_weights.append(split_weight) + + return split_weights + def _init_infer(self): self.pre_infer = self.pre_infer_class(self.config) self.post_infer = self.post_infer_class(self.config) @@ -432,6 +587,221 @@ def infer(self, inputs): def _infer_cond_uncond(self, inputs, infer_condition=True): self.scheduler.infer_condition = infer_condition pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs) + + # Apply sequence parallel pre-processing (only for video) + if self.config.get("seq_parallel", False): + pre_infer_out = self._seq_parallel_pre_process(pre_infer_out) + + # Apply tensor parallel pre-processing (split positional embeddings) + if self.config.get("tensor_parallel", False): + pre_infer_out = self._tensor_parallel_pre_process(pre_infer_out) + vx, ax, video_embedded_timestep, audio_embedded_timestep = self.transformer_infer.infer(self.transformer_weights, pre_infer_out) + + # Apply sequence parallel post-processing (only for video) + if self.config.get("seq_parallel", False): + vx = self._seq_parallel_post_process(vx, self.original_video_seq_len) + video_embedded_timestep = self._seq_parallel_post_process(video_embedded_timestep, self.original_video_seq_len) + # Audio remains global, no gather needed + vx, ax = self.post_infer.infer(self.post_weight, vx, ax, video_embedded_timestep, audio_embedded_timestep) return vx, ax + + @torch.no_grad() + def _seq_parallel_pre_process(self, pre_infer_out): + """ + Pre-process for sequence parallel: only split video sequences across ranks. + Audio remains global (not split) as it has fewer tokens. + """ + world_size = dist.get_world_size(self.seq_p_group) + cur_rank = dist.get_rank(self.seq_p_group) + + # Only process video args for sequence parallel + if pre_infer_out.video_args is not None: + # Split x (latent) + vx = pre_infer_out.video_args.x + self.original_video_seq_len = vx.shape[0] # Record original length before padding + multiple = world_size * self.padding_multiple + padding_size = (multiple - (vx.shape[0] % multiple)) % multiple + if padding_size > 0: + vx = F.pad(vx, (0, 0, 0, padding_size)) + pre_infer_out.video_args.x = torch.chunk(vx, world_size, dim=0)[cur_rank] + + # Split positional embeddings (cos_freqs, sin_freqs) for video self-attention + if pre_infer_out.video_args.positional_embeddings is not None: + v_cos, v_sin = pre_infer_out.video_args.positional_embeddings + # For 4D tensors: [batch, num_heads, seq_len, head_dim], seq_len is at dim=2 + if v_cos.dim() == 4: + seq_dim = 2 + elif v_cos.dim() == 2: + seq_dim = 0 + else: + seq_dim = 1 + + seq_len = v_cos.shape[seq_dim] + padding_size = (multiple - (seq_len % multiple)) % multiple + if padding_size > 0: + pad_spec = [0, 0] * (v_cos.dim() - seq_dim - 1) + [0, padding_size] + [0, 0] * seq_dim + v_cos = F.pad(v_cos, pad_spec) + v_sin = F.pad(v_sin, pad_spec) + + pre_infer_out.video_args.positional_embeddings = (torch.chunk(v_cos, world_size, dim=seq_dim)[cur_rank], torch.chunk(v_sin, world_size, dim=seq_dim)[cur_rank]) + + # Split cross-attention positional embeddings for cross-modal attention + if pre_infer_out.video_args.cross_positional_embeddings is not None: + v_cross_cos, v_cross_sin = pre_infer_out.video_args.cross_positional_embeddings + if v_cross_cos.dim() == 4: + seq_dim = 2 + elif v_cross_cos.dim() == 2: + seq_dim = 0 + else: + seq_dim = 1 + + seq_len = v_cross_cos.shape[seq_dim] + padding_size = (multiple - (seq_len % multiple)) % multiple + if padding_size > 0: + pad_spec = [0, 0] * (v_cross_cos.dim() - seq_dim - 1) + [0, padding_size] + [0, 0] * seq_dim + v_cross_cos = F.pad(v_cross_cos, pad_spec) + v_cross_sin = F.pad(v_cross_sin, pad_spec) + + pre_infer_out.video_args.cross_positional_embeddings = (torch.chunk(v_cross_cos, world_size, dim=seq_dim)[cur_rank], torch.chunk(v_cross_sin, world_size, dim=seq_dim)[cur_rank]) + + # Split timestep embeddings (sequence-length dependent) + if pre_infer_out.video_args.timesteps is not None: + v_timesteps = pre_infer_out.video_args.timesteps + padding_size = (multiple - (v_timesteps.shape[0] % multiple)) % multiple + if padding_size > 0: + v_timesteps = F.pad(v_timesteps, (0, 0, 0, padding_size)) + pre_infer_out.video_args.timesteps = torch.chunk(v_timesteps, world_size, dim=0)[cur_rank] + + if pre_infer_out.video_args.embedded_timestep is not None: + v_embedded_timestep = pre_infer_out.video_args.embedded_timestep + padding_size = (multiple - (v_embedded_timestep.shape[0] % multiple)) % multiple + if padding_size > 0: + v_embedded_timestep = F.pad(v_embedded_timestep, (0, 0, 0, padding_size)) + pre_infer_out.video_args.embedded_timestep = torch.chunk(v_embedded_timestep, world_size, dim=0)[cur_rank] + + if pre_infer_out.video_args.cross_scale_shift_timestep is not None: + v_cross_ss = pre_infer_out.video_args.cross_scale_shift_timestep + padding_size = (multiple - (v_cross_ss.shape[0] % multiple)) % multiple + if padding_size > 0: + v_cross_ss = F.pad(v_cross_ss, (0, 0, 0, padding_size)) + pre_infer_out.video_args.cross_scale_shift_timestep = torch.chunk(v_cross_ss, world_size, dim=0)[cur_rank] + + if pre_infer_out.video_args.cross_gate_timestep is not None: + v_cross_gate = pre_infer_out.video_args.cross_gate_timestep + padding_size = (multiple - (v_cross_gate.shape[0] % multiple)) % multiple + if padding_size > 0: + v_cross_gate = F.pad(v_cross_gate, (0, 0, 0, padding_size)) + pre_infer_out.video_args.cross_gate_timestep = torch.chunk(v_cross_gate, world_size, dim=0)[cur_rank] + + # Audio remains global - no splitting needed + # Audio has fewer tokens, so we keep it on all ranks + + return pre_infer_out + + @torch.no_grad() + def _tensor_parallel_pre_process(self, pre_infer_out): + """ + Pre-process for tensor parallel: split positional embeddings along head dimension. + + In tensor parallel, QKV projections are split along hidden_dim (which equals num_heads * head_dim), + so positional embeddings need to be split along the head dimension to match. + """ + if not self.config.get("tensor_parallel", False): + return pre_infer_out + + tp_group = self.config.get("device_mesh").get_group(mesh_dim="tensor_p") + tp_rank = dist.get_rank(tp_group) + tp_size = dist.get_world_size(tp_group) + + # Get num_heads and head_dim from config + v_num_heads = self.config.get("num_attention_heads", 32) + v_head_dim = self.config.get("attention_head_dim", 128) + a_num_heads = self.config.get("audio_num_attention_heads", 32) + a_head_dim = self.config.get("audio_attention_head_dim", 64) + + v_num_heads_per_rank = v_num_heads // tp_size + a_num_heads_per_rank = a_num_heads // tp_size + + def split_pe(pe, num_heads, num_heads_per_rank, tp_rank, tp_size): + """Split positional embeddings along head dimension.""" + if pe is None: + return None + + if isinstance(pe, tuple): + cos_freqs, sin_freqs = pe + + if cos_freqs.dim() == 2 and cos_freqs.shape[0] == num_heads: + # Shape: [num_heads, head_dim] - split along num_heads dimension + cos_freqs_split = cos_freqs[tp_rank * num_heads_per_rank : (tp_rank + 1) * num_heads_per_rank, :] + sin_freqs_split = sin_freqs[tp_rank * num_heads_per_rank : (tp_rank + 1) * num_heads_per_rank, :] + return (cos_freqs_split, sin_freqs_split) + elif cos_freqs.dim() == 4: + # Shape: [B, H, T, D] where H=num_heads, D=head_dim//2 (for SPLIT rope type) + # In apply_split_rotary_emb, if input is 2D [seq_len, num_heads_per_rank * head_dim] + # and PE is 4D [B, H, T, D], it will reshape input to [B, T, H, head_dim] then swapaxes to [B, H, T, head_dim] + # So H in PE should match num_heads_per_rank in the input + # Therefore, we need to split along H dimension: [B, H, T, D] -> [B, H/tp_size, T, D] + assert cos_freqs.shape[1] == num_heads, f"PE head dimension mismatch: cos_freqs.shape[1]={cos_freqs.shape[1]}, num_heads={num_heads}" + cos_freqs_split = cos_freqs[:, tp_rank * num_heads_per_rank : (tp_rank + 1) * num_heads_per_rank, :, :] + sin_freqs_split = sin_freqs[:, tp_rank * num_heads_per_rank : (tp_rank + 1) * num_heads_per_rank, :, :] + return (cos_freqs_split, sin_freqs_split) + else: + # For other shapes, split along last dimension (hidden_dim) + head_dim = cos_freqs.shape[-1] // num_heads if cos_freqs.dim() > 1 and cos_freqs.shape[-1] % num_heads == 0 else cos_freqs.shape[-1] + hidden_dim_per_rank = num_heads_per_rank * head_dim + start_idx = tp_rank * hidden_dim_per_rank + end_idx = start_idx + hidden_dim_per_rank + cos_freqs_split = cos_freqs[..., start_idx:end_idx] + sin_freqs_split = sin_freqs[..., start_idx:end_idx] + return (cos_freqs_split, sin_freqs_split) + else: + # pe is a single tensor, split along last dimension + head_dim = pe.shape[-1] // num_heads if pe.dim() > 1 and pe.shape[-1] % num_heads == 0 else pe.shape[-1] + hidden_dim_per_rank = num_heads_per_rank * head_dim + start_idx = tp_rank * hidden_dim_per_rank + end_idx = start_idx + hidden_dim_per_rank + return pe[..., start_idx:end_idx] + + # Process video args + if pre_infer_out.video_args is not None: + # Split positional embeddings for video self-attention + if pre_infer_out.video_args.positional_embeddings is not None: + pre_infer_out.video_args.positional_embeddings = split_pe(pre_infer_out.video_args.positional_embeddings, v_num_heads, v_num_heads_per_rank, tp_rank, tp_size) + + # Split cross-attention positional embeddings + if pre_infer_out.video_args.cross_positional_embeddings is not None: + pre_infer_out.video_args.cross_positional_embeddings = split_pe(pre_infer_out.video_args.cross_positional_embeddings, v_num_heads, v_num_heads_per_rank, tp_rank, tp_size) + + # Process audio args + if pre_infer_out.audio_args is not None: + # Split positional embeddings for audio self-attention + if pre_infer_out.audio_args.positional_embeddings is not None: + pre_infer_out.audio_args.positional_embeddings = split_pe(pre_infer_out.audio_args.positional_embeddings, a_num_heads, a_num_heads_per_rank, tp_rank, tp_size) + + # Split cross-attention positional embeddings + if pre_infer_out.audio_args.cross_positional_embeddings is not None: + pre_infer_out.audio_args.cross_positional_embeddings = split_pe(pre_infer_out.audio_args.cross_positional_embeddings, a_num_heads, a_num_heads_per_rank, tp_rank, tp_size) + + return pre_infer_out + + @torch.no_grad() + def _seq_parallel_post_process(self, x, original_length=None): + """ + Post-process for sequence parallel: gather results from all ranks and remove padding. + + Args: + x: Tensor to gather + original_length: Original sequence length before padding. If provided, truncate to this length. + """ + world_size = dist.get_world_size(self.seq_p_group) + gathered_x = [torch.empty_like(x) for _ in range(world_size)] + dist.all_gather(gathered_x, x, group=self.seq_p_group) + combined_output = torch.cat(gathered_x, dim=0) + + # Remove padding to restore original length + if original_length is not None and combined_output.shape[0] > original_length: + combined_output = combined_output[:original_length] + + return combined_output diff --git a/lightx2v/models/networks/ltx2/weights/transformer_weights.py b/lightx2v/models/networks/ltx2/weights/transformer_weights.py index b87015cd2..369b57c2b 100755 --- a/lightx2v/models/networks/ltx2/weights/transformer_weights.py +++ b/lightx2v/models/networks/ltx2/weights/transformer_weights.py @@ -1,3 +1,5 @@ +import torch.distributed as dist + from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList from lightx2v.utils.registry_factory import ( ATTN_WEIGHT_REGISTER, @@ -138,9 +140,34 @@ def __init__( self.scale_shift_table_a2v_ca_video, ) + # Check if tensor parallel is enabled + use_tp = config.get("tensor_parallel", False) + if use_tp: + tp_group = config.get("device_mesh").get_group(mesh_dim="tensor_p") + tp_rank = dist.get_rank(tp_group) + tp_size = dist.get_world_size(tp_group) + else: + tp_group = None + tp_rank = 0 + tp_size = 1 + + # Create attention and FFN modules based on tensor parallel config + if use_tp: + AttentionClass = LTX2AttentionTP + FFNClass = LTX2FFNTP + tp_kwargs = { + "tp_group": tp_group, + "tp_rank": tp_rank, + "tp_size": tp_size, + } + else: + AttentionClass = LTX2Attention + FFNClass = LTX2FFN + tp_kwargs = {} + self.compute_phases = WeightModuleList( [ - LTX2Attention( + AttentionClass( block_index=block_index, attn_prefix="attn1", block_prefix=block_prefix, @@ -152,8 +179,9 @@ def __init__( lazy_load=self.lazy_load, lazy_load_file=self.lazy_load_file, lora_path=lora_path, + **tp_kwargs, ), - LTX2Attention( + AttentionClass( block_index=block_index, attn_prefix="attn2", block_prefix=block_prefix, @@ -165,8 +193,9 @@ def __init__( lazy_load=self.lazy_load, lazy_load_file=self.lazy_load_file, lora_path=lora_path, + **tp_kwargs, ), - LTX2Attention( + AttentionClass( block_index=block_index, attn_prefix="audio_attn1", block_prefix=block_prefix, @@ -178,8 +207,9 @@ def __init__( lazy_load=self.lazy_load, lazy_load_file=self.lazy_load_file, lora_path=lora_path, + **tp_kwargs, ), - LTX2Attention( + AttentionClass( block_index=block_index, attn_prefix="audio_attn2", block_prefix=block_prefix, @@ -191,8 +221,9 @@ def __init__( lazy_load=self.lazy_load, lazy_load_file=self.lazy_load_file, lora_path=lora_path, + **tp_kwargs, ), - LTX2Attention( + AttentionClass( block_index=block_index, attn_prefix="audio_to_video_attn", block_prefix=block_prefix, @@ -204,8 +235,9 @@ def __init__( lazy_load=self.lazy_load, lazy_load_file=self.lazy_load_file, lora_path=lora_path, + **tp_kwargs, ), - LTX2Attention( + AttentionClass( block_index=block_index, attn_prefix="video_to_audio_attn", block_prefix=block_prefix, @@ -217,8 +249,9 @@ def __init__( lazy_load=self.lazy_load, lazy_load_file=self.lazy_load_file, lora_path=lora_path, + **tp_kwargs, ), - LTX2FFN( + FFNClass( block_index=block_index, ffn_prefix="ff", block_prefix=block_prefix, @@ -230,8 +263,9 @@ def __init__( lazy_load=self.lazy_load, lazy_load_file=self.lazy_load_file, lora_path=lora_path, + **tp_kwargs, ), - LTX2FFN( + FFNClass( block_index=block_index, ffn_prefix="audio_ff", block_prefix=block_prefix, @@ -243,6 +277,7 @@ def __init__( lazy_load=self.lazy_load, lazy_load_file=self.lazy_load_file, lora_path=lora_path, + **tp_kwargs, ), ] ) @@ -283,6 +318,14 @@ def __init__( model_prefix = "model.diffusion_model" self.add_module("attn_func", ATTN_WEIGHT_REGISTER[self.config["attn_type"]]()) + + # Add parallel attention module for sequence parallelism + if self.config.get("seq_parallel", False): + self.add_module( + "attn_func_parallel", + ATTN_WEIGHT_REGISTER[self.config.get("parallel", {}).get("seq_p_attn_type", "ulysses")](), + ) + self.add_module( f"q_norm", RMS_WEIGHT_REGISTER[self.attn_rms_type]( @@ -413,3 +456,236 @@ def __init__( lora_path=lora_path, ), ) + + +class LTX2AttentionTP(WeightModule): + """ + Tensor Parallel version of LTX2Attention. + + QKV projections are split column-wise, output projection is split row-wise. + """ + + def __init__( + self, + block_index, + attn_prefix, + block_prefix, + task, + mm_type, + config, + tp_group=None, + tp_rank=0, + tp_size=1, + create_cuda_buffer=False, + create_cpu_buffer=False, + lazy_load=False, + lazy_load_file=None, + lora_path=None, + ): + super().__init__() + self.block_index = block_index + self.mm_type = mm_type + self.task = task + self.config = config + self.quant_method = config.get("quant_method", None) + self.tp_group = tp_group + self.tp_rank = tp_rank + self.tp_size = tp_size + + self.lazy_load = lazy_load + self.lazy_load_file = lazy_load_file + self.attn_rms_type = self.config.get("rms_type", "sgl-kernel") + + block_lora_prefix = "model.diffusion_model.blocks" + model_prefix = "model.diffusion_model" + + self.add_module("attn_func", ATTN_WEIGHT_REGISTER[self.config["attn_type"]]()) + + # Use TP version of norm if tensor parallel is enabled + # Note: In TP, QKV outputs are split, so norm weights must also be split + norm_class = RMS_WEIGHT_REGISTER["TensorParallel"] + norm_kwargs = { + "tp_group": tp_group, + "tp_rank": tp_rank, + "tp_size": tp_size, + } + + self.add_module( + f"q_norm", + norm_class( + weight_name=f"{model_prefix}.{block_prefix}.{block_index}.{attn_prefix}.q_norm.weight", + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=self.lazy_load, + lazy_load_file=self.lazy_load_file, + lora_prefix=block_lora_prefix, + lora_path=lora_path, + **norm_kwargs, + ), + ) + self.add_module( + f"k_norm", + norm_class( + weight_name=f"{model_prefix}.{block_prefix}.{block_index}.{attn_prefix}.k_norm.weight", + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=self.lazy_load, + lazy_load_file=self.lazy_load_file, + lora_prefix=block_lora_prefix, + lora_path=lora_path, + **norm_kwargs, + ), + ) + # QKV projections: column split + self.add_module( + f"to_q", + MM_WEIGHT_REGISTER["TensorParallel"]( + weight_name=f"{model_prefix}.{block_prefix}.{block_index}.{attn_prefix}.to_q.weight", + bias_name=f"{model_prefix}.{block_prefix}.{block_index}.{attn_prefix}.to_q.bias", + mm_type=mm_type, + tp_group=tp_group, + tp_rank=tp_rank, + tp_size=tp_size, + split_dim="col", + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=self.lazy_load, + lazy_load_file=self.lazy_load_file, + lora_prefix=block_lora_prefix, + lora_path=lora_path, + ), + ) + self.add_module( + f"to_k", + MM_WEIGHT_REGISTER["TensorParallel"]( + weight_name=f"{model_prefix}.{block_prefix}.{block_index}.{attn_prefix}.to_k.weight", + bias_name=f"{model_prefix}.{block_prefix}.{block_index}.{attn_prefix}.to_k.bias", + mm_type=mm_type, + tp_group=tp_group, + tp_rank=tp_rank, + tp_size=tp_size, + split_dim="col", + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=self.lazy_load, + lazy_load_file=self.lazy_load_file, + lora_prefix=block_lora_prefix, + lora_path=lora_path, + ), + ) + self.add_module( + f"to_v", + MM_WEIGHT_REGISTER["TensorParallel"]( + weight_name=f"{model_prefix}.{block_prefix}.{block_index}.{attn_prefix}.to_v.weight", + bias_name=f"{model_prefix}.{block_prefix}.{block_index}.{attn_prefix}.to_v.bias", + mm_type=mm_type, + tp_group=tp_group, + tp_rank=tp_rank, + tp_size=tp_size, + split_dim="col", + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=self.lazy_load, + lazy_load_file=self.lazy_load_file, + lora_prefix=block_lora_prefix, + lora_path=lora_path, + ), + ) + # Output projection: row split (needs all-reduce) + self.add_module( + f"to_out", + MM_WEIGHT_REGISTER["TensorParallel"]( + weight_name=f"{model_prefix}.{block_prefix}.{block_index}.{attn_prefix}.to_out.0.weight", + bias_name=f"{model_prefix}.{block_prefix}.{block_index}.{attn_prefix}.to_out.0.bias", + mm_type=mm_type, + tp_group=tp_group, + tp_rank=tp_rank, + tp_size=tp_size, + split_dim="row", + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=self.lazy_load, + lazy_load_file=self.lazy_load_file, + lora_prefix=block_lora_prefix, + lora_path=lora_path, + ), + ) + + +class LTX2FFNTP(WeightModule): + """ + Tensor Parallel version of LTX2FFN. + + First layer (net_0_proj) is split column-wise, second layer (net_2) is split row-wise. + """ + + def __init__( + self, + block_index, + block_prefix, + ffn_prefix, + task, + mm_type, + config, + tp_group=None, + tp_rank=0, + tp_size=1, + create_cuda_buffer=False, + create_cpu_buffer=False, + lazy_load=False, + lazy_load_file=None, + lora_path=None, + ): + super().__init__() + self.block_index = block_index + self.mm_type = mm_type + self.task = task + self.config = config + self.quant_method = config.get("quant_method", None) + self.tp_group = tp_group + self.tp_rank = tp_rank + self.tp_size = tp_size + + self.lazy_load = lazy_load + self.lazy_load_file = lazy_load_file + block_lora_prefix = "model.diffusion_model.blocks" + model_prefix = "model.diffusion_model" + + # First layer: column split + self.add_module( + f"net_0_proj", + MM_WEIGHT_REGISTER["TensorParallel"]( + f"{model_prefix}.{block_prefix}.{block_index}.{ffn_prefix}.net.0.proj.weight", + f"{model_prefix}.{block_prefix}.{block_index}.{ffn_prefix}.net.0.proj.bias", + mm_type=mm_type, + tp_group=tp_group, + tp_rank=tp_rank, + tp_size=tp_size, + split_dim="col", + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=self.lazy_load, + lazy_load_file=self.lazy_load_file, + lora_prefix=block_lora_prefix, + lora_path=lora_path, + ), + ) + # Second layer: row split (needs all-reduce) + self.add_module( + f"net_2", + MM_WEIGHT_REGISTER["TensorParallel"]( + f"{model_prefix}.{block_prefix}.{block_index}.{ffn_prefix}.net.2.weight", + f"{model_prefix}.{block_prefix}.{block_index}.{ffn_prefix}.net.2.bias", + mm_type=mm_type, + tp_group=tp_group, + tp_rank=tp_rank, + tp_size=tp_size, + split_dim="row", + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=self.lazy_load, + lazy_load_file=self.lazy_load_file, + lora_prefix=block_lora_prefix, + lora_path=lora_path, + ), + ) diff --git a/lightx2v/models/networks/wan/infer/transformer_infer.py b/lightx2v/models/networks/wan/infer/transformer_infer.py index 705e1bb29..74133c181 100755 --- a/lightx2v/models/networks/wan/infer/transformer_infer.py +++ b/lightx2v/models/networks/wan/infer/transformer_infer.py @@ -1,6 +1,7 @@ from functools import partial import torch +import torch.distributed as dist from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer from lightx2v.utils.envs import * @@ -192,7 +193,7 @@ def infer_self_attn(self, phase, x, shift_msa, scale_msa): q = phase.self_attn_norm_q.apply(phase.self_attn_q.apply(norm1_out)).view(s, n, d) k = phase.self_attn_norm_k.apply(phase.self_attn_k.apply(norm1_out)).view(s, n, d) v = phase.self_attn_v.apply(norm1_out).view(s, n, d) - + print(dist.get_rank(), "qkv", q.shape, k.shape, v.shape) q, k = self.apply_rope_func(q, k, cos_sin) img_qkv_len = q.shape[0] @@ -226,6 +227,7 @@ def infer_self_attn(self, phase, x, shift_msa, scale_msa): enable_head_parallel=self.enable_head_parallel, **attn_running_args, ) + print(dist.get_rank(), "attn_out", attn_out.shape) else: attn_out = phase.self_attn_1.apply( q=q, diff --git a/lightx2v/models/networks/wan/model.py b/lightx2v/models/networks/wan/model.py index 2ed324792..e7825f32a 100755 --- a/lightx2v/models/networks/wan/model.py +++ b/lightx2v/models/networks/wan/model.py @@ -483,7 +483,9 @@ def _infer_cond_uncond(self, inputs, infer_condition=True): pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs) if self.config["seq_parallel"]: + print(dist.get_rank(), pre_infer_out.x.shape) pre_infer_out = self._seq_parallel_pre_process(pre_infer_out) + print(dist.get_rank(), pre_infer_out.x.shape) x = self.transformer_infer.infer(self.transformer_weights, pre_infer_out) diff --git a/lightx2v/models/runners/ltx2/ltx2_runner.py b/lightx2v/models/runners/ltx2/ltx2_runner.py index aaf686dfd..10ef01c8d 100755 --- a/lightx2v/models/runners/ltx2/ltx2_runner.py +++ b/lightx2v/models/runners/ltx2/ltx2_runner.py @@ -1,6 +1,8 @@ import gc +import os import torch +import torch.distributed as dist from lightx2v.models.input_encoders.hf.ltx2.model import LTX2TextEncoder from lightx2v.models.networks.ltx2.model import LTX2Model @@ -18,31 +20,6 @@ torch_device_module = getattr(torch, AI_DEVICE) -def parse_images_arg(images_str: str) -> list: - if not images_str or images_str.strip() == "": - return [] - - result = [] - for item in images_str.split(","): - parts = item.strip().split(":") - if len(parts) != 3: - raise ValueError(f"Invalid image conditioning format: '{item}'. Expected format: 'image_path:frame_idx:strength'") - - image_path = parts[0].strip() - try: - frame_idx = int(parts[1].strip()) - strength = float(parts[2].strip()) - except ValueError as e: - raise ValueError(f"Invalid image conditioning format: '{item}'. frame_idx must be int, strength must be float. Error: {e}") - - if strength < 0.0 or strength > 1.0: - raise ValueError(f"Invalid strength value {strength} for image '{image_path}'. Strength must be between 0.0 and 1.0.") - - result.append((image_path, frame_idx, strength)) - - return result - - @RUNNER_REGISTER("ltx2") class LTX2Runner(DefaultRunner): def __init__(self, config): @@ -100,6 +77,13 @@ def load_text_encoder(self): text_encoders = [text_encoder] return text_encoders + def get_vae_parallel(self): + if isinstance(self.config.get("parallel", False), bool): + return self.config.get("parallel", False) + if isinstance(self.config.get("parallel", False), dict): + return self.config.get("parallel", {}).get("vae_parallel", True) + return False + def load_vae(self): """Load video and audio VAE decoders.""" # offload config @@ -117,7 +101,15 @@ def load_vae(self): ckpt_path = os.path.join(self.config["model_path"], "transformer") # Video VAE - video_vae = LTX2VideoVAE(checkpoint_path=ckpt_path, device=vae_device, dtype=GET_DTYPE(), load_encoder=self.config["task"] == "i2av", cpu_offload=vae_offload) + video_vae = LTX2VideoVAE( + checkpoint_path=ckpt_path, + device=vae_device, + dtype=GET_DTYPE(), + load_encoder=self.config["task"] == "i2av", + use_tiling=self.config.get("use_tiling_vae", False), + cpu_offload=vae_offload, + parallel=self.get_vae_parallel(), + ) # Audio VAE audio_vae = LTX2AudioVAE(checkpoint_path=ckpt_path, device=vae_device, dtype=GET_DTYPE(), cpu_offload=vae_offload) @@ -163,9 +155,7 @@ def _run_input_encoder_local_t2av(self): def _run_input_encoder_local_i2av(self): self.input_info.video_latent_shape, self.input_info.audio_latent_shape = self.get_latent_shape_with_target_hw() text_encoder_output = self.run_text_encoder(self.input_info) - # Prepare image conditioning if provided - logger.info(f"🖼️ I2AV mode: processing {len(self.input_info.images)} image conditioning(s)") - self.video_denoise_mask, self.initial_video_latent = self._prepare_image_conditioning() + self.video_denoise_mask, self.initial_video_latent = self.run_vae_encoder() torch_device_module.empty_cache() gc.collect() @@ -173,7 +163,13 @@ def _run_input_encoder_local_i2av(self): "text_encoder_output": text_encoder_output, } - def _prepare_image_conditioning(self): + @ProfilingContext4DebugL1( + "Run VAE Encoder", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration, + metrics_labels=["LTX2Runner"], + ) + def run_vae_encoder(self): """ Prepare image conditioning by loading images and encoding them to latents. @@ -182,8 +178,6 @@ def _prepare_image_conditioning(self): - video_denoise_mask: Mask indicating which frames to denoise (unpatchified, shape [1, F, H, W]) - initial_video_latent: Initial latent with conditioned frames (unpatchified, shape [C, F, H, W]) """ - logger.info(f"🖼️ Preparing {len(self.input_info.images)} image conditioning(s)") - # Get latent shape C, F, H, W = self.input_info.video_latent_shape target_height = self.input_info.target_shape[0] if self.input_info.target_shape and len(self.input_info.target_shape) == 2 else self.config["target_height"] @@ -211,8 +205,12 @@ def _prepare_image_conditioning(self): ) # Process each image conditioning - images = parse_images_arg(self.input_info.images) - for image_path, frame_idx, strength in images: + image_paths = self.input_info.image_path.split(",") # image_path1,image_path2,image_path3 + for frame_idx, image_path in enumerate(image_paths): + if not isinstance(self.input_info.image_strength, list): + strength = self.input_info.image_strength + else: + strength = self.input_info.image_strength[frame_idx] logger.info(f" 📷 Loading image: {image_path} for frame {frame_idx} with strength {strength}") # Load and preprocess image @@ -224,12 +222,9 @@ def _prepare_image_conditioning(self): device=AI_DEVICE, ) - # Encode image to latent space - # image shape: [1, C, 1, H, W] with torch.no_grad(): encoded_latent = self.video_vae.encode(image) - # Remove batch dimension: [1, C, 1, H_latent, W_latent] -> [C, 1, H_latent, W_latent] encoded_latent = encoded_latent.squeeze(0) # Verify frame index is valid diff --git a/lightx2v/models/runners/wan/wan_distill_runner.py b/lightx2v/models/runners/wan/wan_distill_runner.py index a7b10284e..8abf606a4 100755 --- a/lightx2v/models/runners/wan/wan_distill_runner.py +++ b/lightx2v/models/runners/wan/wan_distill_runner.py @@ -21,7 +21,7 @@ def load_transformer(self): if not lora_configs: model = WanDistillModel(**wan_model_kwargs) else: - model = build_wan_model_with_lora(WanModel, self.config, wan_model_kwargs, lora_configs, model_typ="wan2.1") + model = build_wan_model_with_lora(WanModel, self.config, wan_model_kwargs, lora_configs, model_type="wan2.1") return model def init_scheduler(self): diff --git a/lightx2v/models/video_encoders/hf/ltx2/model.py b/lightx2v/models/video_encoders/hf/ltx2/model.py index 43c3ffb37..a6d1a2a0c 100755 --- a/lightx2v/models/video_encoders/hf/ltx2/model.py +++ b/lightx2v/models/video_encoders/hf/ltx2/model.py @@ -34,16 +34,23 @@ def __init__( device: torch.device, dtype: torch.dtype = torch.bfloat16, load_encoder: bool = True, + use_tiling: bool = False, cpu_offload: bool = False, + parallel: bool = False, + use_2d_split: bool = True, ): self.checkpoint_path = checkpoint_path self.device = device self.dtype = dtype self.load_encoder_flag = load_encoder + self.use_tiling = use_tiling + self.parallel = parallel + self.use_2d_split = use_2d_split self.loader = SafetensorsModelStateDictLoader() self.encoder = None self.decoder = None self.cpu_offload = cpu_offload + self.grid_table = {} # Cache for 2D grid calculations self.load() def load(self) -> tuple[VideoEncoder | None, VideoDecoder | None]: @@ -75,11 +82,27 @@ def load(self) -> tuple[VideoEncoder | None, VideoDecoder | None]: self.decoder = decoder.to(self.device).eval() def encode(self, video_frames: torch.Tensor) -> torch.Tensor: + """ + Encode video frames to latent space. + Args: + video_frames: Input video tensor [1, C, T, H, W] or [C, T, H, W] + Returns: + Encoded latent tensor [C, F, H_latent, W_latent] + """ + # Ensure video has batch dimension + if video_frames.dim() == 4: + video_frames = video_frames.unsqueeze(0) + if self.cpu_offload: self.encoder = self.encoder.to(AI_DEVICE) + out = self.encoder(video_frames) + if out.dim() == 5: + out = out.squeeze(0) + if self.cpu_offload: self.encoder = self.encoder.to("cpu") + return out def decode( @@ -88,6 +111,10 @@ def decode( tiling_config: TilingConfig | None = None, generator: torch.Generator | None = None, ) -> Iterator[torch.Tensor]: + # 如果启用了tiling但没有提供配置,使用默认配置 + if self.use_tiling and tiling_config is None: + tiling_config = TilingConfig.default() + if self.cpu_offload: self.decoder = self.decoder.to(AI_DEVICE) try: diff --git a/lightx2v/pipeline.py b/lightx2v/pipeline.py index f7bee3b8a..9a67893f3 100755 --- a/lightx2v/pipeline.py +++ b/lightx2v/pipeline.py @@ -28,7 +28,7 @@ from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.set_config import set_config, set_parallel_config -from lightx2v.utils.utils import seed_all +from lightx2v.utils.utils import seed_all, validate_config_paths def dict_like(cls): @@ -170,8 +170,9 @@ def create_generator( ) config = set_config(self) - print(config) + validate_config_paths(config) self.runner = self._init_runner(config) + print(self.runner.config) logger.info(f"Initializing {self.model_cls} runner for {self.task} task...") logger.info(f"Model path: {self.model_path}") logger.info("LightGenerator initialized successfully!") @@ -382,7 +383,7 @@ def generate( negative_prompt, save_result_path, image_path=None, - images=None, + image_strength=None, last_frame_path=None, audio_path=None, src_ref_images=None, @@ -392,9 +393,10 @@ def generate( target_shape=[], ): # Run inference (following LightX2V pattern) + # Note: image_path supports comma-separated paths for multiple images + # image_strength can be a scalar (float/int) or a list matching the number of images self.seed = seed self.image_path = image_path - self.images = images self.last_frame_path = last_frame_path self.audio_path = audio_path self.src_ref_images = src_ref_images @@ -405,6 +407,8 @@ def generate( self.save_result_path = save_result_path self.return_result_tensor = return_result_tensor self.target_shape = target_shape + self.image_strength = image_strength + input_info = init_empty_input_info(self.task) seed_all(self.seed) update_input_info_from_dict(input_info, self) diff --git a/lightx2v/utils/input_info.py b/lightx2v/utils/input_info.py index c483bffad..b796dfb37 100755 --- a/lightx2v/utils/input_info.py +++ b/lightx2v/utils/input_info.py @@ -176,12 +176,9 @@ class I2AVInputInfo: prompt_enhanced: str = field(default_factory=str) negative_prompt: str = field(default_factory=str) image_path: str = field(default_factory=str) + image_strength: float = field(default_factory=float) save_result_path: str = field(default_factory=str) return_result_tensor: bool = field(default_factory=lambda: False) - # Image conditioning: list of (image_path, frame_idx, strength) tuples - # frame_idx: which frame to replace with the image (0-indexed) - # strength: conditioning strength (0.0-1.0, typically 1.0 for full replacement) - images: list = field(default_factory=list) # list[tuple[str, int, float]] # shape related resize_mode: str = field(default_factory=str) original_shape: list = field(default_factory=list) diff --git a/lightx2v/utils/set_config.py b/lightx2v/utils/set_config.py index 69ab2e621..fe81dcb04 100755 --- a/lightx2v/utils/set_config.py +++ b/lightx2v/utils/set_config.py @@ -130,16 +130,29 @@ def set_config(args): def set_parallel_config(config): if config["parallel"]: - cfg_p_size = config["parallel"].get("cfg_p_size", 1) - seq_p_size = config["parallel"].get("seq_p_size", 1) - assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size * seq_p_size must be equal to world_size" - config["device_mesh"] = init_device_mesh(AI_DEVICE, (cfg_p_size, seq_p_size), mesh_dim_names=("cfg_p", "seq_p")) + tensor_p_size = config["parallel"].get("tensor_p_size", 1) + + if tensor_p_size > 1: + # Tensor parallel only: 1D mesh + assert tensor_p_size == dist.get_world_size(), f"tensor_p_size ({tensor_p_size}) must be equal to world_size ({dist.get_world_size()})" + config["device_mesh"] = init_device_mesh(AI_DEVICE, (tensor_p_size,), mesh_dim_names=("tensor_p",)) + config["tensor_parallel"] = True + config["seq_parallel"] = False + config["cfg_parallel"] = False + else: + # Original 2D mesh for cfg_p and seq_p + cfg_p_size = config["parallel"].get("cfg_p_size", 1) + seq_p_size = config["parallel"].get("seq_p_size", 1) + assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size ({cfg_p_size}) * seq_p_size ({seq_p_size}) must be equal to world_size ({dist.get_world_size()})" + config["device_mesh"] = init_device_mesh(AI_DEVICE, (cfg_p_size, seq_p_size), mesh_dim_names=("cfg_p", "seq_p")) + config["tensor_parallel"] = False + + if config["parallel"] and config["parallel"].get("seq_p_size", False) and config["parallel"]["seq_p_size"] > 1: + config["seq_parallel"] = True + + if config.get("enable_cfg", False) and config["parallel"] and config["parallel"].get("cfg_p_size", False) and config["parallel"]["cfg_p_size"] > 1: + config["cfg_parallel"] = True - if config["parallel"] and config["parallel"].get("seq_p_size", False) and config["parallel"]["seq_p_size"] > 1: - config["seq_parallel"] = True - - if config.get("enable_cfg", False) and config["parallel"] and config["parallel"].get("cfg_p_size", False) and config["parallel"]["cfg_p_size"] > 1: - config["cfg_parallel"] = True # warmup dist _a = torch.zeros([1]).to(f"{AI_DEVICE}:{dist.get_rank()}") dist.all_reduce(_a) diff --git a/lightx2v/utils/utils.py b/lightx2v/utils/utils.py index 3560ad327..2ed9d0bdc 100755 --- a/lightx2v/utils/utils.py +++ b/lightx2v/utils/utils.py @@ -598,6 +598,11 @@ def validate_task_arguments(args: "argparse.Namespace") -> None: Raises: AssertionError: If required arguments are missing or invalid for the task """ + # Check model_path exists + model_path = getattr(args, "model_path", None) + if model_path: + check_path_exists(model_path) + task = args.task # Define required file paths for each task @@ -610,6 +615,7 @@ def validate_task_arguments(args: "argparse.Namespace") -> None: "animate": {"required_paths": ["image_path"], "description": "Animate task requires --image_path"}, "t2v": {"required_paths": [], "description": "Text-to-Video task"}, "t2i": {"required_paths": [], "description": "Text-to-Image task"}, + "i2av": {"required_paths": ["image_path"], "description": "Image-to-Audio-Video task requires --image_path"}, } if task not in task_requirements: @@ -635,3 +641,46 @@ def validate_task_arguments(args: "argparse.Namespace") -> None: check_path_exists(path_value) logger.info(f"✓ Task '{task}' arguments validated successfully") + + +def validate_config_paths(config: dict) -> None: + """ + Validate checkpoint paths in config dictionary. + + Args: + config: Configuration dictionary + + Raises: + FileNotFoundError: If any checkpoint path in config does not exist + """ + # Check dit_quantized_ckpt or dit_original_ckpt + if "dit_quantized_ckpt" in config and config["dit_quantized_ckpt"] is not None: + check_path_exists(config["dit_quantized_ckpt"]) + logger.debug(f"✓ Verified dit_quantized_ckpt: {config['dit_quantized_ckpt']}") + + if "dit_original_ckpt" in config and config["dit_original_ckpt"] is not None: + check_path_exists(config["dit_original_ckpt"]) + logger.debug(f"✓ Verified dit_original_ckpt: {config['dit_original_ckpt']}") + + # For wan2.2, check high and low noise checkpoints + model_cls = config.get("model_cls", "") + if model_cls and "wan2.2" in model_cls: + # Check high noise checkpoints + if "high_noise_original_ckpt" in config and config["high_noise_original_ckpt"] is not None: + check_path_exists(config["high_noise_original_ckpt"]) + logger.debug(f"✓ Verified high_noise_original_ckpt: {config['high_noise_original_ckpt']}") + + if "high_noise_quantized_ckpt" in config and config["high_noise_quantized_ckpt"] is not None: + check_path_exists(config["high_noise_quantized_ckpt"]) + logger.debug(f"✓ Verified high_noise_quantized_ckpt: {config['high_noise_quantized_ckpt']}") + + # Check low noise checkpoints + if "low_noise_original_ckpt" in config and config["low_noise_original_ckpt"] is not None: + check_path_exists(config["low_noise_original_ckpt"]) + logger.debug(f"✓ Verified low_noise_original_ckpt: {config['low_noise_original_ckpt']}") + + if "low_noise_quantized_ckpt" in config and config["low_noise_quantized_ckpt"] is not None: + check_path_exists(config["low_noise_quantized_ckpt"]) + logger.debug(f"✓ Verified low_noise_quantized_ckpt: {config['low_noise_quantized_ckpt']}") + + logger.info("✓ Config checkpoint paths validated successfully") diff --git a/scripts/ltx2/run_ltx2_i2av.sh b/scripts/ltx2/run_ltx2_i2av.sh old mode 100644 new mode 100755 index b384f8c22..811d7cad0 --- a/scripts/ltx2/run_ltx2_i2av.sh +++ b/scripts/ltx2/run_ltx2_i2av.sh @@ -4,6 +4,7 @@ lightx2v_path=/path/to/LightX2V model_path=Lightricks/LTX-2 + export CUDA_VISIBLE_DEVICES=0 # set environment variables @@ -12,9 +13,10 @@ source ${lightx2v_path}/scripts/base/base.sh python -m lightx2v.infer \ --model_cls ltx2 \ --task i2av \ ---images "${lightx2v_path}/assets/inputs/imgs/woman.jpeg:0:1.0" \ +--image_path "${lightx2v_path}/assets/inputs/imgs/woman.jpeg" \ --model_path $model_path \ --config_json ${lightx2v_path}/configs/ltx2/ltx2.json \ ---prompt "A young woman with wavy, shoulder-length light brown hair stands outdoors on a foggy day. She wears a cozy pink turtleneck sweater, with a serene expression and piercing blue eyes. A wooden fence and a misty, grassy field fade into the background, evoking a calm and introspective mood." \ +--prompt "A young woman with wavy, shoulder-length light brown hair is singing and dancing joyfully outdoors on a foggy day. She wears a cozy pink turtleneck sweater, swaying gracefully to the music with animated expressions and bright, piercing blue eyes. Her movements are fluid and energetic as she twirls and gestures expressively. A wooden fence and a misty, grassy field fade into the background, creating a dreamy atmosphere for her lively performance." \ --negative_prompt "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." \ ---save_result_path ${lightx2v_path}/save_results/output_lightx2v_ltx2_i2av.mp4 +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_ltx2_i2av.mp4 \ +--image_strength 1.0 diff --git a/scripts/ltx2/run_ltx2_i2av_tp.sh b/scripts/ltx2/run_ltx2_i2av_tp.sh new file mode 100755 index 000000000..546a4402f --- /dev/null +++ b/scripts/ltx2/run_ltx2_i2av_tp.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# set path and first +lightx2v_path=/path/to/LightX2V +model_path=Lightricks/LTX-2 + +export CUDA_VISIBLE_DEVICES=0,1 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=2 -m lightx2v.infer \ +--model_cls ltx2 \ +--task i2av \ +--image_path "${lightx2v_path}/assets/inputs/imgs/woman.jpeg" \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/ltx2/ltx2_tp.json \ +--prompt "A young woman with wavy, shoulder-length light brown hair is singing and dancing joyfully outdoors on a foggy day. She wears a cozy pink turtleneck sweater, swaying gracefully to the music with animated expressions and bright, piercing blue eyes. Her movements are fluid and energetic as she twirls and gestures expressively. A wooden fence and a misty, grassy field fade into the background, creating a dreamy atmosphere for her lively performance." \ +--negative_prompt "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_ltx2_i2av_tp.mp4 \ +--image_strength 1.0 diff --git a/scripts/ltx2/run_ltx2_i2av_ulysses.sh b/scripts/ltx2/run_ltx2_i2av_ulysses.sh new file mode 100755 index 000000000..867726771 --- /dev/null +++ b/scripts/ltx2/run_ltx2_i2av_ulysses.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# set path and first +lightx2v_path=/path/to/LightX2V +model_path=Lightricks/LTX-2 + +export CUDA_VISIBLE_DEVICES=0,1 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=2 -m lightx2v.infer \ +--model_cls ltx2 \ +--task i2av \ +--image_path "${lightx2v_path}/assets/inputs/imgs/woman.jpeg" \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/ltx2/ltx2_ulysses.json \ +--prompt "A young woman with wavy, shoulder-length light brown hair is singing and dancing joyfully outdoors on a foggy day. She wears a cozy pink turtleneck sweater, swaying gracefully to the music with animated expressions and bright, piercing blue eyes. Her movements are fluid and energetic as she twirls and gestures expressively. A wooden fence and a misty, grassy field fade into the background, creating a dreamy atmosphere for her lively performance." \ +--negative_prompt "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_ltx2_i2av_ulysses.mp4 \ +--image_strength 1.0 diff --git a/scripts/ltx2/run_ltx2_t2av.sh b/scripts/ltx2/run_ltx2_t2av.sh index f91931f42..a30d639c0 100755 --- a/scripts/ltx2/run_ltx2_t2av.sh +++ b/scripts/ltx2/run_ltx2_t2av.sh @@ -14,7 +14,7 @@ python -m lightx2v.infer \ --model_cls ltx2 \ --task t2av \ --model_path $model_path \ ---config_json ${lightx2v_path}/configs/ltx2/ltx2_distill_fp8.json \ +--config_json ${lightx2v_path}/configs/ltx2/ltx2.json \ --prompt "A beautiful sunset over the ocean" \ --negative_prompt "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." \ --save_result_path ${lightx2v_path}/save_results/output_lightx2v_ltx2_t2av.mp4 diff --git a/scripts/ltx2/run_ltx2_t2av_cfg_parallel.sh b/scripts/ltx2/run_ltx2_t2av_cfg_parallel.sh old mode 100644 new mode 100755 diff --git a/scripts/ltx2/run_ltx2_t2av_tp.sh b/scripts/ltx2/run_ltx2_t2av_tp.sh new file mode 100755 index 000000000..073df2f6e --- /dev/null +++ b/scripts/ltx2/run_ltx2_t2av_tp.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path=/path/to/LightX2V +model_path=Lightricks/LTX-2 + +export CUDA_VISIBLE_DEVICES=0,1 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=2 -m lightx2v.infer \ +--model_cls ltx2 \ +--task t2av \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/ltx2/ltx2_tp.json \ +--prompt "A beautiful sunset over the ocean" \ +--negative_prompt "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_ltx2_t2av_tp.mp4 diff --git a/scripts/ltx2/run_ltx2_t2av_ulysses.sh b/scripts/ltx2/run_ltx2_t2av_ulysses.sh new file mode 100755 index 000000000..284d70198 --- /dev/null +++ b/scripts/ltx2/run_ltx2_t2av_ulysses.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path=/path/to/LightX2V +model_path=Lightricks/LTX-2 + +export CUDA_VISIBLE_DEVICES=0,1 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=2 -m lightx2v.infer \ +--model_cls ltx2 \ +--task t2av \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/ltx2/ltx2_ulysses.json \ +--prompt "A beautiful sunset over the ocean" \ +--negative_prompt "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_ltx2_t2av_ulysses.mp4 diff --git a/scripts/wan/run_wan_i2v.sh b/scripts/wan/run_wan_i2v.sh index ea50e4e20..3ba45598e 100755 --- a/scripts/wan/run_wan_i2v.sh +++ b/scripts/wan/run_wan_i2v.sh @@ -1,15 +1,15 @@ #!/bin/bash # set path firstly -lightx2v_path= -model_path= +lightx2v_path=/data/nvme0/gushiqiao/models/code/LightX2V +model_path=/data/nvme0/gushiqiao/models/official_models/Wan2.1-I2V-14B-720P -export CUDA_VISIBLE_DEVICES=0 +export CUDA_VISIBLE_DEVICES=0,1,2,3 # set environment variables source ${lightx2v_path}/scripts/base/base.sh -python -m lightx2v.infer \ +torchrun --nproc_per_node=4 -m lightx2v.infer \ --model_cls wan2.1 \ --task i2v \ --model_path $model_path \ From d52ce31c7d764238313fed61b8939f85d0ee0247 Mon Sep 17 00:00:00 2001 From: gushiqiao <975033167> Date: Fri, 23 Jan 2026 06:21:38 +0000 Subject: [PATCH 2/4] fix --- lightx2v/models/runners/ltx2/ltx2_runner.py | 8 -------- lightx2v/models/video_encoders/hf/ltx2/model.py | 4 ---- 2 files changed, 12 deletions(-) diff --git a/lightx2v/models/runners/ltx2/ltx2_runner.py b/lightx2v/models/runners/ltx2/ltx2_runner.py index 10ef01c8d..794adb697 100755 --- a/lightx2v/models/runners/ltx2/ltx2_runner.py +++ b/lightx2v/models/runners/ltx2/ltx2_runner.py @@ -77,13 +77,6 @@ def load_text_encoder(self): text_encoders = [text_encoder] return text_encoders - def get_vae_parallel(self): - if isinstance(self.config.get("parallel", False), bool): - return self.config.get("parallel", False) - if isinstance(self.config.get("parallel", False), dict): - return self.config.get("parallel", {}).get("vae_parallel", True) - return False - def load_vae(self): """Load video and audio VAE decoders.""" # offload config @@ -108,7 +101,6 @@ def load_vae(self): load_encoder=self.config["task"] == "i2av", use_tiling=self.config.get("use_tiling_vae", False), cpu_offload=vae_offload, - parallel=self.get_vae_parallel(), ) # Audio VAE diff --git a/lightx2v/models/video_encoders/hf/ltx2/model.py b/lightx2v/models/video_encoders/hf/ltx2/model.py index a6d1a2a0c..31fb0f874 100755 --- a/lightx2v/models/video_encoders/hf/ltx2/model.py +++ b/lightx2v/models/video_encoders/hf/ltx2/model.py @@ -36,16 +36,12 @@ def __init__( load_encoder: bool = True, use_tiling: bool = False, cpu_offload: bool = False, - parallel: bool = False, - use_2d_split: bool = True, ): self.checkpoint_path = checkpoint_path self.device = device self.dtype = dtype self.load_encoder_flag = load_encoder self.use_tiling = use_tiling - self.parallel = parallel - self.use_2d_split = use_2d_split self.loader = SafetensorsModelStateDictLoader() self.encoder = None self.decoder = None From cb3ae89634ad8d5d53c3e5c4ff1db26190e957e2 Mon Sep 17 00:00:00 2001 From: gushiqiao <975033167> Date: Fri, 23 Jan 2026 06:25:40 +0000 Subject: [PATCH 3/4] fix --- configs/wan/wan_i2v.json | 6 +----- scripts/wan/run_wan_i2v.sh | 8 ++++---- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/configs/wan/wan_i2v.json b/configs/wan/wan_i2v.json index ff4b83b1e..6c2107083 100755 --- a/configs/wan/wan_i2v.json +++ b/configs/wan/wan_i2v.json @@ -9,9 +9,5 @@ "sample_guide_scale": 5, "sample_shift": 3, "enable_cfg": true, - "cpu_offload": false, - "parallel": { - "seq_p_size": 4, - "seq_p_attn_type": "ulysses" - } + "cpu_offload": false } diff --git a/scripts/wan/run_wan_i2v.sh b/scripts/wan/run_wan_i2v.sh index 3ba45598e..ea50e4e20 100755 --- a/scripts/wan/run_wan_i2v.sh +++ b/scripts/wan/run_wan_i2v.sh @@ -1,15 +1,15 @@ #!/bin/bash # set path firstly -lightx2v_path=/data/nvme0/gushiqiao/models/code/LightX2V -model_path=/data/nvme0/gushiqiao/models/official_models/Wan2.1-I2V-14B-720P +lightx2v_path= +model_path= -export CUDA_VISIBLE_DEVICES=0,1,2,3 +export CUDA_VISIBLE_DEVICES=0 # set environment variables source ${lightx2v_path}/scripts/base/base.sh -torchrun --nproc_per_node=4 -m lightx2v.infer \ +python -m lightx2v.infer \ --model_cls wan2.1 \ --task i2v \ --model_path $model_path \ From dd2b5fd7ab237f73504ef7f3101b300653ceabae Mon Sep 17 00:00:00 2001 From: gushiqiao <975033167> Date: Fri, 23 Jan 2026 06:28:07 +0000 Subject: [PATCH 4/4] fix --- lightx2v/models/networks/wan/infer/transformer_infer.py | 4 +--- lightx2v/models/networks/wan/model.py | 2 -- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/lightx2v/models/networks/wan/infer/transformer_infer.py b/lightx2v/models/networks/wan/infer/transformer_infer.py index 74133c181..705e1bb29 100755 --- a/lightx2v/models/networks/wan/infer/transformer_infer.py +++ b/lightx2v/models/networks/wan/infer/transformer_infer.py @@ -1,7 +1,6 @@ from functools import partial import torch -import torch.distributed as dist from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer from lightx2v.utils.envs import * @@ -193,7 +192,7 @@ def infer_self_attn(self, phase, x, shift_msa, scale_msa): q = phase.self_attn_norm_q.apply(phase.self_attn_q.apply(norm1_out)).view(s, n, d) k = phase.self_attn_norm_k.apply(phase.self_attn_k.apply(norm1_out)).view(s, n, d) v = phase.self_attn_v.apply(norm1_out).view(s, n, d) - print(dist.get_rank(), "qkv", q.shape, k.shape, v.shape) + q, k = self.apply_rope_func(q, k, cos_sin) img_qkv_len = q.shape[0] @@ -227,7 +226,6 @@ def infer_self_attn(self, phase, x, shift_msa, scale_msa): enable_head_parallel=self.enable_head_parallel, **attn_running_args, ) - print(dist.get_rank(), "attn_out", attn_out.shape) else: attn_out = phase.self_attn_1.apply( q=q, diff --git a/lightx2v/models/networks/wan/model.py b/lightx2v/models/networks/wan/model.py index e7825f32a..2ed324792 100755 --- a/lightx2v/models/networks/wan/model.py +++ b/lightx2v/models/networks/wan/model.py @@ -483,9 +483,7 @@ def _infer_cond_uncond(self, inputs, infer_condition=True): pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs) if self.config["seq_parallel"]: - print(dist.get_rank(), pre_infer_out.x.shape) pre_infer_out = self._seq_parallel_pre_process(pre_infer_out) - print(dist.get_rank(), pre_infer_out.x.shape) x = self.transformer_infer.infer(self.transformer_weights, pre_infer_out)