diff --git a/contrib/models/Wan2.1-T2V-1.3B/README.md b/contrib/models/Wan2.1-T2V-1.3B/README.md new file mode 100644 index 00000000..512f9473 --- /dev/null +++ b/contrib/models/Wan2.1-T2V-1.3B/README.md @@ -0,0 +1,125 @@ +# WAN 2.1 Text-to-Video 1.3B — Neuron Port + +NeuronX implementation of [Wan-AI/Wan2.1-T2V-1.3B-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers). +All neural network components (T5 encoder, transformer backbone, VAE decoder) run on AWS Trainium. + +## Prerequisites + +- trn2.48xlarge instance with Neuron SDK +- Python packages: `diffusers`, `transformers`, `torch_neuronx`, `neuronx_distributed` +- Model weights downloaded and `WAN_MODEL_DIR` set: + ```bash + export WAN_MODEL_DIR=/home/ubuntu/models/wan2.1-t2v-1.3b + ``` + +## Compilation + +NEFFs must be compiled before inference. Each backbone NEFF is compiled for a fixed frame count. + +| Component | Command | Time | +|-----------|---------|------| +| T5 + 13f backbone + VAE | `NEURON_RT_VISIBLE_CORES=0-7 python3 scripts/trace_all.py --component=all` | ~65 min | +| 49f backbone (CP=8) | `NEURON_RT_VISIBLE_CORES=0-7 torchrun --nproc_per_node=8 scripts/compile_backbone_cp8.py` | ~95 min | + +To compile individual components: +```bash +# T5 only (~2 min) +NEURON_RT_VISIBLE_CORES=0-3 NEURON_RT_VIRTUAL_CORE_SIZE=1 python3 scripts/trace_all.py --component=t5 + +# 13f backbone only (~48 min) +NEURON_RT_VISIBLE_CORES=0-3 python3 scripts/trace_all.py --component=backbone + +# VAE blocks only (~15 min) +NEURON_RT_VISIBLE_CORES=0-3 python3 scripts/trace_all.py --component=vae +``` + +## Inference + +### 13 Frames (~43s, all on Neuron) + +```bash +# Step 1: T5 encoding (LNC=1, separate process) +NEURON_RT_VISIBLE_CORES=0-3 NEURON_RT_VIRTUAL_CORE_SIZE=1 \ + python3 scripts/run_pipeline.py --step=t5 --prompt="a cat on a beach at sunset" + +# Step 2: Backbone + VAE decode +NEURON_RT_VISIBLE_CORES=0-3 \ + python3 scripts/run_pipeline.py --step=generate --seed=42 +``` + +Output: 13 frames at 480×832 in `output/`. + +### 49 Frames (~102s, no crosshatch) + +```bash +# Step 1: T5 encoding (same as above) +NEURON_RT_VISIBLE_CORES=0-3 NEURON_RT_VIRTUAL_CORE_SIZE=1 \ + python3 scripts/run_pipeline.py --step=t5 --prompt="a cat on a beach at sunset" + +# Step 2: Backbone denoising (CP=8, cores 0-7) +NEURON_RT_VISIBLE_CORES=0-7 \ + torchrun --nproc_per_node=8 scripts/validate_backbone_cp8.py + +# Step 3: VAE decode (cores 8-15, reads latents from /tmp/latents_49f.pt) +NEURON_RT_VISIBLE_CORES=8-15 \ + python3 scripts/decode_vae.py --latents=/tmp/latents_49f.pt --num-frames=13 +``` + +Output: 49 frames at 480×832 in `output_49f/`. + +## File Structure + +``` +├── src/ +│ ├── modeling_wan.py # Standalone model, VAE wrappers, XLA compat +│ └── __init__.py +├── scripts/ +│ ├── run_pipeline.py # E2E 13-frame inference +│ ├── trace_all.py # Compile T5, backbone, VAE blocks +│ ├── compile_backbone_cp8.py # Compile 49-frame CP=8 backbone +│ └── validate_backbone_cp8.py # 49-frame backbone denoising +├── compiled_model/ # Pre-compiled NEFFs +│ ├── t5/ +│ ├── vae_blocks/ # 55 blocks (frame-count-agnostic) +│ ├── 13f/backbone/ +│ └── 49f/backbone_cp8/ # 16 NEFFs (2 per rank × 8) +└── test/integration/test_model.py +``` + +## Key Concepts + +- **Pre-computed RoPE:** XLA silently drops inputs used only for `.shape`. RoPE is computed on CPU and passed as an explicit input. +- **Block-by-block VAE:** The causal 3D VAE can't be traced as one unit. Each block is traced with cache I/O and chained at runtime. VAE blocks are frame-count-agnostic. +- **CP=8 for 49 frames:** The 5M instruction limit prevents compiling the full backbone at seq_len=20,280. Context parallelism splits the sequence across 8 ranks, with the model split into 2 NEFFs of 15 blocks each. +- **Different core sets:** T5 (LNC=1) and backbone (LNC=2) can't share a process. For 49 frames, backbone uses cores 0-7 and VAE uses cores 8-15. +- **Fixed frame count per NEFF:** Backbone NEFFs are shape-specific. Different frame counts require recompilation. This is standard for Neuron/XLA (Flux on NxDI does the same for different resolutions). + +## Performance + +| Config | Backbone | VAE | Total | Cores | vs CPU | +|--------|----------|-----|-------|-------|--------| +| 13 frames | 36.5s | 6.6s | 43s | 2 | 1.4× | +| 49 frames | 70.7s | 31.8s | 102s | 10 | 13.5× | + +## Known Limitations + +- **Crosshatch at 13 frames:** VAE temporal upsampling artifact at low frame counts. Use ≥49 frames for clean output. +- **Instruction limit:** Backbone at seq_len>6,240 exceeds 5M limit at TP=1. Requires CP for longer sequences. +- **LNC mismatch:** T5 and backbone must run in separate processes. + +## Validation + +| Component | Cosine vs CPU | +|-----------|--------------| +| Backbone (single step) | 0.9998 | +| T5 encoder | 0.998 | +| VAE (per block) | ≥0.9998 | +| VAE (full decode) | 1.0006 | + +## Compatibility + +| Instance | 13 Frames | 49 Frames | +|----------|-----------|-----------| +| trn2.48xlarge | ✅ | ✅ | + +**Last Updated:** 2026-03-23 diff --git a/contrib/models/Wan2.1-T2V-1.3B/src/__init__.py b/contrib/models/Wan2.1-T2V-1.3B/src/__init__.py new file mode 100644 index 00000000..74b7655f --- /dev/null +++ b/contrib/models/Wan2.1-T2V-1.3B/src/__init__.py @@ -0,0 +1,25 @@ +from .modeling_wan import ( + NeuronWanTransformer3DModel, + T5Wrapper, + make_decoder_xla_compatible, + VAE_BLOCK_ORDER, + ConvInCached, + BlockCached, + NormConvOutCached, + NoCacheWrapper, + UpsampleSpatialOnly, + UpsampleFirstChunk, +) + +__all__ = [ + "NeuronWanTransformer3DModel", + "T5Wrapper", + "make_decoder_xla_compatible", + "VAE_BLOCK_ORDER", + "ConvInCached", + "BlockCached", + "NormConvOutCached", + "NoCacheWrapper", + "UpsampleSpatialOnly", + "UpsampleFirstChunk", +] diff --git a/contrib/models/Wan2.1-T2V-1.3B/src/modeling_wan.py b/contrib/models/Wan2.1-T2V-1.3B/src/modeling_wan.py new file mode 100644 index 00000000..b2d77079 --- /dev/null +++ b/contrib/models/Wan2.1-T2V-1.3B/src/modeling_wan.py @@ -0,0 +1,488 @@ +""" +NeuronX implementation of WAN 2.1 T2V 1.3B — standalone model. + +Mirrors the HF diffusers WanTransformer3DModel architecture with: +- XLA-compatible operations (no negative indexing, no nearest-exact) +- Pre-computed RoPE as explicit inputs (avoids XLA unused-input bug) +- Same state dict keys as HF for direct weight loading + +Usage: + model = NeuronWanTransformer3DModel.from_config(hf_config) + model.load_hf_weights(hf_state_dict) + output = model(hidden_states, timestep, encoder_hidden_states, rope_cos, rope_sin) +""" + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# ─── Embeddings ────────────────────────────────────────────────────────────── + +class TimestepEmbedding(nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + self.linear_1 = nn.Linear(in_dim, out_dim) + self.linear_2 = nn.Linear(out_dim, out_dim) + + def forward(self, x): + return self.linear_2(F.silu(self.linear_1(x))) + + +class Timesteps(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, timesteps): + half = self.dim // 2 + freqs = torch.exp(-math.log(10000) * torch.arange(0, half, dtype=torch.float32, device=timesteps.device) / half) + args = timesteps[:, None].float() * freqs[None] + return torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + + +class TextProjection(nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + self.linear_1 = nn.Linear(in_dim, out_dim) + self.linear_2 = nn.Linear(out_dim, out_dim) + + def forward(self, x): + return self.linear_2(F.silu(self.linear_1(x))) + + +class WanConditionEmbedder(nn.Module): + """Timestep + text conditioning. Produces temb and projects text embeddings.""" + + def __init__(self, dim, freq_dim, text_dim): + super().__init__() + self.timesteps_proj = Timesteps(freq_dim) + self.time_embedder = TimestepEmbedding(freq_dim, dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, dim * 6) + self.text_embedder = TextProjection(text_dim, dim) + + def forward(self, timestep, encoder_hidden_states, **kwargs): + temb = self.time_embedder(self.timesteps_proj(timestep).to(dtype=self.time_embedder.linear_1.weight.dtype)) + temb = temb.type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + return temb, timestep_proj, encoder_hidden_states + + +# ─── Attention ─────────────────────────────────────────────────────────────── + +class WanAttention(nn.Module): + """Multi-head attention with across-heads RMSNorm on Q/K.""" + + def __init__(self, dim, num_heads, dim_head, eps=1e-6): + super().__init__() + self.num_heads = num_heads + self.dim_head = dim_head + inner_dim = num_heads * dim_head + + self.to_q = nn.Linear(dim, inner_dim, bias=True) + self.to_k = nn.Linear(dim, inner_dim, bias=True) + self.to_v = nn.Linear(dim, inner_dim, bias=True) + self.to_out = nn.ModuleList([nn.Linear(inner_dim, dim, bias=True), nn.Dropout(0.0)]) + self.norm_q = nn.RMSNorm(inner_dim, eps=eps, elementwise_affine=True) + self.norm_k = nn.RMSNorm(inner_dim, eps=eps, elementwise_affine=True) + + def forward(self, hidden_states, encoder_hidden_states=None, rotary_emb=None): + kv_input = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + + query = self.norm_q(self.to_q(hidden_states)) + key = self.norm_k(self.to_k(kv_input)) + value = self.to_v(kv_input) + + query = query.unflatten(2, (self.num_heads, self.dim_head)) + key = key.unflatten(2, (self.num_heads, self.dim_head)) + value = value.unflatten(2, (self.num_heads, self.dim_head)) + + if rotary_emb is not None: + query = _apply_rotary_emb(query, *rotary_emb) + key = _apply_rotary_emb(key, *rotary_emb) + + out = F.scaled_dot_product_attention( + query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), + ).transpose(1, 2).flatten(2, 3) + + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + + +def _apply_rotary_emb(x, freqs_cos, freqs_sin): + """Apply rotary position embedding. x: [B, S, H, D], freqs: [B, S, 1, D].""" + x1, x2 = x.unflatten(-1, (-1, 2)).unbind(-1) + cos = freqs_cos[..., 0::2] + sin = freqs_sin[..., 1::2] + out = torch.empty_like(x) + out[..., 0::2] = x1 * cos - x2 * sin + out[..., 1::2] = x1 * sin + x2 * cos + return out.type_as(x) + + +# ─── Transformer Block ─────────────────────────────────────────────────────── + +class FP32LayerNorm(nn.LayerNorm): + def forward(self, x): + return F.layer_norm(x.float(), self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps).type_as(x) + + +class WanTransformerBlock(nn.Module): + def __init__(self, dim, ffn_dim, num_heads, eps=1e-6, cross_attn_norm=True): + super().__init__() + dim_head = dim // num_heads + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.attn1 = WanAttention(dim, num_heads, dim_head, eps) + self.attn2 = WanAttention(dim, num_heads, dim_head, eps) + self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.ffn = WanFeedForward(dim, ffn_dim) + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward(self, hidden_states, encoder_hidden_states, temb, rotary_emb): + shift_msa, scale_msa, gate_msa, c_shift, c_scale, c_gate = ( + self.scale_shift_table + temb.float() + ).chunk(6, dim=1) + + norm_hs = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) + attn_out = self.attn1(norm_hs, rotary_emb=rotary_emb) + hidden_states = (hidden_states.float() + attn_out * gate_msa).type_as(hidden_states) + + norm_hs = self.norm2(hidden_states.float()).type_as(hidden_states) + attn_out = self.attn2(norm_hs, encoder_hidden_states=encoder_hidden_states) + hidden_states = hidden_states + attn_out + + norm_hs = (self.norm3(hidden_states.float()) * (1 + c_scale) + c_shift).type_as(hidden_states) + ff_out = self.ffn(norm_hs) + hidden_states = (hidden_states.float() + ff_out.float() * c_gate).type_as(hidden_states) + return hidden_states + + +class WanFeedForward(nn.Module): + """GELU feed-forward: Linear -> GELU -> Dropout -> Linear.""" + def __init__(self, dim, ffn_dim): + super().__init__() + self.net = nn.ModuleList([ + GELU(dim, ffn_dim), + nn.Dropout(0.0), + nn.Linear(ffn_dim, dim, bias=True), + ]) + + def forward(self, x): + for layer in self.net: + x = layer(x) + return x + + +class GELU(nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + self.proj = nn.Linear(in_dim, out_dim, bias=True) + self.gelu = nn.GELU(approximate="tanh") + + def forward(self, x): + return self.gelu(self.proj(x)) + + +# ─── Full Model ────────────────────────────────────────────────────────────── + +class NeuronWanTransformer3DModel(nn.Module): + """Standalone WAN 2.1 backbone for Neuron. + + Takes pre-computed RoPE as inputs to avoid XLA unused-input bug. + State dict keys match HF WanTransformer3DModel for direct weight loading. + """ + + def __init__(self, num_heads=12, dim_head=128, num_layers=30, ffn_dim=8960, + in_channels=16, out_channels=16, text_dim=4096, freq_dim=256, + patch_size=(1, 2, 2), eps=1e-6, cross_attn_norm=True): + super().__init__() + dim = num_heads * dim_head + self.patch_size = patch_size + + self.patch_embedding = nn.Conv3d(in_channels, dim, kernel_size=patch_size, stride=patch_size) + self.condition_embedder = WanConditionEmbedder(dim, freq_dim, text_dim) + self.blocks = nn.ModuleList([ + WanTransformerBlock(dim, ffn_dim, num_heads, eps, cross_attn_norm) + for _ in range(num_layers) + ]) + self.norm_out = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(dim, out_channels * math.prod(patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, hidden_states, timestep, encoder_hidden_states, rope_cos, rope_sin): + b, c, f, h, w = hidden_states.shape + p_t, p_h, p_w = self.patch_size + ppf, pph, ppw = f // p_t, h // p_h, w // p_w + + rotary_emb = (rope_cos, rope_sin) + + hidden_states = self.patch_embedding(hidden_states).flatten(2).transpose(1, 2) + + temb, timestep_proj, encoder_hidden_states = self.condition_embedder( + timestep.expand(hidden_states.shape[0]), encoder_hidden_states) + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + for block in self.blocks: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + + shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape(b, ppf, pph, ppw, p_t, p_h, p_w, -1) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + return hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + @classmethod + def from_pretrained(cls, path, torch_dtype=torch.bfloat16): + """Load from HF WanTransformer3DModel pretrained weights.""" + from diffusers.models.transformers.transformer_wan import WanTransformer3DModel as HF + hf = HF.from_pretrained(path, torch_dtype=torch_dtype) + cfg = hf.config + model = cls( + num_heads=cfg.num_attention_heads, dim_head=cfg.attention_head_dim, + num_layers=cfg.num_layers, ffn_dim=cfg.ffn_dim, + in_channels=cfg.in_channels, out_channels=cfg.out_channels, + text_dim=cfg.text_dim, freq_dim=cfg.freq_dim, + patch_size=tuple(cfg.patch_size), eps=cfg.eps, + cross_attn_norm=cfg.cross_attn_norm, + ).to(torch_dtype) + model.load_state_dict(hf.state_dict(), strict=False) + # Compute RoPE helper + model._rope = hf.rope + return model + + def compute_rope(self, dummy_latent): + """Pre-compute RoPE on CPU. Call once, pass results to forward().""" + return self._rope(dummy_latent) + + +# ─── VAE Decoder XLA Compatibility ─────────────────────────────────────────── + +def make_decoder_xla_compatible(decoder): + """Monkey-patch a WAN VAE decoder to fix XLA incompatibilities. + + 1. Replace x[:,:,-N:,:,:] with safe indexing (XLA negative index bug) + 2. Replace mode='nearest-exact' with mode='nearest' (XLA custom-call bug) + + Call once after loading VAE, before tracing. No diffusers source changes. + """ + from diffusers.models.autoencoders.autoencoder_kl_wan import ( + WanResample, WanResidualBlock, WanDecoder3d, CACHE_T, + ) + _patch_resample_forward(WanResample, CACHE_T) + _patch_residual_block_forward(WanResidualBlock, CACHE_T) + _patch_decoder_forward(WanDecoder3d, CACHE_T) + for m in decoder.modules(): + if isinstance(m, nn.Upsample) and m.mode == "nearest-exact": + m.mode = "nearest" + + +def _patch_resample_forward(cls, CACHE_T): + def safe_forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T):, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + cache_x = torch.cat([feat_cache[idx][:, :, max(0, feat_cache[idx].shape[2] - 1), :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, max(0, x.shape[2] - 1):, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + return x + cls.forward = safe_forward + + +def _patch_residual_block_forward(cls, CACHE_T): + def safe_forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.conv_shortcut(x) + x = self.norm1(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T):, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, max(0, feat_cache[idx].shape[2] - 1), :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + x = self.norm2(x) + x = self.nonlinearity(x) + x = self.dropout(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T):, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, max(0, feat_cache[idx].shape[2] - 1), :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + return x + h + cls.forward = safe_forward + + +def _patch_decoder_forward(cls, CACHE_T): + def safe_forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T):, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, max(0, feat_cache[idx].shape[2] - 1), :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx) + for up_block in self.up_blocks: + x = up_block(x, feat_cache=feat_cache, feat_idx=feat_idx, first_chunk=first_chunk) + x = self.nonlinearity(self.norm_out(x)) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T):, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, max(0, feat_cache[idx].shape[2] - 1), :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + cls.forward = safe_forward + + +# ─── VAE Block Wrappers (for Neuron tracing) ───────────────────────────────── + +VAE_BLOCK_ORDER = [ + ("conv_in", 0, 1), ("mid_block", 1, 4), + ("up0_resnet0", 5, 2), ("up0_resnet1", 7, 2), ("up0_resnet2", 9, 2), ("up0_upsample", 11, 1), + ("up1_resnet0", 12, 2), ("up1_resnet1", 14, 2), ("up1_resnet2", 16, 2), ("up1_upsample", 18, 1), + ("up2_resnet0", 19, 2), ("up2_resnet1", 21, 2), ("up2_resnet2", 23, 2), ("up2_upsample", 25, 0), + ("up3_resnet0", 25, 2), ("up3_resnet1", 27, 2), ("up3_resnet2", 29, 2), ("norm_conv_out", 31, 1), +] + + +class ConvInCached(nn.Module): + def __init__(self, conv): + super().__init__() + self.conv = conv + def forward(self, x, cache): + out = self.conv(x, cache_x=cache) + new_cache = torch.cat([cache, x], dim=2)[:, :, max(0, cache.shape[2] + x.shape[2] - 2):, :, :] + return out, new_cache + + +class BlockCached(nn.Module): + def __init__(self, block, cache_start, cache_count): + super().__init__() + self.block = block + self.cs = cache_start + self.cc = cache_count + def forward(self, x, *caches): + fc = [None] * 33 + for i, c in enumerate(caches): + fc[self.cs + i] = c + idx = [self.cs] + out = self.block(x, feat_cache=fc, feat_idx=idx) + return (out,) + tuple(fc[self.cs + i] for i in range(self.cc)) + + +class NormConvOutCached(nn.Module): + def __init__(self, decoder): + super().__init__() + self.norm = decoder.norm_out + self.act = decoder.nonlinearity + self.conv = decoder.conv_out + def forward(self, x, cache): + h = self.act(self.norm(x)) + out = self.conv(h, cache_x=cache) + new_cache = torch.cat([cache, h], dim=2)[:, :, max(0, cache.shape[2] + h.shape[2] - 2):, :, :] + return out, new_cache + + +class NoCacheWrapper(nn.Module): + def __init__(self, block): + super().__init__() + self.block = block + def forward(self, x): + return self.block(x) + + +class UpsampleSpatialOnly(nn.Module): + """Frame-0 upsampler: spatial only, no time_conv.""" + def __init__(self, block): + super().__init__() + self.resample = block.resample + def forward(self, x): + b, c, t, h, w = x.size() + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + return x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + +class UpsampleFirstChunk(nn.Module): + """Frame-1 upsampler: time_conv without cache (Rep path).""" + def __init__(self, block): + super().__init__() + self.block = block + def forward(self, x): + b, c, t, h, w = x.size() + x = self.block.time_conv(x) + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t2 = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t2, c, h, w) + x = self.block.resample(x) + return x.view(b, t2, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + +# ─── T5 Wrapper ────────────────────────────────────────────────────────────── + +class T5Wrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + def forward(self, input_ids, attention_mask): + return self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state diff --git a/contrib/models/Wan2.1-T2V-1.3B/src/scripts/compile_backbone_cp8.py b/contrib/models/Wan2.1-T2V-1.3B/src/scripts/compile_backbone_cp8.py new file mode 100644 index 00000000..500c86f6 --- /dev/null +++ b/contrib/models/Wan2.1-T2V-1.3B/src/scripts/compile_backbone_cp8.py @@ -0,0 +1,178 @@ +"""CP=8 WAN backbone split into 2 NEFFs of 15 blocks each.""" +import torch, torch.nn as nn, torch.nn.functional as F, os, sys, time + +sys.path.insert(0, "/mnt/work/wan/all_on_neuron_checkpt/contrib/models/wan2.1-t2v-1.3b/src") + +import neuronx_distributed +import torch_neuronx +from neuronx_distributed.parallel_layers.parallel_state import ( + initialize_model_parallel, get_tensor_model_parallel_size, get_tensor_model_parallel_rank, +) +from neuronx_distributed.parallel_layers.mappings import gather_from_tensor_model_parallel_region + +if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend="xla") +initialize_model_parallel(tensor_model_parallel_size=8) +rank = get_tensor_model_parallel_rank() +cp = get_tensor_model_parallel_size() + +class CPAttn(nn.Module): + def __init__(self, dim, heads, dim_head, eps=1e-6): + super().__init__() + self.heads, self.dim_head = heads, dim_head + inner = dim_head * heads + self.to_q = nn.Linear(dim, inner, bias=True) + self.to_k = nn.Linear(dim, inner, bias=True) + self.to_v = nn.Linear(dim, inner, bias=True) + self.to_out = nn.Sequential(nn.Linear(inner, dim, bias=True), nn.Dropout(0.0)) + self.norm_q = nn.RMSNorm(inner, eps=eps, elementwise_affine=True) + self.norm_k = nn.RMSNorm(inner, eps=eps, elementwise_affine=True) + + def _gather_seq(self, t): + return gather_from_tensor_model_parallel_region(t.transpose(1, 2)).transpose(1, 2) + + def forward(self, x, enc=None, rcl=None, rsl=None, rcf=None, rsf=None): + kv = enc if enc is not None else x + q = self.norm_q(self.to_q(x)) + k = self.norm_k(self.to_k(kv)) + v = self.to_v(kv) + if enc is None: + k = self._gather_seq(k); v = self._gather_seq(v) + q = q.unflatten(2, (self.heads, self.dim_head)).transpose(1, 2) + k = k.unflatten(2, (self.heads, self.dim_head)).transpose(1, 2) + v = v.unflatten(2, (self.heads, self.dim_head)).transpose(1, 2) + if rcl is not None and enc is None: + def rope(x, c, s): + c = c.transpose(1, 2); s = s.transpose(1, 2) + x1, x2 = x.unflatten(-1, (-1, 2)).unbind(-1) + out = torch.empty_like(x) + out[..., 0::2] = x1 * c[..., 0::2] - x2 * s[..., 1::2] + out[..., 1::2] = x1 * s[..., 1::2] + x2 * c[..., 0::2] + return out.type_as(x) + q = rope(q, rcl, rsl); k = rope(k, rcf, rsf) + out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).flatten(2, 3) + return self.to_out(out) + +class CPBlock(nn.Module): + def __init__(self, dim=1536, ffn_dim=8960, heads=12, eps=1e-6): + super().__init__() + self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.attn1 = CPAttn(dim, heads, dim // heads, eps) + self.attn2 = CPAttn(dim, heads, dim // heads, eps) + self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=True) + self.norm3 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.ffn_up = nn.Linear(dim, ffn_dim, bias=True) + self.ffn_down = nn.Linear(ffn_dim, dim, bias=True) + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward(self, x, enc, temb, rcl, rsl, rcf, rsf): + s = (self.scale_shift_table + temb.float()).chunk(6, dim=1) + h = (self.norm1(x.float()) * (1 + s[1]) + s[0]).type_as(x) + x = (x.float() + self.attn1(h, rcl=rcl, rsl=rsl, rcf=rcf, rsf=rsf) * s[2]).type_as(x) + h = self.norm2(x.float()).type_as(x) + x = x + self.attn2(h, enc=enc) + h = (self.norm3(x.float()) * (1 + s[4]) + s[3]).type_as(x) + x = (x.float() + self.ffn_down(F.gelu(self.ffn_up(h), approximate="tanh")).float() * s[5]).type_as(x) + return x + +# NEFF 1: patch_embed + condition + blocks[0:15] + pre-gather RoPE +class NEFF1(nn.Module): + def __init__(self): + super().__init__() + self.patch_embedding = nn.Conv3d(16, 1536, kernel_size=(1,2,2), stride=(1,2,2)) + self.blocks = nn.ModuleList([CPBlock() for _ in range(15)]) + self.condition_embedder = None + + def forward(self, hidden_states, timestep, enc, rope_cos, rope_sin): + x = self.patch_embedding(hidden_states).flatten(2).transpose(1, 2) + temb, tp, enc, _ = self.condition_embedder(timestep.expand(x.shape[0]), enc, None, timestep_seq_len=None) + tp = tp.unflatten(1, (6, -1)) + full_seq = x.shape[1] + local_seq = full_seq // cp + x = x[:, rank * local_seq:(rank + 1) * local_seq] + rcl = rope_cos[:, rank * local_seq:(rank + 1) * local_seq] + rsl = rope_sin[:, rank * local_seq:(rank + 1) * local_seq] + rcf = gather_from_tensor_model_parallel_region(rcl.squeeze(2).transpose(1,2)).transpose(1,2).unsqueeze(2) + rsf = gather_from_tensor_model_parallel_region(rsl.squeeze(2).transpose(1,2)).transpose(1,2).unsqueeze(2) + for block in self.blocks: + x = block(x, enc, tp, rcl, rsl, rcf, rsf) + return x, temb, enc, tp, rcl, rsl, rcf, rsf + +# NEFF 2: blocks[15:30] + norm + proj + gather +class NEFF2(nn.Module): + def __init__(self): + super().__init__() + self.blocks = nn.ModuleList([CPBlock() for _ in range(15)]) + self.norm_out = nn.LayerNorm(1536, elementwise_affine=False) + self.proj_out = nn.Linear(1536, 64, bias=True) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, 1536) / 1536**0.5) + + def forward(self, x, temb, enc, tp, rcl, rsl, rcf, rsf): + for block in self.blocks: + x = block(x, enc, tp, rcl, rsl, rcf, rsf) + x = gather_from_tensor_model_parallel_region(x.transpose(1, 2)).transpose(1, 2) + shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + x = (self.norm_out(x.float()) * (1 + scale) + shift).type_as(x) + return self.proj_out(x) + +# Load weights +from diffusers.models.transformers.transformer_wan import WanTransformer3DModel +hf = WanTransformer3DModel.from_pretrained('/mnt/work/wan/agent_artifacts/data/transformer', torch_dtype=torch.bfloat16) + +n1 = NEFF1().to(torch.bfloat16).eval() +n2 = NEFF2().to(torch.bfloat16).eval() +n1.condition_embedder = hf.condition_embedder + +hf_sd = hf.state_dict() +for neff, prefix, block_range in [(n1, "", range(15)), (n2, "", range(15, 30))]: + mapped = {} + for k, v in hf_sd.items(): + mk = k.replace(".ffn.net.0.proj.", ".ffn_up.").replace(".ffn.net.2.", ".ffn_down.") + if ".to_out.1." in k: continue + # Remap block indices for NEFF2 + if neff is n2: + for bi in block_range: + mk = mk.replace(f"blocks.{bi}.", f"blocks.{bi-15}.") + if mk in neff.state_dict(): mapped[mk] = v + neff.load_state_dict(mapped, strict=False) + +del hf, hf_sd + +# Trace +h = torch.randn(1, 16, 13, 60, 104, dtype=torch.bfloat16) +t = torch.tensor([999], dtype=torch.int64) +e = torch.randn(1, 512, 4096, dtype=torch.bfloat16) +hf2 = WanTransformer3DModel.from_pretrained('/mnt/work/wan/agent_artifacts/data/transformer', torch_dtype=torch.bfloat16) +rc, rs = hf2.rope(h); del hf2 + +CC = "--model-type=transformer -O1 --auto-cast=none --internal-hlo2tensorizer-options='--verify-hlo=false'" + +print(f"[Rank {rank}] Tracing NEFF1 (15 blocks)...") +t0 = time.time() +try: + tr1 = torch_neuronx.trace(n1, (h, t, e, rc, rs), compiler_args=CC) + print(f"[Rank {rank}] NEFF1 COMPILED in {time.time()-t0:.0f}s!") +except Exception as ex: + print(f"[Rank {rank}] NEFF1 FAILED: {str(ex)[:150]}") + sys.exit(1) + +# Get NEFF1 output shapes for NEFF2 input +r1 = tr1(h, t, e, rc, rs) +x_mid = r1[0] +print(f"[Rank {rank}] NEFF1 output: x={x_mid.shape}") + +print(f"[Rank {rank}] Tracing NEFF2 (15 blocks)...") +t0 = time.time() +try: + tr2 = torch_neuronx.trace(n2, r1, compiler_args=CC) + print(f"[Rank {rank}] NEFF2 COMPILED in {time.time()-t0:.0f}s!") + final = tr2(*r1) + print(f"[Rank {rank}] Final output: {final.shape}") + + out_dir = "/mnt/work/wan/agent_artifacts/data/traced_wan_cp8_49f" + os.makedirs(out_dir, exist_ok=True) + torch.jit.save(tr1, f"{out_dir}/neff1_rank{rank}.pt") + torch.jit.save(tr2, f"{out_dir}/neff2_rank{rank}.pt") + print(f"[Rank {rank}] Saved!") +except Exception as ex: + print(f"[Rank {rank}] NEFF2 FAILED: {str(ex)[:150]}") diff --git a/contrib/models/Wan2.1-T2V-1.3B/src/scripts/run_pipeline.py b/contrib/models/Wan2.1-T2V-1.3B/src/scripts/run_pipeline.py new file mode 100644 index 00000000..794bdbd5 --- /dev/null +++ b/contrib/models/Wan2.1-T2V-1.3B/src/scripts/run_pipeline.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 +""" +WAN 2.1 T2V 1.3B — Full E2E pipeline on Neuron. + +Assumes traced NEFFs already exist in this directory: + - traced_t5_lnc1.pt + - traced_wan_480x832.pt + - vae_blocks_cached/*.pt + +Usage (two-step due to LNC mismatch): + + # Step 1: T5 encoding (LNC=1) + NEURON_RT_VISIBLE_CORES=0-3 NEURON_RT_VIRTUAL_CORE_SIZE=1 \ + python3 run_pipeline.py --step=t5 --prompt="a cat on a beach at sunset" + + # Step 2: Backbone + VAE (LNC=2) + NEURON_RT_VISIBLE_CORES=0-3 \ + python3 run_pipeline.py --step=generate --seed=42 --steps=20 --cfg=5.0 + +Or run both steps: + python3 run_pipeline.py --prompt="a cat on a beach at sunset" --seed=42 +""" + +import argparse +import os +import subprocess +import sys +import time + +import torch +import torch_neuronx # must import before torch.jit.load for Neuron models +import numpy as np +from pathlib import Path +from PIL import Image + +SCRIPT_DIR = Path(__file__).parent +MODEL_DIR = Path(os.environ.get("WAN_MODEL_DIR", "/mnt/work/wan/agent_artifacts/data")) +EMBEDS_PATH = Path("/tmp/t5_embeds.pt") + + +def step_t5(prompt: str, negative_prompt: str = ""): + """Encode prompt with T5 on Neuron (LNC=1). Saves embeddings to /tmp.""" + print(f"[T5] Encoding: {prompt!r}") + t0 = time.time() + + traced_t5 = torch.jit.load(str(SCRIPT_DIR / "traced_t5_lnc1.pt")) + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(str(MODEL_DIR / "tokenizer")) + + def encode(text): + tokens = tokenizer(text, max_length=512, padding="max_length", + truncation=True, return_tensors="pt") + ids = tokens.input_ids.to(torch.int64) + mask = tokens.attention_mask.to(torch.int64) + embeds = traced_t5(ids, mask) + # Zero out padding positions (critical — HF pipeline does this) + seq_len = mask.sum(dim=1, keepdim=True).unsqueeze(-1) + positions = torch.arange(512).unsqueeze(0).unsqueeze(-1) + embeds = embeds * (positions < seq_len).to(embeds.dtype) + return embeds + + with torch.no_grad(): + pe = encode(prompt) + ne = encode(negative_prompt) + + torch.save({"pe": pe, "ne": ne}, str(EMBEDS_PATH)) + print(f"[T5] Done in {time.time()-t0:.1f}s -> {EMBEDS_PATH}") + + +def step_generate(seed: int = 42, steps: int = 20, cfg: float = 5.0): + """Denoise with backbone on Neuron + decode with VAE hybrid CPU/Neuron.""" + from diffusers import AutoencoderKLWan, UniPCMultistepScheduler + from diffusers.models.transformers.transformer_wan import WanTransformer3DModel + from modeling import VAE_BLOCK_ORDER + + # Load embeddings + embeds = torch.load(str(EMBEDS_PATH)) + pe, ne = embeds["pe"], embeds["ne"] + + # Load backbone + print("[Backbone] Loading traced model...") + traced_bb = torch.jit.load(str(SCRIPT_DIR / "traced_wan_480x832.pt")) + + # Pre-compute RoPE on CPU + m = WanTransformer3DModel.from_pretrained(str(MODEL_DIR / "transformer"), torch_dtype=torch.bfloat16) + rc, rs = m.rope(torch.randn(1, 16, 4, 60, 104, dtype=torch.bfloat16)) + del m + + # Scheduler + scheduler = UniPCMultistepScheduler.from_pretrained(str(MODEL_DIR / "scheduler")) + + # Denoise + print(f"[Backbone] Denoising ({steps} steps, seed={seed})...") + gen = torch.Generator("cpu").manual_seed(seed) + latents = torch.randn(1, 16, 4, 60, 104, dtype=torch.float32, generator=gen) + scheduler.set_timesteps(steps, device="cpu") + + t0 = time.time() + for i, tt in enumerate(scheduler.timesteps): + li = latents.to(torch.bfloat16) + c = traced_bb(li, tt.expand(1), pe, rc, rs) + u = traced_bb(li, tt.expand(1), ne, rc, rs) + latents = scheduler.step(u + cfg * (c - u), tt, latents, return_dict=False)[0] + if i % 5 == 0: + print(f" step {i}/{steps}") + bb_time = time.time() - t0 + print(f"[Backbone] {bb_time:.1f}s") + + # VAE decode (hybrid) + print("[VAE] Loading...") + vae = AutoencoderKLWan.from_pretrained(str(MODEL_DIR / "vae"), torch_dtype=torch.bfloat16) + vae.eval() + + # Make decoder XLA-compatible (monkey-patches forward methods, no source changes) + from modeling import make_decoder_xla_compatible + make_decoder_xla_compatible(vae.decoder) + + blocks_dir = SCRIPT_DIR / "vae_blocks_cached" + B = {f.stem: torch.jit.load(str(f)) for f in sorted(blocks_dir.glob("*.pt"))} + + lm = torch.tensor(vae.config.latents_mean).view(1, 16, 1, 1, 1).to(torch.bfloat16) + ls = 1.0 / torch.tensor(vae.config.latents_std).view(1, 16, 1, 1, 1).to(torch.bfloat16) + x = vae.post_quant_conv(latents.to(torch.bfloat16) / ls + lm) + + print("[VAE] CPU frames 0+1...") + t0 = time.time() + vae.clear_cache() + frames = [] + for i in range(2): + vae._conv_idx = [0] + kw = dict(first_chunk=True) if i == 0 else {} + with torch.no_grad(): + frames.append(vae.decoder(x[:, :, i:i+1, :, :], + feat_cache=vae._feat_map, feat_idx=vae._conv_idx, **kw)) + + print("[VAE] Neuron frames 2+3...") + cache = [v.clone() if isinstance(v, torch.Tensor) else v for v in vae._feat_map] + for fi in range(2, 4): + h = x[:, :, fi:fi+1, :, :] + for name, cs, cc in VAE_BLOCK_ORDER: + c_in = tuple(cache[cs + i] for i in range(cc)) + result = B[name](h, *c_in) + if cc > 0: + h = result[0] + for i in range(cc): + cache[cs + i] = result[1 + i] + else: + h = result + frames.append(h) + + vae_time = time.time() - t0 + print(f"[VAE] {vae_time:.1f}s") + + # Save frames + video = torch.cat(frames, dim=2) + video = ((video.clamp(-1, 1) + 1) / 2 * 255).to(torch.uint8)[0].permute(1, 2, 3, 0).cpu().numpy() + + out_dir = SCRIPT_DIR / "output" + out_dir.mkdir(exist_ok=True) + for i in range(video.shape[0]): + Image.fromarray(video[i]).save(out_dir / f"frame_{i:03d}.png") + + print(f"\nDone! {video.shape[0]} frames at {video.shape[2]}x{video.shape[1]}") + print(f" Backbone: {bb_time:.1f}s") + print(f" VAE: {vae_time:.1f}s") + print(f" Total: {bb_time + vae_time:.1f}s") + print(f" Output: {out_dir}/") + + +def main(): + parser = argparse.ArgumentParser(description="WAN 2.1 T2V on Neuron") + parser.add_argument("--step", choices=["t5", "generate", "both"], default="both") + parser.add_argument("--prompt", default="a cat sitting on a beach at sunset") + parser.add_argument("--negative-prompt", default="") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--steps", type=int, default=20) + parser.add_argument("--cfg", type=float, default=5.0) + args = parser.parse_args() + + if args.step in ("t5", "both"): + if args.step == "both": + # Run T5 in subprocess with LNC=1 + env = os.environ.copy() + env["NEURON_RT_VISIBLE_CORES"] = env.get("NEURON_RT_VISIBLE_CORES", "0-3") + env["NEURON_RT_VIRTUAL_CORE_SIZE"] = "1" + subprocess.run([ + sys.executable, __file__, + "--step=t5", f"--prompt={args.prompt}", f"--negative-prompt={args.negative_prompt}", + ], env=env, check=True) + else: + step_t5(args.prompt, args.negative_prompt) + + if args.step in ("generate", "both"): + step_generate(args.seed, args.steps, args.cfg) + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Wan2.1-T2V-1.3B/src/scripts/trace_all.py b/contrib/models/Wan2.1-T2V-1.3B/src/scripts/trace_all.py new file mode 100644 index 00000000..ab2ff5e0 --- /dev/null +++ b/contrib/models/Wan2.1-T2V-1.3B/src/scripts/trace_all.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python3 +""" +Trace all Neuron NEFFs from scratch (backbone, T5, VAE blocks). + +Prerequisites: + - Model weights at WAN_MODEL_DIR (default: /mnt/work/wan/agent_artifacts/data) + - VAE patches applied (run: python3 -c "from modeling import apply_vae_patches; apply_vae_patches('')") + +Usage: + # Trace backbone + VAE (LNC=2) + NEURON_RT_VISIBLE_CORES=0-3 python3 trace_all.py --component=backbone + NEURON_RT_VISIBLE_CORES=0-3 python3 trace_all.py --component=vae + + # Trace T5 (LNC=1) + NEURON_RT_VISIBLE_CORES=0-3 NEURON_RT_VIRTUAL_CORE_SIZE=1 python3 trace_all.py --component=t5 + + # Or all at once (handles LNC switching via subprocess): + python3 trace_all.py --component=all +""" + +import argparse +import os +import subprocess +import sys +import time +from pathlib import Path + +import torch +import torch_neuronx + +from modeling import ( + WanBackboneWrapper, T5Wrapper, + ConvInCached, BlockCached, NormConvOutCached, NoCacheWrapper, + VAE_BLOCK_ORDER, make_decoder_xla_compatible, +) + +SCRIPT_DIR = Path(__file__).parent +MODEL_DIR = Path(os.environ.get("WAN_MODEL_DIR", "/mnt/work/wan/agent_artifacts/data")) +CC_TRANSFORMER = "--model-type=transformer -O1 --auto-cast=none --internal-hlo2tensorizer-options='--verify-hlo=false'" +CC_VAE = "--model-type=unet-inference -O1 --auto-cast=none --internal-hlo2tensorizer-options='--verify-hlo=false'" + + +def trace_backbone(): + from diffusers.models.transformers.transformer_wan import WanTransformer3DModel + + print("[Backbone] Loading weights...") + model = WanTransformer3DModel.from_pretrained(str(MODEL_DIR / "transformer"), torch_dtype=torch.bfloat16) + model.eval() + + wrapper = WanBackboneWrapper(model) + h = torch.randn(1, 16, 4, 60, 104, dtype=torch.bfloat16) + t = torch.tensor([999], dtype=torch.int64) + e = torch.randn(1, 512, 4096, dtype=torch.bfloat16) + rc, rs = model.rope(h) + + print("[Backbone] Tracing (this takes ~3 min)...") + t0 = time.time() + traced = torch_neuronx.trace(wrapper, (h, t, e, rc, rs), compiler_args=CC_TRANSFORMER) + print(f"[Backbone] Compiled in {time.time()-t0:.0f}s") + + out = SCRIPT_DIR / "traced_wan_480x832.pt" + torch.jit.save(traced, str(out)) + print(f"[Backbone] Saved: {out}") + + +def trace_t5(): + from transformers import T5EncoderModel + + print("[T5] Loading weights...") + t5 = T5EncoderModel.from_pretrained(str(MODEL_DIR / "text_encoder"), torch_dtype=torch.bfloat16) + t5.eval() + + wrapper = T5Wrapper(t5) + ids = torch.ones(1, 512, dtype=torch.int64) + mask = torch.cat([torch.ones(1, 9, dtype=torch.int64), torch.zeros(1, 503, dtype=torch.int64)], dim=1) + + print("[T5] Tracing (LNC=1, ~2 min)...") + t0 = time.time() + traced = torch_neuronx.trace(wrapper, (ids, mask), + compiler_args=CC_TRANSFORMER + " --logical-nc-config=1") + print(f"[T5] Compiled in {time.time()-t0:.0f}s") + + out = SCRIPT_DIR / "traced_t5_lnc1.pt" + torch.jit.save(traced, str(out)) + print(f"[T5] Saved: {out}") + + +def trace_vae(): + from diffusers import AutoencoderKLWan + + print("[VAE] Loading weights...") + vae = AutoencoderKLWan.from_pretrained(str(MODEL_DIR / "vae"), torch_dtype=torch.bfloat16) + vae.eval() + d = vae.decoder + + # Make decoder XLA-compatible (monkey-patches forward methods, no source changes) + make_decoder_xla_compatible(d) + + # CPU warmup to populate cache + z = torch.randn(1, 16, 4, 60, 104, dtype=torch.bfloat16) + lm = torch.tensor(vae.config.latents_mean).view(1, 16, 1, 1, 1).to(torch.bfloat16) + ls = 1.0 / torch.tensor(vae.config.latents_std).view(1, 16, 1, 1, 1).to(torch.bfloat16) + x = vae.post_quant_conv(z / ls + lm) + + vae.clear_cache() + for i in range(2): + vae._conv_idx = [0] + kw = dict(first_chunk=True) if i == 0 else {} + vae.decoder(x[:, :, i:i+1, :, :], feat_cache=vae._feat_map, feat_idx=vae._conv_idx, **kw) + + cache = [v.clone() if isinstance(v, torch.Tensor) else v for v in vae._feat_map] + + # Get per-block input shapes by running frame 2 on CPU + h = x[:, :, 2:3, :, :] + block_inputs = {} + idx = [0] + + with torch.no_grad(): + # conv_in + block_inputs["conv_in"] = h.clone() + h = d.conv_in(h, cache_x=cache[0]) + cache[0] = torch.cat([cache[0], block_inputs["conv_in"]], dim=2)[:, :, -2:, :, :] + idx[0] = 1 + + # mid_block + block_inputs["mid_block"] = h.clone() + h = d.mid_block(h, feat_cache=cache, feat_idx=idx) + + # up_blocks + for i in range(4): + ub = d.up_blocks[i] + for j in range(3): + block_inputs[f"up{i}_resnet{j}"] = h.clone() + h = ub.resnets[j](h, feat_cache=cache, feat_idx=idx) + if hasattr(ub, "upsamplers") and ub.upsamplers and len(ub.upsamplers) > 0: + block_inputs[f"up{i}_upsample"] = h.clone() + h = ub.upsamplers[0](h, feat_cache=cache, feat_idx=idx) + + # norm_conv_out + block_inputs["norm_conv_out"] = h.clone() + + # Re-populate cache for tracing + vae.clear_cache() + for i in range(2): + vae._conv_idx = [0] + kw = dict(first_chunk=True) if i == 0 else {} + vae.decoder(x[:, :, i:i+1, :, :], feat_cache=vae._feat_map, feat_idx=vae._conv_idx, **kw) + cache = [v.clone() if isinstance(v, torch.Tensor) else v for v in vae._feat_map] + + # Trace each block + out_dir = SCRIPT_DIR / "vae_blocks_cached" + out_dir.mkdir(exist_ok=True) + + def get_module(name): + if name == "conv_in": return d.conv_in + if name == "mid_block": return d.mid_block + if name == "norm_conv_out": return None + parts = name.split("_") + i = int(parts[0][2:]) + if "resnet" in name: + return d.up_blocks[i].resnets[int(parts[1][-1])] + return d.up_blocks[i].upsamplers[0] + + total = len(VAE_BLOCK_ORDER) + for bi, (name, cs, cc) in enumerate(VAE_BLOCK_ORDER): + pct = int((bi / total) * 100) + h_in = block_inputs[name] + + if name == "conv_in": + wrapper = ConvInCached(d.conv_in) + c_in = (cache[cs],) + elif name == "norm_conv_out": + wrapper = NormConvOutCached(d) + c_in = (cache[cs],) + elif cc == 0: + wrapper = NoCacheWrapper(get_module(name)) + c_in = () + else: + wrapper = BlockCached(get_module(name), cs, cc) + c_in = tuple(cache[cs + i] for i in range(cc)) + + with torch.no_grad(): + cpu_result = wrapper(h_in, *c_in) + + cpu_out = cpu_result[0] if isinstance(cpu_result, tuple) and cc > 0 else cpu_result + print(f"[{pct:3d}%] {name}: {list(h_in.shape)} -> {list(cpu_out.shape)}...", end=" ", flush=True) + + traced = torch_neuronx.trace(wrapper, (h_in, *c_in), compiler_args=CC_VAE) + + import torch.nn.functional as F + n_result = traced(h_in, *c_in) + n_out = n_result[0] if isinstance(n_result, tuple) and cc > 0 else n_result + cos = F.cosine_similarity(cpu_out.flatten().float(), n_out.flatten().float(), dim=0) + print(f"cos={cos.item():.4f}") + + torch.jit.save(traced, str(out_dir / f"{name}.pt")) + + # Update cache for next block + if cc > 0 and isinstance(cpu_result, tuple): + for i in range(cc): + cache[cs + i] = cpu_result[1 + i] + + print(f"[100%] All {total} VAE blocks traced -> {out_dir}/") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--component", choices=["backbone", "t5", "vae", "all"], required=True) + args = parser.parse_args() + + if args.component == "all": + env = os.environ.copy() + cores = env.get("NEURON_RT_VISIBLE_CORES", "0-3") + + # T5 in subprocess (LNC=1) + env["NEURON_RT_VISIBLE_CORES"] = cores + env["NEURON_RT_VIRTUAL_CORE_SIZE"] = "1" + subprocess.run([sys.executable, __file__, "--component=t5"], env=env, check=True) + + # Backbone + VAE (LNC=2) + env["NEURON_RT_VISIBLE_CORES"] = cores + env.pop("NEURON_RT_VIRTUAL_CORE_SIZE", None) + subprocess.run([sys.executable, __file__, "--component=backbone"], env=env, check=True) + subprocess.run([sys.executable, __file__, "--component=vae"], env=env, check=True) + elif args.component == "backbone": + trace_backbone() + elif args.component == "t5": + trace_t5() + elif args.component == "vae": + trace_vae() + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Wan2.1-T2V-1.3B/src/scripts/validate_backbone_cp8.py b/contrib/models/Wan2.1-T2V-1.3B/src/scripts/validate_backbone_cp8.py new file mode 100644 index 00000000..86ad9b25 --- /dev/null +++ b/contrib/models/Wan2.1-T2V-1.3B/src/scripts/validate_backbone_cp8.py @@ -0,0 +1,97 @@ +"""Validate CP=8 backbone: cosine vs CPU, then full denoising loop.""" +import torch, torch.nn.functional as F, os, sys, time, numpy as np +from PIL import Image + +sys.path.insert(0, "/mnt/work/wan/all_on_neuron_checkpt/contrib/models/wan2.1-t2v-1.3b/src") + +import neuronx_distributed +import torch_neuronx +from neuronx_distributed.parallel_layers.parallel_state import ( + initialize_model_parallel, get_tensor_model_parallel_rank, get_tensor_model_parallel_size, +) + +if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend="xla") +initialize_model_parallel(tensor_model_parallel_size=8) +rank = get_tensor_model_parallel_rank() + +# Load NEFFs +sd = "/mnt/work/wan/agent_artifacts/data/traced_wan_cp8_49f" +tr1 = torch.jit.load(f"{sd}/neff1_rank{rank}.pt") +tr2 = torch.jit.load(f"{sd}/neff2_rank{rank}.pt") +print(f"[Rank {rank}] NEFFs loaded") + +# Inputs +from diffusers.models.transformers.transformer_wan import WanTransformer3DModel +hf = WanTransformer3DModel.from_pretrained('/mnt/work/wan/agent_artifacts/data/transformer', torch_dtype=torch.bfloat16) + +latent = torch.randn(1, 16, 13, 60, 104, dtype=torch.bfloat16) +t = torch.tensor([999], dtype=torch.int64) +rc, rs = hf.rope(latent) + +# Load T5 embeddings +embeds = torch.load('/tmp/t5_embeds.pt') +pe = embeds['pe'] + +# CPU reference (single step) +if rank == 0: + from modeling_wan import WanBackboneWrapper + wrapper = WanBackboneWrapper(hf) + with torch.no_grad(): + cpu_out = wrapper(latent, t, pe, rc, rs) + print(f"[Rank 0] CPU output: {cpu_out.shape}") + +# Neuron forward (2 NEFFs) +r1 = tr1(latent, t, pe, rc, rs) +final_flat = tr2(*r1) + +# Unpatchify +b = 1; ppf, pph, ppw = 13, 30, 52 +p_t, p_h, p_w = 1, 2, 2 +out = final_flat.reshape(b, ppf, pph, ppw, p_t, p_h, p_w, -1).permute(0, 7, 1, 4, 2, 5, 3, 6) +neuron_out = out.flatten(6, 7).flatten(4, 5).flatten(2, 3) + +if rank == 0: + cos = F.cosine_similarity(cpu_out.flatten().float(), neuron_out.flatten().float(), dim=0) + print(f"[Rank 0] Cosine vs CPU: {cos.item():.6f}") + +# Full denoising loop +if rank == 0: + print(f"[Rank 0] Starting 20-step denoising...") + +from diffusers import UniPCMultistepScheduler +scheduler = UniPCMultistepScheduler.from_pretrained('/mnt/work/wan/agent_artifacts/data/scheduler') + +ne = embeds['ne'] +gen = torch.Generator("cpu").manual_seed(42) +latents = torch.randn(1, 16, 13, 60, 104, dtype=torch.float32, generator=gen) +scheduler.set_timesteps(20, device="cpu") + +t0 = time.time() +for si, tt in enumerate(scheduler.timesteps): + li = latents.to(torch.bfloat16) + ts = tt.expand(1) + + # CFG: cond + r1c = tr1(li, ts, pe, rc, rs) + fc = tr2(*r1c) + cond = fc.reshape(1, 13, 30, 52, 1, 2, 2, -1).permute(0, 7, 1, 4, 2, 5, 3, 6).flatten(6,7).flatten(4,5).flatten(2,3) + + # CFG: uncond + r1u = tr1(li, ts, ne, rc, rs) + fu = tr2(*r1u) + uncond = fu.reshape(1, 13, 30, 52, 1, 2, 2, -1).permute(0, 7, 1, 4, 2, 5, 3, 6).flatten(6,7).flatten(4,5).flatten(2,3) + + noise_pred = uncond + 5.0 * (cond - uncond) + latents = scheduler.step(noise_pred, tt, latents, return_dict=False)[0] + + if rank == 0 and si % 5 == 0: + print(f" step {si}/20") + +bb_time = time.time() - t0 +if rank == 0: + print(f"[Rank 0] Backbone: {bb_time:.1f}s") + torch.save(latents, "/tmp/latents_49f.pt") + print(f"[Rank 0] Latents saved to /tmp/latents_49f.pt") + +del hf diff --git a/contrib/models/Wan2.1-T2V-1.3B/test/__init__.py b/contrib/models/Wan2.1-T2V-1.3B/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Wan2.1-T2V-1.3B/test/integration/__init__.py b/contrib/models/Wan2.1-T2V-1.3B/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Wan2.1-T2V-1.3B/test/integration/test_model.py b/contrib/models/Wan2.1-T2V-1.3B/test/integration/test_model.py new file mode 100644 index 00000000..3a29b656 --- /dev/null +++ b/contrib/models/Wan2.1-T2V-1.3B/test/integration/test_model.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +""" +Integration tests for WAN 2.1 T2V 1.3B NeuronX implementation. +""" + +import pytest +import torch +import torch_neuronx # must import before torch.jit.load for Neuron models +import torch.nn.functional as F +import json +import os +import time +from pathlib import Path + +MODEL_PATH = os.environ.get("WAN_MODEL_DIR", "/home/ubuntu/models/wan2.1-t2v-1.3b/") +COMPILED_PATH = os.environ.get("WAN_COMPILED_DIR", "/home/ubuntu/neuron_models/wan2.1-t2v-1.3b/") + +import sys +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) +from modeling_wan import ( + WanBackboneWrapper, T5Wrapper, + ConvInCached, BlockCached, NormConvOutCached, NoCacheWrapper, + VAE_BLOCK_ORDER, make_decoder_xla_compatible, +) + + +@pytest.fixture(scope="module") +def vae(): + from diffusers import AutoencoderKLWan + v = AutoencoderKLWan.from_pretrained(f"{MODEL_PATH}/vae", torch_dtype=torch.bfloat16) + v.eval() + make_decoder_xla_compatible(v.decoder) + return v + + +@pytest.fixture(scope="module") +def traced_backbone(): + p = Path(COMPILED_PATH) / "traced_wan_480x832.pt" + if not p.exists(): + pytest.skip(f"Traced backbone not found: {p}") + return torch.jit.load(str(p)) + + +@pytest.fixture(scope="module") +def traced_vae_blocks(): + d = Path(COMPILED_PATH) / "vae_blocks_cached" + if not d.exists(): + pytest.skip(f"VAE blocks not found: {d}") + return {f.stem: torch.jit.load(str(f)) for f in sorted(d.glob("*.pt"))} + + +@pytest.fixture(scope="module") +def rope(): + from diffusers.models.transformers.transformer_wan import WanTransformer3DModel + m = WanTransformer3DModel.from_pretrained(f"{MODEL_PATH}/transformer", torch_dtype=torch.bfloat16) + rc, rs = m.rope(torch.randn(1, 16, 4, 60, 104, dtype=torch.bfloat16)) + del m + return rc, rs + + +def test_backbone_loads(traced_backbone): + """Traced backbone loads successfully.""" + assert traced_backbone is not None + + +def test_backbone_output_shape(traced_backbone, rope): + """Backbone produces correct output shape.""" + h = torch.randn(1, 16, 4, 60, 104, dtype=torch.bfloat16) + t = torch.tensor([999], dtype=torch.int64) + e = torch.randn(1, 512, 4096, dtype=torch.bfloat16) + rc, rs = rope + out = traced_backbone(h, t, e, rc, rs) + assert out.shape == (1, 16, 4, 60, 104) + + +def test_backbone_cosine_vs_cpu(traced_backbone, rope): + """Backbone output matches CPU reference (cosine > 0.999).""" + from diffusers.models.transformers.transformer_wan import WanTransformer3DModel + + model = WanTransformer3DModel.from_pretrained(f"{MODEL_PATH}/transformer", torch_dtype=torch.bfloat16) + model.eval() + wrapper = WanBackboneWrapper(model) + + h = torch.randn(1, 16, 4, 60, 104, dtype=torch.bfloat16) + t = torch.tensor([999], dtype=torch.int64) + e = torch.randn(1, 512, 4096, dtype=torch.bfloat16) + rc, rs = rope + + with torch.no_grad(): + cpu_out = wrapper(h, t, e, rc, rs) + neuron_out = traced_backbone(h, t, e, rc, rs) + + cos = F.cosine_similarity(cpu_out.flatten().float(), neuron_out.flatten().float(), dim=0) + assert cos.item() > 0.999, f"Backbone cosine {cos.item():.6f} < 0.999" + + +def test_vae_blocks_load(traced_vae_blocks): + """All 18 VAE blocks load.""" + assert len(traced_vae_blocks) == 18 + + +def test_vae_hybrid_decode_cosine(vae, traced_vae_blocks): + """Hybrid VAE decode matches CPU reference (cosine > 0.999).""" + z = torch.randn(1, 16, 4, 60, 104, dtype=torch.bfloat16) + lm = torch.tensor(vae.config.latents_mean).view(1, 16, 1, 1, 1).to(torch.bfloat16) + ls = 1.0 / torch.tensor(vae.config.latents_std).view(1, 16, 1, 1, 1).to(torch.bfloat16) + x = vae.post_quant_conv(z / ls + lm) + + # CPU reference + vae.clear_cache() + cpu_frames = [] + for i in range(4): + vae._conv_idx = [0] + kw = dict(first_chunk=True) if i == 0 else {} + with torch.no_grad(): + cpu_frames.append(vae.decoder(x[:, :, i:i+1, :, :], + feat_cache=vae._feat_map, feat_idx=vae._conv_idx, **kw)) + cpu_video = torch.cat(cpu_frames, dim=2) + + # Hybrid: CPU 0+1, Neuron 2+3 + vae.clear_cache() + frames = [] + for i in range(2): + vae._conv_idx = [0] + kw = dict(first_chunk=True) if i == 0 else {} + with torch.no_grad(): + frames.append(vae.decoder(x[:, :, i:i+1, :, :], + feat_cache=vae._feat_map, feat_idx=vae._conv_idx, **kw)) + + cache = [v.clone() if isinstance(v, torch.Tensor) else v for v in vae._feat_map] + for fi in range(2, 4): + h = x[:, :, fi:fi+1, :, :] + for name, cs, cc in VAE_BLOCK_ORDER: + c_in = tuple(cache[cs + i] for i in range(cc)) + result = traced_vae_blocks[name](h, *c_in) + if cc > 0: + h = result[0] + for i in range(cc): + cache[cs + i] = result[1 + i] + else: + h = result + frames.append(h) + + neuron_video = torch.cat(frames, dim=2) + cos = F.cosine_similarity(cpu_video.flatten().float(), neuron_video.flatten().float(), dim=0) + assert cos.item() > 0.999, f"VAE cosine {cos.item():.6f} < 0.999" + + +def test_vae_output_frame_count(vae, traced_vae_blocks): + """Hybrid decode produces 13 output frames from 4 latent frames.""" + z = torch.randn(1, 16, 4, 60, 104, dtype=torch.bfloat16) + lm = torch.tensor(vae.config.latents_mean).view(1, 16, 1, 1, 1).to(torch.bfloat16) + ls = 1.0 / torch.tensor(vae.config.latents_std).view(1, 16, 1, 1, 1).to(torch.bfloat16) + x = vae.post_quant_conv(z / ls + lm) + + vae.clear_cache() + frames = [] + for i in range(2): + vae._conv_idx = [0] + kw = dict(first_chunk=True) if i == 0 else {} + with torch.no_grad(): + frames.append(vae.decoder(x[:, :, i:i+1, :, :], + feat_cache=vae._feat_map, feat_idx=vae._conv_idx, **kw)) + + cache = [v.clone() if isinstance(v, torch.Tensor) else v for v in vae._feat_map] + for fi in range(2, 4): + h = x[:, :, fi:fi+1, :, :] + for name, cs, cc in VAE_BLOCK_ORDER: + c_in = tuple(cache[cs + i] for i in range(cc)) + result = traced_vae_blocks[name](h, *c_in) + if cc > 0: + h = result[0] + for i in range(cc): + cache[cs + i] = result[1 + i] + else: + h = result + frames.append(h) + + video = torch.cat(frames, dim=2) + assert video.shape == (1, 3, 13, 480, 832), f"Expected (1,3,13,480,832), got {video.shape}" + + +if __name__ == "__main__": + print("Run with: pytest", __file__, "--capture=tee-sys") diff --git a/contrib/models/Wan2.1-T2V-1.3B/test/unit/__init__.py b/contrib/models/Wan2.1-T2V-1.3B/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b