Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions configs/ltx2/ltx2.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
{
"infer_steps": 40,
"target_video_length": 121,
"text_len": 512,
"target_height": 512,
"target_width": 768,
"attn_type": "sage_attn2",
Expand All @@ -14,5 +13,6 @@
"audio_fps": 24000,
"audio_mel_bins":16,
"double_precision_rope": true,
"dit_original_ckpt": "/data/nvme0/gushiqiao/models/official_models/LTX-2/ltx-2-19b-dev.safetensors"
"use_tiling_vae": false,
"dit_original_ckpt": "Lightricks/LTX-2/ltx-2-19b-dev.safetensors"
}
2 changes: 1 addition & 1 deletion configs/ltx2/ltx2_distill_fp8.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"audio_fps": 24000,
"audio_mel_bins":16,
"double_precision_rope": true,
"dit_quantized_ckpt": "/data/nvme0/gushiqiao/models/official_models/LTX-2/ltx-2-19b-distilled-fp8.safetensors",
"dit_quantized_ckpt": "Lightricks/LTX-2/ltx-2-19b-distilled-fp8.safetensors",
"dit_quantized": true,
"dit_quant_scheme": "fp8-pertensor",
"skip_fp8_block_index" : [0, 43, 44, 45, 46, 47]
Expand Down
1 change: 0 additions & 1 deletion configs/ltx2/ltx2_fp8.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
{
"infer_steps": 40,
"target_video_length": 121,
"text_len": 512,
"target_height": 512,
"target_width": 768,
"attn_type": "sage_attn2",
Expand Down
20 changes: 20 additions & 0 deletions configs/ltx2/ltx2_tp.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"infer_steps": 40,
"target_video_length": 121,
"target_height": 512,
"target_width": 768,
"attn_type": "sage_attn2",
"sample_guide_scale": 4,
"sample_shift": [2.05, 0.95],
"enable_cfg": true,
"cpu_offload": false,
"num_channels_latents": 128,
"fps": 24,
"audio_fps": 24000,
"audio_mel_bins":16,
"double_precision_rope": true,
"dit_original_ckpt": "Lightricks/LTX-2/ltx-2-19b-dev.safetensors",
"parallel": {
"tensor_p_size": 2
}
}
21 changes: 21 additions & 0 deletions configs/ltx2/ltx2_ulysses.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"infer_steps": 40,
"target_video_length": 121,
"target_height": 512,
"target_width": 768,
"attn_type": "sage_attn2",
"sample_guide_scale": 4,
"sample_shift": [2.05, 0.95],
"enable_cfg": true,
"cpu_offload": false,
"num_channels_latents": 128,
"fps": 24,
"audio_fps": 24000,
"audio_mel_bins":16,
"double_precision_rope": true,
"dit_original_ckpt": "Lightricks/LTX-2/ltx-2-19b-dev.safetensors",
"parallel": {
"seq_p_size": 2,
"seq_p_attn_type": "ulysses"
}
}
12 changes: 9 additions & 3 deletions examples/ltx2/ltxt_i2av.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,21 @@
)

seed = 42
images = "/path/to/woman.jpeg:0:1.0"
prompt = "A young woman with wavy, shoulder-length light brown hair stands outdoors on a foggy day. She wears a cozy pink turtleneck sweater, with a serene expression and piercing blue eyes. A wooden fence and a misty, grassy field fade into the background, evoking a calm and introspective mood."
image_path = "/path/to/woman.jpeg" # For multiple images, use comma-separated paths: "path1.jpg,path2.jpg"
image_strength = 1.0 # Scalar: use same strength for all images, or list: [1.0, 0.8] for different strengths
prompt = "A young woman with wavy, shoulder-length light brown hair is singing and dancing joyfully outdoors on a foggy day. She wears a cozy pink turtleneck sweater, swaying gracefully to the music with animated expressions and bright, piercing blue eyes. Her movements are fluid and energetic as she twirls and gestures expressively. A wooden fence and a misty, grassy field fade into the background, creating a dreamy atmosphere for her lively performance."
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
save_result_path = "/path/to/save_results/output.mp4"

# Note: image_strength can also be set in config_json
# For scalar: image_strength = 1.0 (all images use same strength)
# For list: image_strength = [1.0, 0.8] (must match number of images)

pipe.generate(
seed=seed,
prompt=prompt,
images=images,
image_path=image_path,
image_strength=image_strength,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
17 changes: 11 additions & 6 deletions examples/ltx2/ltxt_i2av_distilled_fp8.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from lightx2v import LightX2VPipeline

pipe = LightX2VPipeline(
model_path="Lightricks/LTX-2/",
model_path="/data/nvme0/gushiqiao/models/official_models/LTX-2/",
model_cls="ltx2",
task="i2av",
)

pipe.enable_quantize(
dit_quantized=True,
dit_quantized_ckpt="Lightricks/LTX-2/ltx-2-19b-distilled-fp8.safetensors",
dit_quantized_ckpt="/data/nvme0/gushiqiao/models/official_models/LTX-2/ltx-2-19b-distilled-fp8.safetensors",
Comment on lines +4 to +11
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoded local paths are being used for model_path and dit_quantized_ckpt. This is not portable and will cause issues for other developers. It's recommended to use relative paths, environment variables, or a mechanism to load these from a local configuration file that is not committed to the repository.

quant_scheme="fp8-pertensor",
skip_fp8_block_index=[0, 43, 44, 45, 46, 47],
)
Expand All @@ -35,15 +35,20 @@
)

seed = 42
images = "/path/to/woman.jpeg:0:1.0"
prompt = "A young woman with wavy, shoulder-length light brown hair stands outdoors on a foggy day. She wears a cozy pink turtleneck sweater, with a serene expression and piercing blue eyes. A wooden fence and a misty, grassy field fade into the background, evoking a calm and introspective mood."
image_path = "/data/nvme0/gushiqiao/models/code/LightX2V/assets/inputs/imgs/woman.jpeg" # For multiple images, use comma-separated paths: "path1.jpg,path2.jpg"
image_strength = 1.0 # Scalar: use same strength for all images, or list: [1.0, 0.8] for different strengths
prompt = "A young woman with wavy, shoulder-length light brown hair is singing and dancing joyfully outdoors on a foggy day. She wears a cozy pink turtleneck sweater, swaying gracefully to the music with animated expressions and bright, piercing blue eyes. Her movements are fluid and energetic as she twirls and gestures expressively. A wooden fence and a misty, grassy field fade into the background, creating a dreamy atmosphere for her lively performance."
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
save_result_path = "/path/to/save_results/output.mp4"
save_result_path = "/data/nvme0/gushiqiao/models/code/LightX2V/save_results/output_lightx2v_ltx2_i2av_distilled_fp8.mp4"

# Note: image_strength can also be set in config_json
# For scalar: image_strength = 1.0 (all images use same strength)
# For list: image_strength = [1.0, 0.8] (must match number of images)
pipe.generate(
seed=seed,
prompt=prompt,
images=images,
image_path=image_path,
image_strength=image_strength,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
19 changes: 10 additions & 9 deletions examples/ltx2/ltxt_t2av_distilled_fp8.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from lightx2v import LightX2VPipeline

pipe = LightX2VPipeline(model_path="Lightricks/LTX-2", model_cls="ltx2", task="t2av")
pipe = LightX2VPipeline(model_path="/data/nvme0/gushiqiao/models/official_models/LTX-2", model_cls="ltx2", task="t2av")

pipe.enable_quantize(
dit_quantized=True,
dit_quantized_ckpt="Lightricks/LTX-2/ltx-2-19b-distilled-fp8.safetensors",
dit_quantized_ckpt="/data/nvme0/gushiqiao/models/official_models/LTX-2/ltx-2-19b-distilled-fp8.safetensors",
Comment on lines +3 to +7
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This file contains hardcoded local paths for model_path and dit_quantized_ckpt. This makes the example not runnable for other users. Please consider using relative paths from the repository root, or loading paths from environment variables to improve portability.

quant_scheme="fp8-pertensor",
skip_fp8_block_index=[0, 43, 44, 45, 46, 47],
)
Expand Down Expand Up @@ -33,11 +33,12 @@
seed = 42
prompt = "A beautiful sunset over the ocean"
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
save_result_path = "/path/to/save_results/output.mp4"
save_result_path = "/data/nvme0/gushiqiao/models/code/LightX2V/save_results/output_lightx2v_ltx2_t2av_distilled_fp8.mp4"

pipe.generate(
seed=seed,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
for seed in range(10):
pipe.generate(
seed=seed,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
96 changes: 96 additions & 0 deletions lightx2v/common/ops/mm/mm_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import ABCMeta, abstractmethod

import torch
import torch.distributed as dist
from loguru import logger
from safetensors import safe_open

Expand Down Expand Up @@ -92,6 +93,8 @@
except ImportError:
marlin_cuda_quant = None

import torch.distributed as dist


class MMWeightTemplate(metaclass=ABCMeta):
def __init__(
Expand Down Expand Up @@ -2143,3 +2146,96 @@ def apply(self, input_tensor):
if self.has_lora_branch:
return output_tensor + self.apply_lora(input_tensor)
return output_tensor


@MM_WEIGHT_REGISTER("TensorParallel")
class MMWeightTP(MMWeightTemplate):
"""
Tensor Parallel wrapper for any MMWeight type.

This is a generic wrapper that can wrap any MMWeight implementation (Default, fp8, int8, etc.)
and add tensor parallelism support by:
1. Handling weight splitting in load() method
2. Adding all-reduce for row-wise split in apply() method

Supports column-wise and row-wise weight splitting:
- Column split: weight [in_dim, out_dim] -> [in_dim, out_dim/tp_size] per rank
- Row split: weight [in_dim, out_dim] -> [in_dim/tp_size, out_dim] per rank
"""

def __init__(
self,
weight_name,
bias_name,
mm_type="Default",
tp_group=None,
tp_rank=0,
tp_size=1,
split_dim="col",
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix,
lora_path,
)
self.tp_group = tp_group
self.tp_rank = tp_rank
self.tp_size = tp_size
self.split_dim = split_dim # "col" for column split, "row" for row split
assert split_dim in ["col", "row"], f"split_dim must be 'col' or 'row', got {split_dim}"

self._mm = MM_WEIGHT_REGISTER.get(mm_type, MMWeight)(
weight_name=weight_name,
bias_name=bias_name,
create_cuda_buffer=create_cuda_buffer,
create_cpu_buffer=create_cpu_buffer,
lazy_load=lazy_load,
lazy_load_file=lazy_load_file,
is_post_adapter=is_post_adapter,
lora_prefix=lora_prefix,
lora_path=lora_path,
)
self._row_split_bias = None

def load(self, weight_dict):
"""Load weights using internal MMWeight's load method.

Note: Weights in weight_dict are already split by _load_weights_from_rank0.
The format is [out_dim/tp_size, in_dim] for column split or [out_dim, in_dim/tp_size] for row split.
MMWeight.load will handle the transposition via create_default_tensors.

For row split, bias is not split and should be added after all-reduce.
We temporarily remove bias from _mm to prevent it from being added before all-reduce.
"""
self._mm.load(weight_dict)
if self.split_dim == "row" and self.bias_name is not None and self.bias_name in weight_dict:
self._row_split_bias = self._mm.bias.clone()
self._mm.bias = None

def apply(self, input_tensor):
"""Apply matrix multiplication with tensor parallel support."""
# Use internal MMWeight's apply method (handles fp8, int8, etc.)
# For row split, _mm.bias is None, so bias won't be added here
output = self._mm.apply(input_tensor)

# For row split, need all-reduce to combine results from all ranks
if self.split_dim == "row" and self.tp_size > 1 and self.tp_group is not None:
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.tp_group)
# Add bias after all-reduce (bias is not split for row split)
if self._row_split_bias is not None:
output = output + self._row_split_bias

return output
60 changes: 60 additions & 0 deletions lightx2v/common/ops/norm/rms_norm_weight.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABCMeta, abstractmethod

import torch
import torch.distributed as dist
from loguru import logger
from safetensors import safe_open

Expand Down Expand Up @@ -176,6 +177,65 @@ def apply(self, input_tensor):
return input_tensor


@RMS_WEIGHT_REGISTER("TensorParallel")
class RMSWeightTP(RMSWeightTemplate):
"""
RMSNorm weight module with tensor parallelism support.

The weight is split along the hidden dimension to match the split QKV outputs.
"""

def __init__(
self,
weight_name,
tp_group=None,
tp_rank=0,
tp_size=1,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
eps=1e-6,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
eps,
lora_prefix,
lora_path,
)
self.tp_group = tp_group
self.tp_rank = tp_rank
self.tp_size = tp_size

def apply(self, input_tensor):
local_sum = input_tensor.pow(2).sum(-1, keepdim=True)

# All-reduce to get global sum
if self.tp_size > 1 and self.tp_group is not None:
dist.all_reduce(local_sum, op=dist.ReduceOp.SUM, group=self.tp_group)

# Compute global mean: global_sum / hidden_dim
hidden_dim = input_tensor.shape[-1] * self.tp_size
global_mean = local_sum / hidden_dim

# Apply normalization with global mean
if self.sensitive_layer_dtype != self.infer_dtype:
input_tensor = input_tensor * torch.rsqrt(global_mean.float() + self.eps).to(self.infer_dtype)
input_tensor = (input_tensor * self._get_actual_weight()).to(self.infer_dtype)
else:
input_tensor = input_tensor * torch.rsqrt(global_mean + self.eps)
input_tensor = input_tensor * self._get_actual_weight()
return input_tensor


@RMS_WEIGHT_REGISTER("sgl-kernel")
class RMSWeightSgl(RMSWeight):
def __init__(
Expand Down
Loading