diff --git a/configs/ltx2/ltx2.json b/configs/ltx2/ltx2.json index 7c0d6e3b..dabc0bc7 100755 --- a/configs/ltx2/ltx2.json +++ b/configs/ltx2/ltx2.json @@ -1,7 +1,6 @@ { "infer_steps": 40, "target_video_length": 121, - "text_len": 512, "target_height": 512, "target_width": 768, "attn_type": "sage_attn2", @@ -14,5 +13,6 @@ "audio_fps": 24000, "audio_mel_bins":16, "double_precision_rope": true, - "dit_original_ckpt": "/data/nvme0/gushiqiao/models/official_models/LTX-2/ltx-2-19b-dev.safetensors" + "use_tiling_vae": false, + "dit_original_ckpt": "Lightricks/LTX-2/ltx-2-19b-dev.safetensors" } diff --git a/configs/ltx2/ltx2_distill_fp8.json b/configs/ltx2/ltx2_distill_fp8.json index 6c16342d..2ee31c3f 100755 --- a/configs/ltx2/ltx2_distill_fp8.json +++ b/configs/ltx2/ltx2_distill_fp8.json @@ -13,7 +13,7 @@ "audio_fps": 24000, "audio_mel_bins":16, "double_precision_rope": true, - "dit_quantized_ckpt": "/data/nvme0/gushiqiao/models/official_models/LTX-2/ltx-2-19b-distilled-fp8.safetensors", + "dit_quantized_ckpt": "Lightricks/LTX-2/ltx-2-19b-distilled-fp8.safetensors", "dit_quantized": true, "dit_quant_scheme": "fp8-pertensor", "skip_fp8_block_index" : [0, 43, 44, 45, 46, 47] diff --git a/configs/ltx2/ltx2_fp8.json b/configs/ltx2/ltx2_fp8.json index c728de5a..37e4c873 100755 --- a/configs/ltx2/ltx2_fp8.json +++ b/configs/ltx2/ltx2_fp8.json @@ -1,7 +1,6 @@ { "infer_steps": 40, "target_video_length": 121, - "text_len": 512, "target_height": 512, "target_width": 768, "attn_type": "sage_attn2", diff --git a/configs/ltx2/ltx2_tp.json b/configs/ltx2/ltx2_tp.json new file mode 100755 index 00000000..93b007c4 --- /dev/null +++ b/configs/ltx2/ltx2_tp.json @@ -0,0 +1,20 @@ +{ + "infer_steps": 40, + "target_video_length": 121, + "target_height": 512, + "target_width": 768, + "attn_type": "sage_attn2", + "sample_guide_scale": 4, + "sample_shift": [2.05, 0.95], + "enable_cfg": true, + "cpu_offload": false, + "num_channels_latents": 128, + "fps": 24, + "audio_fps": 24000, + "audio_mel_bins":16, + "double_precision_rope": true, + "dit_original_ckpt": "Lightricks/LTX-2/ltx-2-19b-dev.safetensors", + "parallel": { + "tensor_p_size": 2 + } +} diff --git a/configs/ltx2/ltx2_ulysses.json b/configs/ltx2/ltx2_ulysses.json new file mode 100755 index 00000000..092e6674 --- /dev/null +++ b/configs/ltx2/ltx2_ulysses.json @@ -0,0 +1,21 @@ +{ + "infer_steps": 40, + "target_video_length": 121, + "target_height": 512, + "target_width": 768, + "attn_type": "sage_attn2", + "sample_guide_scale": 4, + "sample_shift": [2.05, 0.95], + "enable_cfg": true, + "cpu_offload": false, + "num_channels_latents": 128, + "fps": 24, + "audio_fps": 24000, + "audio_mel_bins":16, + "double_precision_rope": true, + "dit_original_ckpt": "Lightricks/LTX-2/ltx-2-19b-dev.safetensors", + "parallel": { + "seq_p_size": 2, + "seq_p_attn_type": "ulysses" + } +} diff --git a/examples/ltx2/ltxt_i2av.py b/examples/ltx2/ltxt_i2av.py index a7974567..c66b8402 100755 --- a/examples/ltx2/ltxt_i2av.py +++ b/examples/ltx2/ltxt_i2av.py @@ -24,15 +24,21 @@ ) seed = 42 -images = "/path/to/woman.jpeg:0:1.0" -prompt = "A young woman with wavy, shoulder-length light brown hair stands outdoors on a foggy day. She wears a cozy pink turtleneck sweater, with a serene expression and piercing blue eyes. A wooden fence and a misty, grassy field fade into the background, evoking a calm and introspective mood." +image_path = "/path/to/woman.jpeg" # For multiple images, use comma-separated paths: "path1.jpg,path2.jpg" +image_strength = 1.0 # Scalar: use same strength for all images, or list: [1.0, 0.8] for different strengths +prompt = "A young woman with wavy, shoulder-length light brown hair is singing and dancing joyfully outdoors on a foggy day. She wears a cozy pink turtleneck sweater, swaying gracefully to the music with animated expressions and bright, piercing blue eyes. Her movements are fluid and energetic as she twirls and gestures expressively. A wooden fence and a misty, grassy field fade into the background, creating a dreamy atmosphere for her lively performance." negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." save_result_path = "/path/to/save_results/output.mp4" +# Note: image_strength can also be set in config_json +# For scalar: image_strength = 1.0 (all images use same strength) +# For list: image_strength = [1.0, 0.8] (must match number of images) + pipe.generate( seed=seed, prompt=prompt, - images=images, + image_path=image_path, + image_strength=image_strength, negative_prompt=negative_prompt, save_result_path=save_result_path, ) diff --git a/examples/ltx2/ltxt_i2av_distilled_fp8.py b/examples/ltx2/ltxt_i2av_distilled_fp8.py index faf8ccf9..904f725b 100755 --- a/examples/ltx2/ltxt_i2av_distilled_fp8.py +++ b/examples/ltx2/ltxt_i2av_distilled_fp8.py @@ -1,14 +1,14 @@ from lightx2v import LightX2VPipeline pipe = LightX2VPipeline( - model_path="Lightricks/LTX-2/", + model_path="/data/nvme0/gushiqiao/models/official_models/LTX-2/", model_cls="ltx2", task="i2av", ) pipe.enable_quantize( dit_quantized=True, - dit_quantized_ckpt="Lightricks/LTX-2/ltx-2-19b-distilled-fp8.safetensors", + dit_quantized_ckpt="/data/nvme0/gushiqiao/models/official_models/LTX-2/ltx-2-19b-distilled-fp8.safetensors", quant_scheme="fp8-pertensor", skip_fp8_block_index=[0, 43, 44, 45, 46, 47], ) @@ -35,15 +35,20 @@ ) seed = 42 -images = "/path/to/woman.jpeg:0:1.0" -prompt = "A young woman with wavy, shoulder-length light brown hair stands outdoors on a foggy day. She wears a cozy pink turtleneck sweater, with a serene expression and piercing blue eyes. A wooden fence and a misty, grassy field fade into the background, evoking a calm and introspective mood." +image_path = "/data/nvme0/gushiqiao/models/code/LightX2V/assets/inputs/imgs/woman.jpeg" # For multiple images, use comma-separated paths: "path1.jpg,path2.jpg" +image_strength = 1.0 # Scalar: use same strength for all images, or list: [1.0, 0.8] for different strengths +prompt = "A young woman with wavy, shoulder-length light brown hair is singing and dancing joyfully outdoors on a foggy day. She wears a cozy pink turtleneck sweater, swaying gracefully to the music with animated expressions and bright, piercing blue eyes. Her movements are fluid and energetic as she twirls and gestures expressively. A wooden fence and a misty, grassy field fade into the background, creating a dreamy atmosphere for her lively performance." negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." -save_result_path = "/path/to/save_results/output.mp4" +save_result_path = "/data/nvme0/gushiqiao/models/code/LightX2V/save_results/output_lightx2v_ltx2_i2av_distilled_fp8.mp4" +# Note: image_strength can also be set in config_json +# For scalar: image_strength = 1.0 (all images use same strength) +# For list: image_strength = [1.0, 0.8] (must match number of images) pipe.generate( seed=seed, prompt=prompt, - images=images, + image_path=image_path, + image_strength=image_strength, negative_prompt=negative_prompt, save_result_path=save_result_path, ) diff --git a/examples/ltx2/ltxt_t2av_distilled_fp8.py b/examples/ltx2/ltxt_t2av_distilled_fp8.py index 2bcfcb75..0e81115c 100755 --- a/examples/ltx2/ltxt_t2av_distilled_fp8.py +++ b/examples/ltx2/ltxt_t2av_distilled_fp8.py @@ -1,10 +1,10 @@ from lightx2v import LightX2VPipeline -pipe = LightX2VPipeline(model_path="Lightricks/LTX-2", model_cls="ltx2", task="t2av") +pipe = LightX2VPipeline(model_path="/data/nvme0/gushiqiao/models/official_models/LTX-2", model_cls="ltx2", task="t2av") pipe.enable_quantize( dit_quantized=True, - dit_quantized_ckpt="Lightricks/LTX-2/ltx-2-19b-distilled-fp8.safetensors", + dit_quantized_ckpt="/data/nvme0/gushiqiao/models/official_models/LTX-2/ltx-2-19b-distilled-fp8.safetensors", quant_scheme="fp8-pertensor", skip_fp8_block_index=[0, 43, 44, 45, 46, 47], ) @@ -33,11 +33,12 @@ seed = 42 prompt = "A beautiful sunset over the ocean" negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." -save_result_path = "/path/to/save_results/output.mp4" +save_result_path = "/data/nvme0/gushiqiao/models/code/LightX2V/save_results/output_lightx2v_ltx2_t2av_distilled_fp8.mp4" -pipe.generate( - seed=seed, - prompt=prompt, - negative_prompt=negative_prompt, - save_result_path=save_result_path, -) +for seed in range(10): + pipe.generate( + seed=seed, + prompt=prompt, + negative_prompt=negative_prompt, + save_result_path=save_result_path, + ) diff --git a/lightx2v/common/ops/mm/mm_weight.py b/lightx2v/common/ops/mm/mm_weight.py index c27e9cef..1b723dab 100755 --- a/lightx2v/common/ops/mm/mm_weight.py +++ b/lightx2v/common/ops/mm/mm_weight.py @@ -2,6 +2,7 @@ from abc import ABCMeta, abstractmethod import torch +import torch.distributed as dist from loguru import logger from safetensors import safe_open @@ -92,6 +93,8 @@ except ImportError: marlin_cuda_quant = None +import torch.distributed as dist + class MMWeightTemplate(metaclass=ABCMeta): def __init__( @@ -2143,3 +2146,96 @@ def apply(self, input_tensor): if self.has_lora_branch: return output_tensor + self.apply_lora(input_tensor) return output_tensor + + +@MM_WEIGHT_REGISTER("TensorParallel") +class MMWeightTP(MMWeightTemplate): + """ + Tensor Parallel wrapper for any MMWeight type. + + This is a generic wrapper that can wrap any MMWeight implementation (Default, fp8, int8, etc.) + and add tensor parallelism support by: + 1. Handling weight splitting in load() method + 2. Adding all-reduce for row-wise split in apply() method + + Supports column-wise and row-wise weight splitting: + - Column split: weight [in_dim, out_dim] -> [in_dim, out_dim/tp_size] per rank + - Row split: weight [in_dim, out_dim] -> [in_dim/tp_size, out_dim] per rank + """ + + def __init__( + self, + weight_name, + bias_name, + mm_type="Default", + tp_group=None, + tp_rank=0, + tp_size=1, + split_dim="col", + create_cuda_buffer=False, + create_cpu_buffer=False, + lazy_load=False, + lazy_load_file=None, + is_post_adapter=False, + lora_prefix="diffusion_model.blocks", + lora_path="", + ): + super().__init__( + weight_name, + bias_name, + create_cuda_buffer, + create_cpu_buffer, + lazy_load, + lazy_load_file, + is_post_adapter, + lora_prefix, + lora_path, + ) + self.tp_group = tp_group + self.tp_rank = tp_rank + self.tp_size = tp_size + self.split_dim = split_dim # "col" for column split, "row" for row split + assert split_dim in ["col", "row"], f"split_dim must be 'col' or 'row', got {split_dim}" + + self._mm = MM_WEIGHT_REGISTER.get(mm_type, MMWeight)( + weight_name=weight_name, + bias_name=bias_name, + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=lazy_load, + lazy_load_file=lazy_load_file, + is_post_adapter=is_post_adapter, + lora_prefix=lora_prefix, + lora_path=lora_path, + ) + self._row_split_bias = None + + def load(self, weight_dict): + """Load weights using internal MMWeight's load method. + + Note: Weights in weight_dict are already split by _load_weights_from_rank0. + The format is [out_dim/tp_size, in_dim] for column split or [out_dim, in_dim/tp_size] for row split. + MMWeight.load will handle the transposition via create_default_tensors. + + For row split, bias is not split and should be added after all-reduce. + We temporarily remove bias from _mm to prevent it from being added before all-reduce. + """ + self._mm.load(weight_dict) + if self.split_dim == "row" and self.bias_name is not None and self.bias_name in weight_dict: + self._row_split_bias = self._mm.bias.clone() + self._mm.bias = None + + def apply(self, input_tensor): + """Apply matrix multiplication with tensor parallel support.""" + # Use internal MMWeight's apply method (handles fp8, int8, etc.) + # For row split, _mm.bias is None, so bias won't be added here + output = self._mm.apply(input_tensor) + + # For row split, need all-reduce to combine results from all ranks + if self.split_dim == "row" and self.tp_size > 1 and self.tp_group is not None: + dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.tp_group) + # Add bias after all-reduce (bias is not split for row split) + if self._row_split_bias is not None: + output = output + self._row_split_bias + + return output diff --git a/lightx2v/common/ops/norm/rms_norm_weight.py b/lightx2v/common/ops/norm/rms_norm_weight.py index 077a40e9..148e5996 100755 --- a/lightx2v/common/ops/norm/rms_norm_weight.py +++ b/lightx2v/common/ops/norm/rms_norm_weight.py @@ -1,6 +1,7 @@ from abc import ABCMeta, abstractmethod import torch +import torch.distributed as dist from loguru import logger from safetensors import safe_open @@ -176,6 +177,65 @@ def apply(self, input_tensor): return input_tensor +@RMS_WEIGHT_REGISTER("TensorParallel") +class RMSWeightTP(RMSWeightTemplate): + """ + RMSNorm weight module with tensor parallelism support. + + The weight is split along the hidden dimension to match the split QKV outputs. + """ + + def __init__( + self, + weight_name, + tp_group=None, + tp_rank=0, + tp_size=1, + create_cuda_buffer=False, + create_cpu_buffer=False, + lazy_load=False, + lazy_load_file=None, + is_post_adapter=False, + eps=1e-6, + lora_prefix="diffusion_model.blocks", + lora_path="", + ): + super().__init__( + weight_name, + create_cuda_buffer, + create_cpu_buffer, + lazy_load, + lazy_load_file, + is_post_adapter, + eps, + lora_prefix, + lora_path, + ) + self.tp_group = tp_group + self.tp_rank = tp_rank + self.tp_size = tp_size + + def apply(self, input_tensor): + local_sum = input_tensor.pow(2).sum(-1, keepdim=True) + + # All-reduce to get global sum + if self.tp_size > 1 and self.tp_group is not None: + dist.all_reduce(local_sum, op=dist.ReduceOp.SUM, group=self.tp_group) + + # Compute global mean: global_sum / hidden_dim + hidden_dim = input_tensor.shape[-1] * self.tp_size + global_mean = local_sum / hidden_dim + + # Apply normalization with global mean + if self.sensitive_layer_dtype != self.infer_dtype: + input_tensor = input_tensor * torch.rsqrt(global_mean.float() + self.eps).to(self.infer_dtype) + input_tensor = (input_tensor * self._get_actual_weight()).to(self.infer_dtype) + else: + input_tensor = input_tensor * torch.rsqrt(global_mean + self.eps) + input_tensor = input_tensor * self._get_actual_weight() + return input_tensor + + @RMS_WEIGHT_REGISTER("sgl-kernel") class RMSWeightSgl(RMSWeight): def __init__( diff --git a/lightx2v/infer.py b/lightx2v/infer.py index a8dedda2..84c62d79 100755 --- a/lightx2v/infer.py +++ b/lightx2v/infer.py @@ -24,7 +24,7 @@ from lightx2v.utils.profiler import * from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.set_config import print_config, set_config, set_parallel_config -from lightx2v.utils.utils import seed_all, validate_task_arguments +from lightx2v.utils.utils import seed_all, validate_config_paths, validate_task_arguments from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER @@ -75,15 +75,15 @@ def main(): parser.add_argument("--prompt", type=str, default="", help="The input prompt for text-to-video generation") parser.add_argument("--negative_prompt", type=str, default="") - parser.add_argument("--image_path", type=str, default="", help="The path to input image file for image-to-video (i2v) task") - parser.add_argument("--last_frame_path", type=str, default="", help="The path to last frame file for first-last-frame-to-video (flf2v) task") - parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file or directory for audio-to-video (s2v) task") parser.add_argument( - "--images", + "--image_path", type=str, default="", - help="Image conditioning for I2AV task. Format: 'path1:frame_idx1:strength1,path2:frame_idx2:strength2'. Example: 'cat.jpg:0:1.0,dog.jpg:60:0.8'", + help="The path to input image file(s) for image-to-video (i2v) or image-to-audio-video (i2av) task. Multiple paths should be comma-separated. Example: 'path1.jpg,path2.jpg'", ) + parser.add_argument("--last_frame_path", type=str, default="", help="The path to last frame file for first-last-frame-to-video (flf2v) task") + parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file or directory for audio-to-video (s2v) task") + parser.add_argument("--image_strength", type=float, default=1.0, help="The strength of the image-to-audio-video (i2av) task") # [Warning] For vace task, need refactor. parser.add_argument( "--src_ref_images", @@ -149,6 +149,8 @@ def main(): print_config(config) + validate_config_paths(config) + with ProfilingContext4DebugL1("Total Cost"): # init runner runner = init_runner(config) diff --git a/lightx2v/models/networks/ltx2/infer/pre_infer.py b/lightx2v/models/networks/ltx2/infer/pre_infer.py index dd59a405..abb274b6 100755 --- a/lightx2v/models/networks/ltx2/infer/pre_infer.py +++ b/lightx2v/models/networks/ltx2/infer/pre_infer.py @@ -53,7 +53,7 @@ def _prepare_positional_embeddings( max_pos=max_pos, use_middle_indices_grid=use_middle_indices_grid, num_attention_heads=num_attention_heads, - rope_type=LTXRopeType.SPLIT, + rope_type="split", freq_grid_generator=freq_grid_generator, ) return pe diff --git a/lightx2v/models/networks/ltx2/infer/transformer_infer.py b/lightx2v/models/networks/ltx2/infer/transformer_infer.py index c3af96e5..c3da867c 100755 --- a/lightx2v/models/networks/ltx2/infer/transformer_infer.py +++ b/lightx2v/models/networks/ltx2/infer/transformer_infer.py @@ -9,10 +9,11 @@ """ import torch +import torch.distributed as dist from lightx2v.models.networks.ltx2.infer.module_io import LTX2PreInferModuleOutput from lightx2v.models.networks.ltx2.infer.triton_ops import fuse_scale_shift_kernel, fused_rmsnorm_modulate -from lightx2v.models.networks.ltx2.infer.utils import LTXRopeType, apply_rotary_emb, modulate_torch_naive, modulate_with_rmsnorm_torch_naive, rmsnorm_torch_naive +from lightx2v.models.networks.ltx2.infer.utils import apply_rotary_emb, modulate_torch_naive, modulate_with_rmsnorm_torch_naive, rmsnorm_torch_naive from lightx2v.models.networks.wan.infer.triton_ops import norm_infer @@ -31,13 +32,31 @@ def __init__(self, config): config: Model configuration dictionary """ self.config = config - self.rope_type = LTXRopeType(config["rope_type"]) + self.rope_type = config["rope_type"] self.blocks_num = config.get("num_layers", 48) self.v_num_heads = config.get("num_attention_heads", 32) self.v_head_dim = config.get("attention_head_dim", 128) self.a_num_heads = config.get("audio_num_attention_heads", 32) self.a_head_dim = config.get("audio_attention_head_dim", 64) self.clean_cuda_cache = config.get("clean_cuda_cache", False) + + if config.get("seq_parallel", False): + self.seq_p_group = config.get("device_mesh").get_group(mesh_dim="seq_p") + self.seq_p_fp8_comm = config.get("parallel", {}).get("seq_p_fp8_comm", False) + else: + self.seq_p_group = None + self.seq_p_fp8_comm = False + + # Initialize tensor parallel group + if config.get("tensor_parallel", False): + self.tp_group = config.get("device_mesh").get_group(mesh_dim="tensor_p") + self.tp_rank = dist.get_rank(self.tp_group) + self.tp_size = dist.get_world_size(self.tp_group) + else: + self.tp_group = None + self.tp_rank = 0 + self.tp_size = 1 + if config.get("norm_modulate_backend", "triton") == "triton": self.norm_infer_func = norm_infer self.modulate_func = fuse_scale_shift_kernel @@ -46,11 +65,68 @@ def __init__(self, config): self.norm_infer_func = rmsnorm_torch_naive self.modulate_func = modulate_torch_naive self.modulate_with_rmsnorm_func = modulate_with_rmsnorm_torch_naive + self.reset_infer_states() def set_scheduler(self, scheduler): """Set the scheduler for inference.""" self.scheduler = scheduler + def reset_infer_states(self): + """Reset inference states for cumulative sequence lengths.""" + # Only cache cu_seqlens_qkv for self-attention (q, k, v have same length) + # For cross-attention, cu_seqlens_kv varies by context type, create dynamically + self.v_attn_cu_seqlens_qkv = None # For video self-attention + self.a_attn_cu_seqlens_qkv = None # For audio self-attention + + def _create_cu_seqlens(self, seq_len: int, device: torch.device) -> torch.Tensor: + """ + Create cumulative sequence lengths tensor for attention. + + Args: + seq_len: Sequence length + device: Device to place the tensor on + + Returns: + Cumulative sequence lengths tensor [0, seq_len] + """ + if self.config["attn_type"] in ["flash_attn2", "flash_attn3"]: + return torch.tensor([0, seq_len]).cumsum(0, dtype=torch.int32).to(device, non_blocking=True) + else: + return torch.tensor([0, seq_len]).cumsum(0, dtype=torch.int32) + + def _gather_cross_attn_context(self, context: torch.Tensor, k_pe=None): + """ + Gather context and k_pe from all ranks for cross-attention in sequence parallel mode. + + Args: + context: Local context tensor to gather + k_pe: Optional tuple of (cos_freqs, sin_freqs) for key positional embeddings + + Returns: + Tuple of (gathered_context, gathered_k_pe) + """ + world_size = dist.get_world_size(self.seq_p_group) + + # Gather context + context_gathered = [torch.zeros_like(context) for _ in range(world_size)] + dist.all_gather(context_gathered, context, group=self.seq_p_group) + gathered_context = torch.cat(context_gathered, dim=0) + + # Gather k_pe if provided + gathered_k_pe = k_pe + if k_pe is not None: + cos_freqs, sin_freqs = k_pe + # Determine sequence dimension: for 4D tensors [B, H, T, D], seq_dim is 2 + seq_dim = 2 if cos_freqs.dim() == 4 else (0 if cos_freqs.dim() == 2 else 1) + + cos_gathered = [torch.zeros_like(cos_freqs) for _ in range(world_size)] + sin_gathered = [torch.zeros_like(sin_freqs) for _ in range(world_size)] + dist.all_gather(cos_gathered, cos_freqs.contiguous(), group=self.seq_p_group) + dist.all_gather(sin_gathered, sin_freqs.contiguous(), group=self.seq_p_group) + gathered_k_pe = (torch.cat(cos_gathered, dim=seq_dim), torch.cat(sin_gathered, dim=seq_dim)) + + return gathered_context, gathered_k_pe + def _infer_attn( self, attn_phase, @@ -59,38 +135,109 @@ def _infer_attn( pe=None, k_pe=None, is_audio=False, + need_gather_video_context=False, # Only True for video-to-audio cross-attention ) -> torch.Tensor: + """ + Unified attention inference method supporting both TP and non-TP modes. + + Args: + attn_phase: LTX2Attention or LTX2AttentionTP instance + x: Input tensor [seq_len, hidden_dim] + context: Context tensor for cross-attention (None for self-attention) + pe: Positional embeddings for query + k_pe: Positional embeddings for key + is_audio: Whether this is audio attention + need_gather_video_context: Whether to gather video context for cross-attention (only for SP) + + Returns: + Attention output tensor [seq_len, hidden_dim] + """ + use_tp = self.tp_size > 1 + is_self_attn = context is None + context = x if is_self_attn else context + q = attn_phase.to_q.apply(x) - context = x if context is None else context + # For sequence parallel (non-TP), gather context if needed + if need_gather_video_context and self.config.get("seq_parallel", False) and not use_tp: + context, k_pe = self._gather_cross_attn_context(context, k_pe) k = attn_phase.to_k.apply(context) v = attn_phase.to_v.apply(context) q = attn_phase.q_norm.apply(q) k = attn_phase.k_norm.apply(k) + if pe is not None: q = apply_rotary_emb(q, pe, self.rope_type).squeeze() k = apply_rotary_emb(k, pe if k_pe is None else k_pe, self.rope_type).squeeze() - out = attn_phase.attn_func.apply( - q=q.view( - -1, - self.v_num_heads if not is_audio else self.a_num_heads, - self.v_head_dim if not is_audio else self.a_head_dim, - ), - k=k.view( - -1, - self.v_num_heads if not is_audio else self.a_num_heads, - self.v_head_dim if not is_audio else self.a_head_dim, - ), - v=v.view( - -1, - self.v_num_heads if not is_audio else self.a_num_heads, - self.v_head_dim if not is_audio else self.a_head_dim, - ), - max_seqlen_q=q.size(0), - ) + + num_heads = self.v_num_heads if not is_audio else self.a_num_heads + head_dim = self.v_head_dim if not is_audio else self.a_head_dim + seq_len = q.size(0) + + # For TP, heads are split across ranks + num_heads_effective = num_heads // self.tp_size if use_tp else num_heads + + q = q.view(-1, num_heads_effective, head_dim) + k = k.view(-1, num_heads_effective, head_dim) + v = v.view(-1, num_heads_effective, head_dim) + + # For video self-attention with sequence parallel (non-TP only) + if is_self_attn and not is_audio and self.config.get("seq_parallel", False) and not use_tp: + # Cache cu_seqlens_qkv for self-attention (q, k, v have same length) + if self.v_attn_cu_seqlens_qkv is None: + self.v_attn_cu_seqlens_qkv = self._create_cu_seqlens(q.shape[0], q.device) + + out = attn_phase.attn_func_parallel.apply( + q=q, + k=k, + v=v, + slice_qkv_len=seq_len, + cu_seqlens_qkv=self.v_attn_cu_seqlens_qkv, + attention_module=attn_phase.attn_func, + attention_type=self.config["attn_type"], + seq_p_group=self.seq_p_group, + use_fp8_comm=self.seq_p_fp8_comm, + use_tensor_fusion=False, + enable_head_parallel=False, + ) + else: + # For all other attention types (cross-attn, audio self-attn, TP, non-parallel) + # Cache cu_seqlens_qkv for self-attention only + if is_self_attn: + if not is_audio and self.v_attn_cu_seqlens_qkv is None: + self.v_attn_cu_seqlens_qkv = self._create_cu_seqlens(q.shape[0], q.device) + elif is_audio and self.a_attn_cu_seqlens_qkv is None: + self.a_attn_cu_seqlens_qkv = self._create_cu_seqlens(q.shape[0], q.device) + + cu_seqlens_q = self.v_attn_cu_seqlens_qkv if not is_audio else self.a_attn_cu_seqlens_qkv + cu_seqlens_kv = cu_seqlens_q # For self-attn, q and k have same length + else: + # For cross-attention, always create cu_seqlens dynamically + # because k length varies by context type (text, audio, video) + cu_seqlens_q = self._create_cu_seqlens(q.shape[0], q.device) + cu_seqlens_kv = self._create_cu_seqlens(k.shape[0], k.device) + + out = attn_phase.attn_func.apply( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=q.size(0), + max_seqlen_kv=k.size(0), + ) return attn_phase.to_out.apply(out) def _infer_ffn(self, ffn_phase, x: torch.Tensor) -> torch.Tensor: - """Apply feed-forward network.""" + """ + Unified feed-forward network inference method supporting both TP and non-TP modes. + + Args: + ffn_phase: LTX2FFN or LTX2FFNTP instance + x: Input tensor [seq_len, hidden_dim] + + Returns: + FFN output tensor [seq_len, hidden_dim] + """ x = ffn_phase.net_0_proj.apply(x) x = torch.nn.functional.gelu(x, approximate="tanh") return ffn_phase.net_2.apply(x) @@ -117,6 +264,8 @@ def infer_block(self, block, vx, ax, pre_infer_out: LTX2PreInferModuleOutput): ) norm_vx = self.modulate_with_rmsnorm_func(vx, vscale_msa, vshift_msa, weight=None, bias=None, eps=1e-6) + + # Video self-attention vx = ( vx + self._infer_attn( @@ -127,7 +276,7 @@ def infer_block(self, block, vx, ax, pre_infer_out: LTX2PreInferModuleOutput): ) * vgate_msa ) - + # Video cross-attention vx = vx + self._infer_attn( attn_phase=block.compute_phases[1], x=self.norm_infer_func(vx, weight=None, bias=None, eps=1e-6), @@ -145,6 +294,8 @@ def infer_block(self, block, vx, ax, pre_infer_out: LTX2PreInferModuleOutput): ) norm_ax = self.modulate_with_rmsnorm_func(ax, ascale_msa, ashift_msa, weight=None, bias=None, eps=1e-6) + + # Audio self-attention ax = ( ax + self._infer_attn( @@ -155,7 +306,7 @@ def infer_block(self, block, vx, ax, pre_infer_out: LTX2PreInferModuleOutput): ) * agate_msa ) - + # Audio cross-attention ax = ax + self._infer_attn( attn_phase=block.compute_phases[3], x=self.norm_infer_func(ax, weight=None, bias=None, eps=1e-6), @@ -196,12 +347,12 @@ def infer_block(self, block, vx, ax, pre_infer_out: LTX2PreInferModuleOutput): ) # Audio-to-video cross-attention - # vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v) + shift_ca_video_hidden_states_a2v - # ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v - + # Video queries attend to audio context + # Audio is global (not split), so no need to gather vx_scaled = self.modulate_func(vx_norm3, scale_ca_video_hidden_states_a2v, shift_ca_video_hidden_states_a2v) ax_scaled = self.modulate_func(ax_norm3, scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v) + # Audio-to-video cross-attention vx = ( vx + self._infer_attn( @@ -211,13 +362,14 @@ def infer_block(self, block, vx, ax, pre_infer_out: LTX2PreInferModuleOutput): pe=pre_infer_out.video_args.cross_positional_embeddings, k_pe=pre_infer_out.audio_args.cross_positional_embeddings, is_audio=True, + need_gather_video_context=False, # Audio is global, no gather needed ) * gate_out_a2v ) # Video-to-audio cross-attention - # ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a - # vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a + # Audio queries need to attend to full video context + # In TP, video is NOT split (unlike SP), so no gather needed ax_scaled = self.modulate_func(ax_norm3, scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a) vx_scaled = self.modulate_func(vx_norm3, scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a) @@ -230,6 +382,7 @@ def infer_block(self, block, vx, ax, pre_infer_out: LTX2PreInferModuleOutput): pe=pre_infer_out.audio_args.cross_positional_embeddings, k_pe=pre_infer_out.video_args.cross_positional_embeddings, is_audio=True, + need_gather_video_context=not (self.tp_size > 1), # Need gather for SP, not for TP ) * gate_out_v2a ) @@ -282,6 +435,9 @@ def infer(self, weights, pre_infer_out: LTX2PreInferModuleOutput): Returns: Tuple of (video_x, audio_x, video_timestep, audio_timestep) after transformer blocks """ + # Reset inference states at the beginning of each inference + self.reset_infer_states() + vx = pre_infer_out.video_args.x ax = pre_infer_out.audio_args.x diff --git a/lightx2v/models/networks/ltx2/infer/utils.py b/lightx2v/models/networks/ltx2/infer/utils.py index f059466c..cfc35bf7 100755 --- a/lightx2v/models/networks/ltx2/infer/utils.py +++ b/lightx2v/models/networks/ltx2/infer/utils.py @@ -1,6 +1,5 @@ import functools import math -from enum import Enum from typing import Callable, Tuple import numpy as np @@ -71,19 +70,14 @@ def get_timestep_embedding( return emb -class LTXRopeType(Enum): - INTERLEAVED = "interleaved" - SPLIT = "split" - - def apply_rotary_emb( input_tensor: torch.Tensor, freqs_cis: Tuple[torch.Tensor, torch.Tensor], - rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + rope_type: str = "split", ) -> torch.Tensor: - if rope_type == LTXRopeType.INTERLEAVED: + if rope_type == "interleaved": return apply_interleaved_rotary_emb(input_tensor, *freqs_cis) - elif rope_type == LTXRopeType.SPLIT: + elif rope_type == "split": return apply_split_rotary_emb(input_tensor, *freqs_cis) else: raise ValueError(f"Invalid rope type: {rope_type}") @@ -234,7 +228,7 @@ def precompute_freqs_cis( max_pos: list[int] | None = None, use_middle_indices_grid: bool = False, num_attention_heads: int = 32, - rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + rope_type: str = "split", freq_grid_generator: Callable[[float, int, int, torch.device], torch.Tensor] = generate_freq_grid_pytorch, ) -> tuple[torch.Tensor, torch.Tensor]: if max_pos is None: @@ -243,7 +237,7 @@ def precompute_freqs_cis( indices = freq_grid_generator(theta, indices_grid.shape[1], dim) freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid) - if rope_type == LTXRopeType.SPLIT: + if rope_type == "split": expected_freqs = dim // 2 current_freqs = freqs.shape[-1] pad_size = expected_freqs - current_freqs diff --git a/lightx2v/models/networks/ltx2/model.py b/lightx2v/models/networks/ltx2/model.py index 17e572d6..22890ec3 100755 --- a/lightx2v/models/networks/ltx2/model.py +++ b/lightx2v/models/networks/ltx2/model.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist +import torch.nn.functional as F from loguru import logger from safetensors import safe_open @@ -40,6 +41,28 @@ def __init__(self, model_path, config, device, lora_path=None, lora_strength=1.0 self.cpu_offload = self.config.get("cpu_offload", False) self.offload_granularity = self.config.get("offload_granularity", "block") + # Initialize sequence parallel group + if self.config.get("seq_parallel", False): + self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") + else: + self.seq_p_group = None + + if self.config.get("tensor_parallel", False): + self.use_tp = True + self.tp_group = self.config.get("device_mesh").get_group(mesh_dim="tensor_p") + self.tp_rank = dist.get_rank(self.tp_group) + self.tp_size = dist.get_world_size(self.tp_group) + else: + self.tp_group = None + self.use_tp = False + self.tp_rank = 0 + self.tp_size = 1 + + self.padding_multiple = self.config.get("padding_multiple", 1) + + # Track original video sequence length before padding (for sequence parallel) + self.original_video_seq_len = None + # self.model_type = model_type self.remove_keys = ["text_embedding_projection", "audio_vae", "vae", "vocoder", "model.diffusion_model.audio_embeddings_connector", "model.diffusion_model.video_embeddings_connector"] self.lazy_load = self.config.get("lazy_load", False) @@ -96,7 +119,7 @@ def _should_load_weights(self): # Single GPU mode return True elif dist.is_initialized(): - if self.config.get("load_from_rank0", False): + if self.use_tp: # Multi-GPU mode, only rank 0 loads if dist.get_rank() == 0: logger.info(f"Loading weights from {self.model_path}") @@ -236,12 +259,9 @@ def _init_weights(self, weight_dict=None): # Load quantized weights weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer) - if self.config.get("device_mesh") is not None and self.config.get("load_from_rank0", False): + if self.config.get("device_mesh") is not None and self.config.get("tensor_parallel", False): weight_dict = self._load_weights_from_rank0(weight_dict, is_weight_loader) - if hasattr(self, "adapter_weights_dict"): - weight_dict.update(self.adapter_weights_dict) - self.original_weight_dict = weight_dict else: self.original_weight_dict = weight_dict @@ -273,64 +293,199 @@ def _apply_weights(self, weight_dict=None): gc.collect() def _load_weights_from_rank0(self, weight_dict, is_weight_loader): - logger.info("Loading distributed weights") + """ + Load and distribute weights from rank 0 to all ranks. + + Only supports tensor parallel mode with CUDA device. + CPU offload is not supported. + """ + # CPU offload is not supported + if self.cpu_offload: + raise NotImplementedError("_load_weights_from_rank0 does not support CPU offload. Please set cpu_offload=False.") + + logger.info("Loading distributed weights with tensor parallel (CUDA only)") global_src_rank = 0 - target_device = "cpu" if self.cpu_offload else "cuda" if is_weight_loader: + # Rank 0: prepare weights (split for TP) + processed_weight_dict = {} meta_dict = {} + for key, tensor in weight_dict.items(): - meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype} + # Only process .weight keys for TP splitting (bias is handled separately) + if key.endswith(".weight") and self._is_tp_weight(key): + # Split weights for TP: create one entry per rank with original key name + split_weights = self._split_weight_for_tp(key, tensor, self.tp_size) + # Store all split weights temporarily (will be filtered by rank later) + for rank_idx in range(self.tp_size): + rank_key = f"{key}__tp_rank_{rank_idx}" + processed_weight_dict[rank_key] = split_weights[rank_idx] + if rank_idx == 0: # Use rank 0's shape for meta (all ranks have same shape after split) + meta_dict[key] = {"shape": split_weights[rank_idx].shape, "dtype": split_weights[rank_idx].dtype, "is_tp": True} + + # Also handle bias if it exists (for column split weights) + bias_key = key.replace(".weight", ".bias") + if bias_key in weight_dict and self._get_split_type(key) == "col": + # Column split: bias also needs to be split + bias_tensor = weight_dict[bias_key] + assert bias_tensor.shape[0] % self.tp_size == 0, f"bias dimension ({bias_tensor.shape[0]}) must be divisible by tp_size ({self.tp_size}) for {bias_key}" + chunk_size = bias_tensor.shape[0] // self.tp_size + for rank_idx in range(self.tp_size): + rank_bias_key = f"{bias_key}__tp_rank_{rank_idx}" + start_idx = rank_idx * chunk_size + end_idx = start_idx + chunk_size + processed_weight_dict[rank_bias_key] = bias_tensor[start_idx:end_idx] + if rank_idx == 0: + meta_dict[bias_key] = {"shape": bias_tensor[start_idx:end_idx].shape, "dtype": bias_tensor.dtype, "is_tp": True} + # For row split weights, bias is not split (added after all-reduce) + else: + # Non-TP weights or bias (bias is handled above if it's a TP weight's bias) + # Skip bias keys that are already processed above + if not (key.endswith(".bias") and key.replace(".bias", ".weight") in processed_weight_dict): + processed_weight_dict[key] = tensor + meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype, "is_tp": False} obj_list = [meta_dict] dist.broadcast_object_list(obj_list, src=global_src_rank) synced_meta_dict = obj_list[0] + weight_dict = processed_weight_dict # Use processed weights else: obj_list = [None] dist.broadcast_object_list(obj_list, src=global_src_rank) synced_meta_dict = obj_list[0] + # Allocate tensors on CUDA distributed_weight_dict = {} for key, meta in synced_meta_dict.items(): - distributed_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) + is_tp = meta.get("is_tp", False) + device = torch.device(f"cuda:{torch.cuda.current_device()}") + if is_tp: + # TP weight: each rank gets its own slice + distributed_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device=device) + else: + # Non-TP weight: all ranks get full weight + distributed_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device=device) - if target_device == "cuda": - dist.barrier(device_ids=[torch.cuda.current_device()]) + dist.barrier() + # Distribute weights for key in sorted(synced_meta_dict.keys()): - if is_weight_loader: - distributed_weight_dict[key].copy_(weight_dict[key], non_blocking=True) + meta = synced_meta_dict[key] + is_tp = meta.get("is_tp", False) + + if is_tp: + # TP weight: rank 0 sends different slices to each rank + # Use send/recv to ensure each rank gets its own slice + dist.barrier(group=self.tp_group) - if target_device == "cpu": if is_weight_loader: - gpu_tensor = distributed_weight_dict[key].cuda() - dist.broadcast(gpu_tensor, src=global_src_rank) - distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True) - del gpu_tensor - torch.cuda.empty_cache() + # Rank 0: send each rank's slice in order + for rank_idx in range(self.tp_size): + rank_key = f"{key}__tp_rank_{rank_idx}" + if rank_key in weight_dict: + if rank_idx == self.tp_rank: + # Copy to my own buffer + distributed_weight_dict[key].copy_(weight_dict[rank_key], non_blocking=True) + else: + # Send to other ranks + dist.send(weight_dict[rank_key].contiguous(), dst=rank_idx, group=self.tp_group) else: - gpu_tensor = torch.empty_like(distributed_weight_dict[key], device="cuda") - dist.broadcast(gpu_tensor, src=global_src_rank) - distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True) - del gpu_tensor - torch.cuda.empty_cache() - - if distributed_weight_dict[key].is_pinned(): - distributed_weight_dict[key].copy_(distributed_weight_dict[key], non_blocking=True) + # Other ranks: receive from rank 0 + dist.recv(distributed_weight_dict[key], src=global_src_rank, group=self.tp_group) else: + # Non-TP weight: broadcast to all ranks + if is_weight_loader: + distributed_weight_dict[key].copy_(weight_dict[key], non_blocking=True) + dist.broadcast(distributed_weight_dict[key], src=global_src_rank) - if target_device == "cuda": - torch.cuda.synchronize() - else: - for tensor in distributed_weight_dict.values(): - if tensor.is_pinned(): - tensor.copy_(tensor, non_blocking=False) + torch.cuda.synchronize() - logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}") + logger.info(f"Weights distributed across {dist.get_world_size()} devices on CUDA") return distributed_weight_dict + def _is_tp_weight(self, key): + """Check if a weight key needs TP splitting. + + TP weights include: + - Attention layers: to_q, to_k, to_v, to_out.0, q_norm, k_norm + - FFN layers: net.0.proj, net.2 + """ + # Generic patterns that apply to all attention and FFN layers + tp_patterns = [ + ".to_q.", + ".to_k.", + ".to_v.", + ".to_out.0.", + ".q_norm.", + ".k_norm.", + ".net.0.proj.", + ".net.2.", + ] + return any(pattern in key for pattern in tp_patterns) + + def _get_split_type(self, key): + """Determine the split type for a weight key. + + Returns: + "col": Column split (to_q, to_k, to_v, net.0.proj) + "row": Row split (to_out.0, net.2) + "norm": Norm split (q_norm, k_norm) + None: No split needed + """ + if ".q_norm." in key or ".k_norm." in key: + return "norm" + elif ".to_q." in key or ".to_k." in key or ".to_v." in key or ".net.0.proj." in key: + return "col" + elif ".to_out.0." in key or ".net.2." in key: + return "row" + return None + + def _split_weight_for_tp(self, key, weight, tp_size): + """ + Split a weight tensor for tensor parallel. + Returns a list of split weights, one for each rank. + """ + split_type = self._get_split_type(key) + if split_type is None: + # Unknown pattern, don't split + return [weight] * tp_size + + if split_type == "norm": + # 1D weights (norm weights): [hidden_dim] -> [hidden_dim/tp_size] per rank + assert weight.dim() == 1, f"Norm weight should be 1D, got {weight.dim()}D for {key}" + assert weight.shape[0] % tp_size == 0, f"hidden_dim ({weight.shape[0]}) must be divisible by tp_size ({tp_size}) for {key}" + chunk_size = weight.shape[0] // tp_size + return [weight[rank_idx * chunk_size : (rank_idx + 1) * chunk_size] for rank_idx in range(tp_size)] + + # 2D weights (linear layer weights) + assert weight.dim() == 2, f"Linear weight should be 2D, got {weight.dim()}D for {key}" + + # Transpose to [in_dim, out_dim] format for easier splitting + weight_t = weight.t() # [in_dim, out_dim] + + if split_type == "col": + # Column split: [out_dim, in_dim] -> [out_dim/tp_size, in_dim] per rank + # Split along out_dim dimension + assert weight_t.shape[1] % tp_size == 0, f"out_dim ({weight_t.shape[1]}) must be divisible by tp_size ({tp_size}) for {key}" + chunk_size = weight_t.shape[1] // tp_size + split_weights = [] + for rank_idx in range(tp_size): + split_weight = weight_t[:, rank_idx * chunk_size : (rank_idx + 1) * chunk_size].t() # Back to [out_dim/tp_size, in_dim] + split_weights.append(split_weight) + else: # split_type == "row" + # Row split: [out_dim, in_dim] -> [out_dim, in_dim/tp_size] per rank + # Split along in_dim dimension + assert weight_t.shape[0] % tp_size == 0, f"in_dim ({weight_t.shape[0]}) must be divisible by tp_size ({tp_size}) for {key}" + chunk_size = weight_t.shape[0] // tp_size + split_weights = [] + for rank_idx in range(tp_size): + split_weight = weight_t[rank_idx * chunk_size : (rank_idx + 1) * chunk_size, :].t() # Back to [out_dim, in_dim/tp_size] + split_weights.append(split_weight) + + return split_weights + def _init_infer(self): self.pre_infer = self.pre_infer_class(self.config) self.post_infer = self.post_infer_class(self.config) @@ -432,6 +587,221 @@ def infer(self, inputs): def _infer_cond_uncond(self, inputs, infer_condition=True): self.scheduler.infer_condition = infer_condition pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs) + + # Apply sequence parallel pre-processing (only for video) + if self.config.get("seq_parallel", False): + pre_infer_out = self._seq_parallel_pre_process(pre_infer_out) + + # Apply tensor parallel pre-processing (split positional embeddings) + if self.config.get("tensor_parallel", False): + pre_infer_out = self._tensor_parallel_pre_process(pre_infer_out) + vx, ax, video_embedded_timestep, audio_embedded_timestep = self.transformer_infer.infer(self.transformer_weights, pre_infer_out) + + # Apply sequence parallel post-processing (only for video) + if self.config.get("seq_parallel", False): + vx = self._seq_parallel_post_process(vx, self.original_video_seq_len) + video_embedded_timestep = self._seq_parallel_post_process(video_embedded_timestep, self.original_video_seq_len) + # Audio remains global, no gather needed + vx, ax = self.post_infer.infer(self.post_weight, vx, ax, video_embedded_timestep, audio_embedded_timestep) return vx, ax + + @torch.no_grad() + def _seq_parallel_pre_process(self, pre_infer_out): + """ + Pre-process for sequence parallel: only split video sequences across ranks. + Audio remains global (not split) as it has fewer tokens. + """ + world_size = dist.get_world_size(self.seq_p_group) + cur_rank = dist.get_rank(self.seq_p_group) + + # Only process video args for sequence parallel + if pre_infer_out.video_args is not None: + # Split x (latent) + vx = pre_infer_out.video_args.x + self.original_video_seq_len = vx.shape[0] # Record original length before padding + multiple = world_size * self.padding_multiple + padding_size = (multiple - (vx.shape[0] % multiple)) % multiple + if padding_size > 0: + vx = F.pad(vx, (0, 0, 0, padding_size)) + pre_infer_out.video_args.x = torch.chunk(vx, world_size, dim=0)[cur_rank] + + # Split positional embeddings (cos_freqs, sin_freqs) for video self-attention + if pre_infer_out.video_args.positional_embeddings is not None: + v_cos, v_sin = pre_infer_out.video_args.positional_embeddings + # For 4D tensors: [batch, num_heads, seq_len, head_dim], seq_len is at dim=2 + if v_cos.dim() == 4: + seq_dim = 2 + elif v_cos.dim() == 2: + seq_dim = 0 + else: + seq_dim = 1 + + seq_len = v_cos.shape[seq_dim] + padding_size = (multiple - (seq_len % multiple)) % multiple + if padding_size > 0: + pad_spec = [0, 0] * (v_cos.dim() - seq_dim - 1) + [0, padding_size] + [0, 0] * seq_dim + v_cos = F.pad(v_cos, pad_spec) + v_sin = F.pad(v_sin, pad_spec) + + pre_infer_out.video_args.positional_embeddings = (torch.chunk(v_cos, world_size, dim=seq_dim)[cur_rank], torch.chunk(v_sin, world_size, dim=seq_dim)[cur_rank]) + + # Split cross-attention positional embeddings for cross-modal attention + if pre_infer_out.video_args.cross_positional_embeddings is not None: + v_cross_cos, v_cross_sin = pre_infer_out.video_args.cross_positional_embeddings + if v_cross_cos.dim() == 4: + seq_dim = 2 + elif v_cross_cos.dim() == 2: + seq_dim = 0 + else: + seq_dim = 1 + + seq_len = v_cross_cos.shape[seq_dim] + padding_size = (multiple - (seq_len % multiple)) % multiple + if padding_size > 0: + pad_spec = [0, 0] * (v_cross_cos.dim() - seq_dim - 1) + [0, padding_size] + [0, 0] * seq_dim + v_cross_cos = F.pad(v_cross_cos, pad_spec) + v_cross_sin = F.pad(v_cross_sin, pad_spec) + + pre_infer_out.video_args.cross_positional_embeddings = (torch.chunk(v_cross_cos, world_size, dim=seq_dim)[cur_rank], torch.chunk(v_cross_sin, world_size, dim=seq_dim)[cur_rank]) + + # Split timestep embeddings (sequence-length dependent) + if pre_infer_out.video_args.timesteps is not None: + v_timesteps = pre_infer_out.video_args.timesteps + padding_size = (multiple - (v_timesteps.shape[0] % multiple)) % multiple + if padding_size > 0: + v_timesteps = F.pad(v_timesteps, (0, 0, 0, padding_size)) + pre_infer_out.video_args.timesteps = torch.chunk(v_timesteps, world_size, dim=0)[cur_rank] + + if pre_infer_out.video_args.embedded_timestep is not None: + v_embedded_timestep = pre_infer_out.video_args.embedded_timestep + padding_size = (multiple - (v_embedded_timestep.shape[0] % multiple)) % multiple + if padding_size > 0: + v_embedded_timestep = F.pad(v_embedded_timestep, (0, 0, 0, padding_size)) + pre_infer_out.video_args.embedded_timestep = torch.chunk(v_embedded_timestep, world_size, dim=0)[cur_rank] + + if pre_infer_out.video_args.cross_scale_shift_timestep is not None: + v_cross_ss = pre_infer_out.video_args.cross_scale_shift_timestep + padding_size = (multiple - (v_cross_ss.shape[0] % multiple)) % multiple + if padding_size > 0: + v_cross_ss = F.pad(v_cross_ss, (0, 0, 0, padding_size)) + pre_infer_out.video_args.cross_scale_shift_timestep = torch.chunk(v_cross_ss, world_size, dim=0)[cur_rank] + + if pre_infer_out.video_args.cross_gate_timestep is not None: + v_cross_gate = pre_infer_out.video_args.cross_gate_timestep + padding_size = (multiple - (v_cross_gate.shape[0] % multiple)) % multiple + if padding_size > 0: + v_cross_gate = F.pad(v_cross_gate, (0, 0, 0, padding_size)) + pre_infer_out.video_args.cross_gate_timestep = torch.chunk(v_cross_gate, world_size, dim=0)[cur_rank] + + # Audio remains global - no splitting needed + # Audio has fewer tokens, so we keep it on all ranks + + return pre_infer_out + + @torch.no_grad() + def _tensor_parallel_pre_process(self, pre_infer_out): + """ + Pre-process for tensor parallel: split positional embeddings along head dimension. + + In tensor parallel, QKV projections are split along hidden_dim (which equals num_heads * head_dim), + so positional embeddings need to be split along the head dimension to match. + """ + if not self.config.get("tensor_parallel", False): + return pre_infer_out + + tp_group = self.config.get("device_mesh").get_group(mesh_dim="tensor_p") + tp_rank = dist.get_rank(tp_group) + tp_size = dist.get_world_size(tp_group) + + # Get num_heads and head_dim from config + v_num_heads = self.config.get("num_attention_heads", 32) + v_head_dim = self.config.get("attention_head_dim", 128) + a_num_heads = self.config.get("audio_num_attention_heads", 32) + a_head_dim = self.config.get("audio_attention_head_dim", 64) + + v_num_heads_per_rank = v_num_heads // tp_size + a_num_heads_per_rank = a_num_heads // tp_size + + def split_pe(pe, num_heads, num_heads_per_rank, tp_rank, tp_size): + """Split positional embeddings along head dimension.""" + if pe is None: + return None + + if isinstance(pe, tuple): + cos_freqs, sin_freqs = pe + + if cos_freqs.dim() == 2 and cos_freqs.shape[0] == num_heads: + # Shape: [num_heads, head_dim] - split along num_heads dimension + cos_freqs_split = cos_freqs[tp_rank * num_heads_per_rank : (tp_rank + 1) * num_heads_per_rank, :] + sin_freqs_split = sin_freqs[tp_rank * num_heads_per_rank : (tp_rank + 1) * num_heads_per_rank, :] + return (cos_freqs_split, sin_freqs_split) + elif cos_freqs.dim() == 4: + # Shape: [B, H, T, D] where H=num_heads, D=head_dim//2 (for SPLIT rope type) + # In apply_split_rotary_emb, if input is 2D [seq_len, num_heads_per_rank * head_dim] + # and PE is 4D [B, H, T, D], it will reshape input to [B, T, H, head_dim] then swapaxes to [B, H, T, head_dim] + # So H in PE should match num_heads_per_rank in the input + # Therefore, we need to split along H dimension: [B, H, T, D] -> [B, H/tp_size, T, D] + assert cos_freqs.shape[1] == num_heads, f"PE head dimension mismatch: cos_freqs.shape[1]={cos_freqs.shape[1]}, num_heads={num_heads}" + cos_freqs_split = cos_freqs[:, tp_rank * num_heads_per_rank : (tp_rank + 1) * num_heads_per_rank, :, :] + sin_freqs_split = sin_freqs[:, tp_rank * num_heads_per_rank : (tp_rank + 1) * num_heads_per_rank, :, :] + return (cos_freqs_split, sin_freqs_split) + else: + # For other shapes, split along last dimension (hidden_dim) + head_dim = cos_freqs.shape[-1] // num_heads if cos_freqs.dim() > 1 and cos_freqs.shape[-1] % num_heads == 0 else cos_freqs.shape[-1] + hidden_dim_per_rank = num_heads_per_rank * head_dim + start_idx = tp_rank * hidden_dim_per_rank + end_idx = start_idx + hidden_dim_per_rank + cos_freqs_split = cos_freqs[..., start_idx:end_idx] + sin_freqs_split = sin_freqs[..., start_idx:end_idx] + return (cos_freqs_split, sin_freqs_split) + else: + # pe is a single tensor, split along last dimension + head_dim = pe.shape[-1] // num_heads if pe.dim() > 1 and pe.shape[-1] % num_heads == 0 else pe.shape[-1] + hidden_dim_per_rank = num_heads_per_rank * head_dim + start_idx = tp_rank * hidden_dim_per_rank + end_idx = start_idx + hidden_dim_per_rank + return pe[..., start_idx:end_idx] + + # Process video args + if pre_infer_out.video_args is not None: + # Split positional embeddings for video self-attention + if pre_infer_out.video_args.positional_embeddings is not None: + pre_infer_out.video_args.positional_embeddings = split_pe(pre_infer_out.video_args.positional_embeddings, v_num_heads, v_num_heads_per_rank, tp_rank, tp_size) + + # Split cross-attention positional embeddings + if pre_infer_out.video_args.cross_positional_embeddings is not None: + pre_infer_out.video_args.cross_positional_embeddings = split_pe(pre_infer_out.video_args.cross_positional_embeddings, v_num_heads, v_num_heads_per_rank, tp_rank, tp_size) + + # Process audio args + if pre_infer_out.audio_args is not None: + # Split positional embeddings for audio self-attention + if pre_infer_out.audio_args.positional_embeddings is not None: + pre_infer_out.audio_args.positional_embeddings = split_pe(pre_infer_out.audio_args.positional_embeddings, a_num_heads, a_num_heads_per_rank, tp_rank, tp_size) + + # Split cross-attention positional embeddings + if pre_infer_out.audio_args.cross_positional_embeddings is not None: + pre_infer_out.audio_args.cross_positional_embeddings = split_pe(pre_infer_out.audio_args.cross_positional_embeddings, a_num_heads, a_num_heads_per_rank, tp_rank, tp_size) + + return pre_infer_out + + @torch.no_grad() + def _seq_parallel_post_process(self, x, original_length=None): + """ + Post-process for sequence parallel: gather results from all ranks and remove padding. + + Args: + x: Tensor to gather + original_length: Original sequence length before padding. If provided, truncate to this length. + """ + world_size = dist.get_world_size(self.seq_p_group) + gathered_x = [torch.empty_like(x) for _ in range(world_size)] + dist.all_gather(gathered_x, x, group=self.seq_p_group) + combined_output = torch.cat(gathered_x, dim=0) + + # Remove padding to restore original length + if original_length is not None and combined_output.shape[0] > original_length: + combined_output = combined_output[:original_length] + + return combined_output diff --git a/lightx2v/models/networks/ltx2/weights/transformer_weights.py b/lightx2v/models/networks/ltx2/weights/transformer_weights.py index b87015cd..369b57c2 100755 --- a/lightx2v/models/networks/ltx2/weights/transformer_weights.py +++ b/lightx2v/models/networks/ltx2/weights/transformer_weights.py @@ -1,3 +1,5 @@ +import torch.distributed as dist + from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList from lightx2v.utils.registry_factory import ( ATTN_WEIGHT_REGISTER, @@ -138,9 +140,34 @@ def __init__( self.scale_shift_table_a2v_ca_video, ) + # Check if tensor parallel is enabled + use_tp = config.get("tensor_parallel", False) + if use_tp: + tp_group = config.get("device_mesh").get_group(mesh_dim="tensor_p") + tp_rank = dist.get_rank(tp_group) + tp_size = dist.get_world_size(tp_group) + else: + tp_group = None + tp_rank = 0 + tp_size = 1 + + # Create attention and FFN modules based on tensor parallel config + if use_tp: + AttentionClass = LTX2AttentionTP + FFNClass = LTX2FFNTP + tp_kwargs = { + "tp_group": tp_group, + "tp_rank": tp_rank, + "tp_size": tp_size, + } + else: + AttentionClass = LTX2Attention + FFNClass = LTX2FFN + tp_kwargs = {} + self.compute_phases = WeightModuleList( [ - LTX2Attention( + AttentionClass( block_index=block_index, attn_prefix="attn1", block_prefix=block_prefix, @@ -152,8 +179,9 @@ def __init__( lazy_load=self.lazy_load, lazy_load_file=self.lazy_load_file, lora_path=lora_path, + **tp_kwargs, ), - LTX2Attention( + AttentionClass( block_index=block_index, attn_prefix="attn2", block_prefix=block_prefix, @@ -165,8 +193,9 @@ def __init__( lazy_load=self.lazy_load, lazy_load_file=self.lazy_load_file, lora_path=lora_path, + **tp_kwargs, ), - LTX2Attention( + AttentionClass( block_index=block_index, attn_prefix="audio_attn1", block_prefix=block_prefix, @@ -178,8 +207,9 @@ def __init__( lazy_load=self.lazy_load, lazy_load_file=self.lazy_load_file, lora_path=lora_path, + **tp_kwargs, ), - LTX2Attention( + AttentionClass( block_index=block_index, attn_prefix="audio_attn2", block_prefix=block_prefix, @@ -191,8 +221,9 @@ def __init__( lazy_load=self.lazy_load, lazy_load_file=self.lazy_load_file, lora_path=lora_path, + **tp_kwargs, ), - LTX2Attention( + AttentionClass( block_index=block_index, attn_prefix="audio_to_video_attn", block_prefix=block_prefix, @@ -204,8 +235,9 @@ def __init__( lazy_load=self.lazy_load, lazy_load_file=self.lazy_load_file, lora_path=lora_path, + **tp_kwargs, ), - LTX2Attention( + AttentionClass( block_index=block_index, attn_prefix="video_to_audio_attn", block_prefix=block_prefix, @@ -217,8 +249,9 @@ def __init__( lazy_load=self.lazy_load, lazy_load_file=self.lazy_load_file, lora_path=lora_path, + **tp_kwargs, ), - LTX2FFN( + FFNClass( block_index=block_index, ffn_prefix="ff", block_prefix=block_prefix, @@ -230,8 +263,9 @@ def __init__( lazy_load=self.lazy_load, lazy_load_file=self.lazy_load_file, lora_path=lora_path, + **tp_kwargs, ), - LTX2FFN( + FFNClass( block_index=block_index, ffn_prefix="audio_ff", block_prefix=block_prefix, @@ -243,6 +277,7 @@ def __init__( lazy_load=self.lazy_load, lazy_load_file=self.lazy_load_file, lora_path=lora_path, + **tp_kwargs, ), ] ) @@ -283,6 +318,14 @@ def __init__( model_prefix = "model.diffusion_model" self.add_module("attn_func", ATTN_WEIGHT_REGISTER[self.config["attn_type"]]()) + + # Add parallel attention module for sequence parallelism + if self.config.get("seq_parallel", False): + self.add_module( + "attn_func_parallel", + ATTN_WEIGHT_REGISTER[self.config.get("parallel", {}).get("seq_p_attn_type", "ulysses")](), + ) + self.add_module( f"q_norm", RMS_WEIGHT_REGISTER[self.attn_rms_type]( @@ -413,3 +456,236 @@ def __init__( lora_path=lora_path, ), ) + + +class LTX2AttentionTP(WeightModule): + """ + Tensor Parallel version of LTX2Attention. + + QKV projections are split column-wise, output projection is split row-wise. + """ + + def __init__( + self, + block_index, + attn_prefix, + block_prefix, + task, + mm_type, + config, + tp_group=None, + tp_rank=0, + tp_size=1, + create_cuda_buffer=False, + create_cpu_buffer=False, + lazy_load=False, + lazy_load_file=None, + lora_path=None, + ): + super().__init__() + self.block_index = block_index + self.mm_type = mm_type + self.task = task + self.config = config + self.quant_method = config.get("quant_method", None) + self.tp_group = tp_group + self.tp_rank = tp_rank + self.tp_size = tp_size + + self.lazy_load = lazy_load + self.lazy_load_file = lazy_load_file + self.attn_rms_type = self.config.get("rms_type", "sgl-kernel") + + block_lora_prefix = "model.diffusion_model.blocks" + model_prefix = "model.diffusion_model" + + self.add_module("attn_func", ATTN_WEIGHT_REGISTER[self.config["attn_type"]]()) + + # Use TP version of norm if tensor parallel is enabled + # Note: In TP, QKV outputs are split, so norm weights must also be split + norm_class = RMS_WEIGHT_REGISTER["TensorParallel"] + norm_kwargs = { + "tp_group": tp_group, + "tp_rank": tp_rank, + "tp_size": tp_size, + } + + self.add_module( + f"q_norm", + norm_class( + weight_name=f"{model_prefix}.{block_prefix}.{block_index}.{attn_prefix}.q_norm.weight", + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=self.lazy_load, + lazy_load_file=self.lazy_load_file, + lora_prefix=block_lora_prefix, + lora_path=lora_path, + **norm_kwargs, + ), + ) + self.add_module( + f"k_norm", + norm_class( + weight_name=f"{model_prefix}.{block_prefix}.{block_index}.{attn_prefix}.k_norm.weight", + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=self.lazy_load, + lazy_load_file=self.lazy_load_file, + lora_prefix=block_lora_prefix, + lora_path=lora_path, + **norm_kwargs, + ), + ) + # QKV projections: column split + self.add_module( + f"to_q", + MM_WEIGHT_REGISTER["TensorParallel"]( + weight_name=f"{model_prefix}.{block_prefix}.{block_index}.{attn_prefix}.to_q.weight", + bias_name=f"{model_prefix}.{block_prefix}.{block_index}.{attn_prefix}.to_q.bias", + mm_type=mm_type, + tp_group=tp_group, + tp_rank=tp_rank, + tp_size=tp_size, + split_dim="col", + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=self.lazy_load, + lazy_load_file=self.lazy_load_file, + lora_prefix=block_lora_prefix, + lora_path=lora_path, + ), + ) + self.add_module( + f"to_k", + MM_WEIGHT_REGISTER["TensorParallel"]( + weight_name=f"{model_prefix}.{block_prefix}.{block_index}.{attn_prefix}.to_k.weight", + bias_name=f"{model_prefix}.{block_prefix}.{block_index}.{attn_prefix}.to_k.bias", + mm_type=mm_type, + tp_group=tp_group, + tp_rank=tp_rank, + tp_size=tp_size, + split_dim="col", + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=self.lazy_load, + lazy_load_file=self.lazy_load_file, + lora_prefix=block_lora_prefix, + lora_path=lora_path, + ), + ) + self.add_module( + f"to_v", + MM_WEIGHT_REGISTER["TensorParallel"]( + weight_name=f"{model_prefix}.{block_prefix}.{block_index}.{attn_prefix}.to_v.weight", + bias_name=f"{model_prefix}.{block_prefix}.{block_index}.{attn_prefix}.to_v.bias", + mm_type=mm_type, + tp_group=tp_group, + tp_rank=tp_rank, + tp_size=tp_size, + split_dim="col", + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=self.lazy_load, + lazy_load_file=self.lazy_load_file, + lora_prefix=block_lora_prefix, + lora_path=lora_path, + ), + ) + # Output projection: row split (needs all-reduce) + self.add_module( + f"to_out", + MM_WEIGHT_REGISTER["TensorParallel"]( + weight_name=f"{model_prefix}.{block_prefix}.{block_index}.{attn_prefix}.to_out.0.weight", + bias_name=f"{model_prefix}.{block_prefix}.{block_index}.{attn_prefix}.to_out.0.bias", + mm_type=mm_type, + tp_group=tp_group, + tp_rank=tp_rank, + tp_size=tp_size, + split_dim="row", + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=self.lazy_load, + lazy_load_file=self.lazy_load_file, + lora_prefix=block_lora_prefix, + lora_path=lora_path, + ), + ) + + +class LTX2FFNTP(WeightModule): + """ + Tensor Parallel version of LTX2FFN. + + First layer (net_0_proj) is split column-wise, second layer (net_2) is split row-wise. + """ + + def __init__( + self, + block_index, + block_prefix, + ffn_prefix, + task, + mm_type, + config, + tp_group=None, + tp_rank=0, + tp_size=1, + create_cuda_buffer=False, + create_cpu_buffer=False, + lazy_load=False, + lazy_load_file=None, + lora_path=None, + ): + super().__init__() + self.block_index = block_index + self.mm_type = mm_type + self.task = task + self.config = config + self.quant_method = config.get("quant_method", None) + self.tp_group = tp_group + self.tp_rank = tp_rank + self.tp_size = tp_size + + self.lazy_load = lazy_load + self.lazy_load_file = lazy_load_file + block_lora_prefix = "model.diffusion_model.blocks" + model_prefix = "model.diffusion_model" + + # First layer: column split + self.add_module( + f"net_0_proj", + MM_WEIGHT_REGISTER["TensorParallel"]( + f"{model_prefix}.{block_prefix}.{block_index}.{ffn_prefix}.net.0.proj.weight", + f"{model_prefix}.{block_prefix}.{block_index}.{ffn_prefix}.net.0.proj.bias", + mm_type=mm_type, + tp_group=tp_group, + tp_rank=tp_rank, + tp_size=tp_size, + split_dim="col", + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=self.lazy_load, + lazy_load_file=self.lazy_load_file, + lora_prefix=block_lora_prefix, + lora_path=lora_path, + ), + ) + # Second layer: row split (needs all-reduce) + self.add_module( + f"net_2", + MM_WEIGHT_REGISTER["TensorParallel"]( + f"{model_prefix}.{block_prefix}.{block_index}.{ffn_prefix}.net.2.weight", + f"{model_prefix}.{block_prefix}.{block_index}.{ffn_prefix}.net.2.bias", + mm_type=mm_type, + tp_group=tp_group, + tp_rank=tp_rank, + tp_size=tp_size, + split_dim="row", + create_cuda_buffer=create_cuda_buffer, + create_cpu_buffer=create_cpu_buffer, + lazy_load=self.lazy_load, + lazy_load_file=self.lazy_load_file, + lora_prefix=block_lora_prefix, + lora_path=lora_path, + ), + ) diff --git a/lightx2v/models/runners/ltx2/ltx2_runner.py b/lightx2v/models/runners/ltx2/ltx2_runner.py index aaf686df..794adb69 100755 --- a/lightx2v/models/runners/ltx2/ltx2_runner.py +++ b/lightx2v/models/runners/ltx2/ltx2_runner.py @@ -1,6 +1,8 @@ import gc +import os import torch +import torch.distributed as dist from lightx2v.models.input_encoders.hf.ltx2.model import LTX2TextEncoder from lightx2v.models.networks.ltx2.model import LTX2Model @@ -18,31 +20,6 @@ torch_device_module = getattr(torch, AI_DEVICE) -def parse_images_arg(images_str: str) -> list: - if not images_str or images_str.strip() == "": - return [] - - result = [] - for item in images_str.split(","): - parts = item.strip().split(":") - if len(parts) != 3: - raise ValueError(f"Invalid image conditioning format: '{item}'. Expected format: 'image_path:frame_idx:strength'") - - image_path = parts[0].strip() - try: - frame_idx = int(parts[1].strip()) - strength = float(parts[2].strip()) - except ValueError as e: - raise ValueError(f"Invalid image conditioning format: '{item}'. frame_idx must be int, strength must be float. Error: {e}") - - if strength < 0.0 or strength > 1.0: - raise ValueError(f"Invalid strength value {strength} for image '{image_path}'. Strength must be between 0.0 and 1.0.") - - result.append((image_path, frame_idx, strength)) - - return result - - @RUNNER_REGISTER("ltx2") class LTX2Runner(DefaultRunner): def __init__(self, config): @@ -117,7 +94,14 @@ def load_vae(self): ckpt_path = os.path.join(self.config["model_path"], "transformer") # Video VAE - video_vae = LTX2VideoVAE(checkpoint_path=ckpt_path, device=vae_device, dtype=GET_DTYPE(), load_encoder=self.config["task"] == "i2av", cpu_offload=vae_offload) + video_vae = LTX2VideoVAE( + checkpoint_path=ckpt_path, + device=vae_device, + dtype=GET_DTYPE(), + load_encoder=self.config["task"] == "i2av", + use_tiling=self.config.get("use_tiling_vae", False), + cpu_offload=vae_offload, + ) # Audio VAE audio_vae = LTX2AudioVAE(checkpoint_path=ckpt_path, device=vae_device, dtype=GET_DTYPE(), cpu_offload=vae_offload) @@ -163,9 +147,7 @@ def _run_input_encoder_local_t2av(self): def _run_input_encoder_local_i2av(self): self.input_info.video_latent_shape, self.input_info.audio_latent_shape = self.get_latent_shape_with_target_hw() text_encoder_output = self.run_text_encoder(self.input_info) - # Prepare image conditioning if provided - logger.info(f"🖼️ I2AV mode: processing {len(self.input_info.images)} image conditioning(s)") - self.video_denoise_mask, self.initial_video_latent = self._prepare_image_conditioning() + self.video_denoise_mask, self.initial_video_latent = self.run_vae_encoder() torch_device_module.empty_cache() gc.collect() @@ -173,7 +155,13 @@ def _run_input_encoder_local_i2av(self): "text_encoder_output": text_encoder_output, } - def _prepare_image_conditioning(self): + @ProfilingContext4DebugL1( + "Run VAE Encoder", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration, + metrics_labels=["LTX2Runner"], + ) + def run_vae_encoder(self): """ Prepare image conditioning by loading images and encoding them to latents. @@ -182,8 +170,6 @@ def _prepare_image_conditioning(self): - video_denoise_mask: Mask indicating which frames to denoise (unpatchified, shape [1, F, H, W]) - initial_video_latent: Initial latent with conditioned frames (unpatchified, shape [C, F, H, W]) """ - logger.info(f"🖼️ Preparing {len(self.input_info.images)} image conditioning(s)") - # Get latent shape C, F, H, W = self.input_info.video_latent_shape target_height = self.input_info.target_shape[0] if self.input_info.target_shape and len(self.input_info.target_shape) == 2 else self.config["target_height"] @@ -211,8 +197,12 @@ def _prepare_image_conditioning(self): ) # Process each image conditioning - images = parse_images_arg(self.input_info.images) - for image_path, frame_idx, strength in images: + image_paths = self.input_info.image_path.split(",") # image_path1,image_path2,image_path3 + for frame_idx, image_path in enumerate(image_paths): + if not isinstance(self.input_info.image_strength, list): + strength = self.input_info.image_strength + else: + strength = self.input_info.image_strength[frame_idx] logger.info(f" 📷 Loading image: {image_path} for frame {frame_idx} with strength {strength}") # Load and preprocess image @@ -224,12 +214,9 @@ def _prepare_image_conditioning(self): device=AI_DEVICE, ) - # Encode image to latent space - # image shape: [1, C, 1, H, W] with torch.no_grad(): encoded_latent = self.video_vae.encode(image) - # Remove batch dimension: [1, C, 1, H_latent, W_latent] -> [C, 1, H_latent, W_latent] encoded_latent = encoded_latent.squeeze(0) # Verify frame index is valid diff --git a/lightx2v/models/runners/wan/wan_distill_runner.py b/lightx2v/models/runners/wan/wan_distill_runner.py index a7b10284..8abf606a 100755 --- a/lightx2v/models/runners/wan/wan_distill_runner.py +++ b/lightx2v/models/runners/wan/wan_distill_runner.py @@ -21,7 +21,7 @@ def load_transformer(self): if not lora_configs: model = WanDistillModel(**wan_model_kwargs) else: - model = build_wan_model_with_lora(WanModel, self.config, wan_model_kwargs, lora_configs, model_typ="wan2.1") + model = build_wan_model_with_lora(WanModel, self.config, wan_model_kwargs, lora_configs, model_type="wan2.1") return model def init_scheduler(self): diff --git a/lightx2v/models/video_encoders/hf/ltx2/model.py b/lightx2v/models/video_encoders/hf/ltx2/model.py index 43c3ffb3..31fb0f87 100755 --- a/lightx2v/models/video_encoders/hf/ltx2/model.py +++ b/lightx2v/models/video_encoders/hf/ltx2/model.py @@ -34,16 +34,19 @@ def __init__( device: torch.device, dtype: torch.dtype = torch.bfloat16, load_encoder: bool = True, + use_tiling: bool = False, cpu_offload: bool = False, ): self.checkpoint_path = checkpoint_path self.device = device self.dtype = dtype self.load_encoder_flag = load_encoder + self.use_tiling = use_tiling self.loader = SafetensorsModelStateDictLoader() self.encoder = None self.decoder = None self.cpu_offload = cpu_offload + self.grid_table = {} # Cache for 2D grid calculations self.load() def load(self) -> tuple[VideoEncoder | None, VideoDecoder | None]: @@ -75,11 +78,27 @@ def load(self) -> tuple[VideoEncoder | None, VideoDecoder | None]: self.decoder = decoder.to(self.device).eval() def encode(self, video_frames: torch.Tensor) -> torch.Tensor: + """ + Encode video frames to latent space. + Args: + video_frames: Input video tensor [1, C, T, H, W] or [C, T, H, W] + Returns: + Encoded latent tensor [C, F, H_latent, W_latent] + """ + # Ensure video has batch dimension + if video_frames.dim() == 4: + video_frames = video_frames.unsqueeze(0) + if self.cpu_offload: self.encoder = self.encoder.to(AI_DEVICE) + out = self.encoder(video_frames) + if out.dim() == 5: + out = out.squeeze(0) + if self.cpu_offload: self.encoder = self.encoder.to("cpu") + return out def decode( @@ -88,6 +107,10 @@ def decode( tiling_config: TilingConfig | None = None, generator: torch.Generator | None = None, ) -> Iterator[torch.Tensor]: + # 如果启用了tiling但没有提供配置,使用默认配置 + if self.use_tiling and tiling_config is None: + tiling_config = TilingConfig.default() + if self.cpu_offload: self.decoder = self.decoder.to(AI_DEVICE) try: diff --git a/lightx2v/pipeline.py b/lightx2v/pipeline.py index f7bee3b8..9a67893f 100755 --- a/lightx2v/pipeline.py +++ b/lightx2v/pipeline.py @@ -28,7 +28,7 @@ from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.set_config import set_config, set_parallel_config -from lightx2v.utils.utils import seed_all +from lightx2v.utils.utils import seed_all, validate_config_paths def dict_like(cls): @@ -170,8 +170,9 @@ def create_generator( ) config = set_config(self) - print(config) + validate_config_paths(config) self.runner = self._init_runner(config) + print(self.runner.config) logger.info(f"Initializing {self.model_cls} runner for {self.task} task...") logger.info(f"Model path: {self.model_path}") logger.info("LightGenerator initialized successfully!") @@ -382,7 +383,7 @@ def generate( negative_prompt, save_result_path, image_path=None, - images=None, + image_strength=None, last_frame_path=None, audio_path=None, src_ref_images=None, @@ -392,9 +393,10 @@ def generate( target_shape=[], ): # Run inference (following LightX2V pattern) + # Note: image_path supports comma-separated paths for multiple images + # image_strength can be a scalar (float/int) or a list matching the number of images self.seed = seed self.image_path = image_path - self.images = images self.last_frame_path = last_frame_path self.audio_path = audio_path self.src_ref_images = src_ref_images @@ -405,6 +407,8 @@ def generate( self.save_result_path = save_result_path self.return_result_tensor = return_result_tensor self.target_shape = target_shape + self.image_strength = image_strength + input_info = init_empty_input_info(self.task) seed_all(self.seed) update_input_info_from_dict(input_info, self) diff --git a/lightx2v/utils/input_info.py b/lightx2v/utils/input_info.py index c483bffa..b796dfb3 100755 --- a/lightx2v/utils/input_info.py +++ b/lightx2v/utils/input_info.py @@ -176,12 +176,9 @@ class I2AVInputInfo: prompt_enhanced: str = field(default_factory=str) negative_prompt: str = field(default_factory=str) image_path: str = field(default_factory=str) + image_strength: float = field(default_factory=float) save_result_path: str = field(default_factory=str) return_result_tensor: bool = field(default_factory=lambda: False) - # Image conditioning: list of (image_path, frame_idx, strength) tuples - # frame_idx: which frame to replace with the image (0-indexed) - # strength: conditioning strength (0.0-1.0, typically 1.0 for full replacement) - images: list = field(default_factory=list) # list[tuple[str, int, float]] # shape related resize_mode: str = field(default_factory=str) original_shape: list = field(default_factory=list) diff --git a/lightx2v/utils/set_config.py b/lightx2v/utils/set_config.py index 69ab2e62..fe81dcb0 100755 --- a/lightx2v/utils/set_config.py +++ b/lightx2v/utils/set_config.py @@ -130,16 +130,29 @@ def set_config(args): def set_parallel_config(config): if config["parallel"]: - cfg_p_size = config["parallel"].get("cfg_p_size", 1) - seq_p_size = config["parallel"].get("seq_p_size", 1) - assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size * seq_p_size must be equal to world_size" - config["device_mesh"] = init_device_mesh(AI_DEVICE, (cfg_p_size, seq_p_size), mesh_dim_names=("cfg_p", "seq_p")) + tensor_p_size = config["parallel"].get("tensor_p_size", 1) + + if tensor_p_size > 1: + # Tensor parallel only: 1D mesh + assert tensor_p_size == dist.get_world_size(), f"tensor_p_size ({tensor_p_size}) must be equal to world_size ({dist.get_world_size()})" + config["device_mesh"] = init_device_mesh(AI_DEVICE, (tensor_p_size,), mesh_dim_names=("tensor_p",)) + config["tensor_parallel"] = True + config["seq_parallel"] = False + config["cfg_parallel"] = False + else: + # Original 2D mesh for cfg_p and seq_p + cfg_p_size = config["parallel"].get("cfg_p_size", 1) + seq_p_size = config["parallel"].get("seq_p_size", 1) + assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size ({cfg_p_size}) * seq_p_size ({seq_p_size}) must be equal to world_size ({dist.get_world_size()})" + config["device_mesh"] = init_device_mesh(AI_DEVICE, (cfg_p_size, seq_p_size), mesh_dim_names=("cfg_p", "seq_p")) + config["tensor_parallel"] = False + + if config["parallel"] and config["parallel"].get("seq_p_size", False) and config["parallel"]["seq_p_size"] > 1: + config["seq_parallel"] = True + + if config.get("enable_cfg", False) and config["parallel"] and config["parallel"].get("cfg_p_size", False) and config["parallel"]["cfg_p_size"] > 1: + config["cfg_parallel"] = True - if config["parallel"] and config["parallel"].get("seq_p_size", False) and config["parallel"]["seq_p_size"] > 1: - config["seq_parallel"] = True - - if config.get("enable_cfg", False) and config["parallel"] and config["parallel"].get("cfg_p_size", False) and config["parallel"]["cfg_p_size"] > 1: - config["cfg_parallel"] = True # warmup dist _a = torch.zeros([1]).to(f"{AI_DEVICE}:{dist.get_rank()}") dist.all_reduce(_a) diff --git a/lightx2v/utils/utils.py b/lightx2v/utils/utils.py index 3560ad32..2ed9d0bd 100755 --- a/lightx2v/utils/utils.py +++ b/lightx2v/utils/utils.py @@ -598,6 +598,11 @@ def validate_task_arguments(args: "argparse.Namespace") -> None: Raises: AssertionError: If required arguments are missing or invalid for the task """ + # Check model_path exists + model_path = getattr(args, "model_path", None) + if model_path: + check_path_exists(model_path) + task = args.task # Define required file paths for each task @@ -610,6 +615,7 @@ def validate_task_arguments(args: "argparse.Namespace") -> None: "animate": {"required_paths": ["image_path"], "description": "Animate task requires --image_path"}, "t2v": {"required_paths": [], "description": "Text-to-Video task"}, "t2i": {"required_paths": [], "description": "Text-to-Image task"}, + "i2av": {"required_paths": ["image_path"], "description": "Image-to-Audio-Video task requires --image_path"}, } if task not in task_requirements: @@ -635,3 +641,46 @@ def validate_task_arguments(args: "argparse.Namespace") -> None: check_path_exists(path_value) logger.info(f"✓ Task '{task}' arguments validated successfully") + + +def validate_config_paths(config: dict) -> None: + """ + Validate checkpoint paths in config dictionary. + + Args: + config: Configuration dictionary + + Raises: + FileNotFoundError: If any checkpoint path in config does not exist + """ + # Check dit_quantized_ckpt or dit_original_ckpt + if "dit_quantized_ckpt" in config and config["dit_quantized_ckpt"] is not None: + check_path_exists(config["dit_quantized_ckpt"]) + logger.debug(f"✓ Verified dit_quantized_ckpt: {config['dit_quantized_ckpt']}") + + if "dit_original_ckpt" in config and config["dit_original_ckpt"] is not None: + check_path_exists(config["dit_original_ckpt"]) + logger.debug(f"✓ Verified dit_original_ckpt: {config['dit_original_ckpt']}") + + # For wan2.2, check high and low noise checkpoints + model_cls = config.get("model_cls", "") + if model_cls and "wan2.2" in model_cls: + # Check high noise checkpoints + if "high_noise_original_ckpt" in config and config["high_noise_original_ckpt"] is not None: + check_path_exists(config["high_noise_original_ckpt"]) + logger.debug(f"✓ Verified high_noise_original_ckpt: {config['high_noise_original_ckpt']}") + + if "high_noise_quantized_ckpt" in config and config["high_noise_quantized_ckpt"] is not None: + check_path_exists(config["high_noise_quantized_ckpt"]) + logger.debug(f"✓ Verified high_noise_quantized_ckpt: {config['high_noise_quantized_ckpt']}") + + # Check low noise checkpoints + if "low_noise_original_ckpt" in config and config["low_noise_original_ckpt"] is not None: + check_path_exists(config["low_noise_original_ckpt"]) + logger.debug(f"✓ Verified low_noise_original_ckpt: {config['low_noise_original_ckpt']}") + + if "low_noise_quantized_ckpt" in config and config["low_noise_quantized_ckpt"] is not None: + check_path_exists(config["low_noise_quantized_ckpt"]) + logger.debug(f"✓ Verified low_noise_quantized_ckpt: {config['low_noise_quantized_ckpt']}") + + logger.info("✓ Config checkpoint paths validated successfully") diff --git a/scripts/ltx2/run_ltx2_i2av.sh b/scripts/ltx2/run_ltx2_i2av.sh old mode 100644 new mode 100755 index b384f8c2..811d7cad --- a/scripts/ltx2/run_ltx2_i2av.sh +++ b/scripts/ltx2/run_ltx2_i2av.sh @@ -4,6 +4,7 @@ lightx2v_path=/path/to/LightX2V model_path=Lightricks/LTX-2 + export CUDA_VISIBLE_DEVICES=0 # set environment variables @@ -12,9 +13,10 @@ source ${lightx2v_path}/scripts/base/base.sh python -m lightx2v.infer \ --model_cls ltx2 \ --task i2av \ ---images "${lightx2v_path}/assets/inputs/imgs/woman.jpeg:0:1.0" \ +--image_path "${lightx2v_path}/assets/inputs/imgs/woman.jpeg" \ --model_path $model_path \ --config_json ${lightx2v_path}/configs/ltx2/ltx2.json \ ---prompt "A young woman with wavy, shoulder-length light brown hair stands outdoors on a foggy day. She wears a cozy pink turtleneck sweater, with a serene expression and piercing blue eyes. A wooden fence and a misty, grassy field fade into the background, evoking a calm and introspective mood." \ +--prompt "A young woman with wavy, shoulder-length light brown hair is singing and dancing joyfully outdoors on a foggy day. She wears a cozy pink turtleneck sweater, swaying gracefully to the music with animated expressions and bright, piercing blue eyes. Her movements are fluid and energetic as she twirls and gestures expressively. A wooden fence and a misty, grassy field fade into the background, creating a dreamy atmosphere for her lively performance." \ --negative_prompt "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." \ ---save_result_path ${lightx2v_path}/save_results/output_lightx2v_ltx2_i2av.mp4 +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_ltx2_i2av.mp4 \ +--image_strength 1.0 diff --git a/scripts/ltx2/run_ltx2_i2av_tp.sh b/scripts/ltx2/run_ltx2_i2av_tp.sh new file mode 100755 index 00000000..546a4402 --- /dev/null +++ b/scripts/ltx2/run_ltx2_i2av_tp.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# set path and first +lightx2v_path=/path/to/LightX2V +model_path=Lightricks/LTX-2 + +export CUDA_VISIBLE_DEVICES=0,1 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=2 -m lightx2v.infer \ +--model_cls ltx2 \ +--task i2av \ +--image_path "${lightx2v_path}/assets/inputs/imgs/woman.jpeg" \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/ltx2/ltx2_tp.json \ +--prompt "A young woman with wavy, shoulder-length light brown hair is singing and dancing joyfully outdoors on a foggy day. She wears a cozy pink turtleneck sweater, swaying gracefully to the music with animated expressions and bright, piercing blue eyes. Her movements are fluid and energetic as she twirls and gestures expressively. A wooden fence and a misty, grassy field fade into the background, creating a dreamy atmosphere for her lively performance." \ +--negative_prompt "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_ltx2_i2av_tp.mp4 \ +--image_strength 1.0 diff --git a/scripts/ltx2/run_ltx2_i2av_ulysses.sh b/scripts/ltx2/run_ltx2_i2av_ulysses.sh new file mode 100755 index 00000000..86772677 --- /dev/null +++ b/scripts/ltx2/run_ltx2_i2av_ulysses.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# set path and first +lightx2v_path=/path/to/LightX2V +model_path=Lightricks/LTX-2 + +export CUDA_VISIBLE_DEVICES=0,1 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=2 -m lightx2v.infer \ +--model_cls ltx2 \ +--task i2av \ +--image_path "${lightx2v_path}/assets/inputs/imgs/woman.jpeg" \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/ltx2/ltx2_ulysses.json \ +--prompt "A young woman with wavy, shoulder-length light brown hair is singing and dancing joyfully outdoors on a foggy day. She wears a cozy pink turtleneck sweater, swaying gracefully to the music with animated expressions and bright, piercing blue eyes. Her movements are fluid and energetic as she twirls and gestures expressively. A wooden fence and a misty, grassy field fade into the background, creating a dreamy atmosphere for her lively performance." \ +--negative_prompt "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_ltx2_i2av_ulysses.mp4 \ +--image_strength 1.0 diff --git a/scripts/ltx2/run_ltx2_t2av.sh b/scripts/ltx2/run_ltx2_t2av.sh index f91931f4..a30d639c 100755 --- a/scripts/ltx2/run_ltx2_t2av.sh +++ b/scripts/ltx2/run_ltx2_t2av.sh @@ -14,7 +14,7 @@ python -m lightx2v.infer \ --model_cls ltx2 \ --task t2av \ --model_path $model_path \ ---config_json ${lightx2v_path}/configs/ltx2/ltx2_distill_fp8.json \ +--config_json ${lightx2v_path}/configs/ltx2/ltx2.json \ --prompt "A beautiful sunset over the ocean" \ --negative_prompt "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." \ --save_result_path ${lightx2v_path}/save_results/output_lightx2v_ltx2_t2av.mp4 diff --git a/scripts/ltx2/run_ltx2_t2av_cfg_parallel.sh b/scripts/ltx2/run_ltx2_t2av_cfg_parallel.sh old mode 100644 new mode 100755 diff --git a/scripts/ltx2/run_ltx2_t2av_tp.sh b/scripts/ltx2/run_ltx2_t2av_tp.sh new file mode 100755 index 00000000..073df2f6 --- /dev/null +++ b/scripts/ltx2/run_ltx2_t2av_tp.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path=/path/to/LightX2V +model_path=Lightricks/LTX-2 + +export CUDA_VISIBLE_DEVICES=0,1 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=2 -m lightx2v.infer \ +--model_cls ltx2 \ +--task t2av \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/ltx2/ltx2_tp.json \ +--prompt "A beautiful sunset over the ocean" \ +--negative_prompt "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_ltx2_t2av_tp.mp4 diff --git a/scripts/ltx2/run_ltx2_t2av_ulysses.sh b/scripts/ltx2/run_ltx2_t2av_ulysses.sh new file mode 100755 index 00000000..284d7019 --- /dev/null +++ b/scripts/ltx2/run_ltx2_t2av_ulysses.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# set path and first +lightx2v_path=/path/to/LightX2V +model_path=Lightricks/LTX-2 + +export CUDA_VISIBLE_DEVICES=0,1 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=2 -m lightx2v.infer \ +--model_cls ltx2 \ +--task t2av \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/ltx2/ltx2_ulysses.json \ +--prompt "A beautiful sunset over the ocean" \ +--negative_prompt "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." \ +--save_result_path ${lightx2v_path}/save_results/output_lightx2v_ltx2_t2av_ulysses.mp4