From 19ce3048c1dd7e7a64db3d0d6908f08f2cf9c70a Mon Sep 17 00:00:00 2001 From: feng0w0 Date: Fri, 9 Jan 2026 18:06:41 +0800 Subject: [PATCH 1/5] [model][NPU]:Wan model rope use torch.complex64 in NPU --- diffsynth/models/wan_video_dit.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index daafa7a68..43cd601e6 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -5,6 +5,8 @@ from typing import Tuple, Optional from einops import rearrange from .wan_video_camera_controller import SimpleAdapter +from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE + try: import flash_attn_interface FLASH_ATTN_3_AVAILABLE = True @@ -92,6 +94,7 @@ def rope_apply(x, freqs, num_heads): x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) x_out = torch.view_as_complex(x.to(torch.float64).reshape( x.shape[0], x.shape[1], x.shape[2], -1, 2)) + freqs = freqs.to(torch.complex64) if IS_NPU_AVAILABLE else freqs x_out = torch.view_as_real(x_out * freqs).flatten(2) return x_out.to(x.dtype) From 3b662da31e49e2cf9196d8608f7ba0c6c71875ec Mon Sep 17 00:00:00 2001 From: feng0w0 Date: Fri, 9 Jan 2026 18:11:40 +0800 Subject: [PATCH 2/5] [model][NPU]:Wan model rope use torch.complex64 in NPU --- diffsynth/utils/xfuser/xdit_context_parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index b7fa72d92..d365cfe3b 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -5,7 +5,7 @@ get_sequence_parallel_world_size, get_sp_group) from xfuser.core.long_ctx_attention import xFuserLongContextAttention -from ...core.device import parse_nccl_backend, parse_device_type +from ...core.device import parse_nccl_backend, parse_device_type, IS_NPU_AVAILABLE def initialize_usp(device_type): @@ -50,7 +50,7 @@ def rope_apply(x, freqs, num_heads): sp_rank = get_sequence_parallel_rank() freqs = pad_freqs(freqs, s_per_rank * sp_size) freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] - + freqs_rank = freqs_rank.to(torch.complex64) if IS_NPU_AVAILABLE else freqs_rank x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) return x_out.to(x.dtype) From 544c391936b6b9c301b99b070996f97a57217871 Mon Sep 17 00:00:00 2001 From: feng0w0 Date: Mon, 12 Jan 2026 11:24:11 +0800 Subject: [PATCH 3/5] [model][NPU]:Wan model rope use torch.complex64 in NPU --- docs/en/Pipeline_Usage/GPU_support.md | 2 +- docs/zh/Pipeline_Usage/GPU_support.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/en/Pipeline_Usage/GPU_support.md b/docs/en/Pipeline_Usage/GPU_support.md index 6c27de778..aba570649 100644 --- a/docs/en/Pipeline_Usage/GPU_support.md +++ b/docs/en/Pipeline_Usage/GPU_support.md @@ -59,7 +59,7 @@ save_video(video, "video.mp4", fps=15, quality=5) ``` ### Training -NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_scripts`, for example `examples/wanvideo/model_training/special/npu_scripts/Wan2.2-T2V-A14B-NPU.sh`. +NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_training`, for example `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`. In the NPU training scripts, NPU specific environment variables that can optimize performance have been added, and relevant parameters have been enabled for specific models. diff --git a/docs/zh/Pipeline_Usage/GPU_support.md b/docs/zh/Pipeline_Usage/GPU_support.md index b955f5600..8124147e2 100644 --- a/docs/zh/Pipeline_Usage/GPU_support.md +++ b/docs/zh/Pipeline_Usage/GPU_support.md @@ -59,7 +59,7 @@ save_video(video, "video.mp4", fps=15, quality=5) ``` ### 训练 -当前已为每类模型添加NPU的启动脚本样例,脚本存放在`examples/xxx/special/npu_scripts`目录下,例如 `examples/wanvideo/model_training/special/npu_scripts/Wan2.2-T2V-A14B-NPU.sh`。 +当前已为每类模型添加NPU的启动脚本样例,脚本存放在`examples/xxx/special/npu_training`目录下,例如 `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`。 在NPU训练脚本中,添加了可以优化性能的NPU特有环境变量,并针对特定模型开启了相关参数。 From 6be244233a5706d0cf7e0fc8f019566f8f0dca8f Mon Sep 17 00:00:00 2001 From: feng0w0 Date: Mon, 12 Jan 2026 11:34:41 +0800 Subject: [PATCH 4/5] [model][NPU]:Wan model rope use torch.complex64 in NPU --- diffsynth/core/device/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffsynth/core/device/__init__.py b/diffsynth/core/device/__init__.py index 8373471cf..889d6823a 100644 --- a/diffsynth/core/device/__init__.py +++ b/diffsynth/core/device/__init__.py @@ -1,2 +1,2 @@ from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name -from .npu_compatible_device import IS_NPU_AVAILABLE +from .npu_compatible_device import IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE From d16877e69548523f2ea23c4fff530bdd81b31cfa Mon Sep 17 00:00:00 2001 From: feng0w0 Date: Tue, 13 Jan 2026 11:17:51 +0800 Subject: [PATCH 5/5] [model][NPU]:Wan model rope use torch.complex64 in NPU --- diffsynth/models/wan_video_dit.py | 3 +-- diffsynth/utils/xfuser/xdit_context_parallel.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 43cd601e6..738622302 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -5,7 +5,6 @@ from typing import Tuple, Optional from einops import rearrange from .wan_video_camera_controller import SimpleAdapter -from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE try: import flash_attn_interface @@ -94,7 +93,7 @@ def rope_apply(x, freqs, num_heads): x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) x_out = torch.view_as_complex(x.to(torch.float64).reshape( x.shape[0], x.shape[1], x.shape[2], -1, 2)) - freqs = freqs.to(torch.complex64) if IS_NPU_AVAILABLE else freqs + freqs = freqs.to(torch.complex64) if freqs.device == "npu" else freqs x_out = torch.view_as_real(x_out * freqs).flatten(2) return x_out.to(x.dtype) diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index d365cfe3b..21dc3b33c 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -5,7 +5,7 @@ get_sequence_parallel_world_size, get_sp_group) from xfuser.core.long_ctx_attention import xFuserLongContextAttention -from ...core.device import parse_nccl_backend, parse_device_type, IS_NPU_AVAILABLE +from ...core.device import parse_nccl_backend, parse_device_type def initialize_usp(device_type): @@ -50,7 +50,7 @@ def rope_apply(x, freqs, num_heads): sp_rank = get_sequence_parallel_rank() freqs = pad_freqs(freqs, s_per_rank * sp_size) freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] - freqs_rank = freqs_rank.to(torch.complex64) if IS_NPU_AVAILABLE else freqs_rank + freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device == "npu" else freqs_rank x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) return x_out.to(x.dtype)