From 19ce3048c1dd7e7a64db3d0d6908f08f2cf9c70a Mon Sep 17 00:00:00 2001 From: feng0w0 Date: Fri, 9 Jan 2026 18:06:41 +0800 Subject: [PATCH 1/7] [model][NPU]:Wan model rope use torch.complex64 in NPU --- diffsynth/models/wan_video_dit.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index daafa7a68..43cd601e6 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -5,6 +5,8 @@ from typing import Tuple, Optional from einops import rearrange from .wan_video_camera_controller import SimpleAdapter +from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE + try: import flash_attn_interface FLASH_ATTN_3_AVAILABLE = True @@ -92,6 +94,7 @@ def rope_apply(x, freqs, num_heads): x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) x_out = torch.view_as_complex(x.to(torch.float64).reshape( x.shape[0], x.shape[1], x.shape[2], -1, 2)) + freqs = freqs.to(torch.complex64) if IS_NPU_AVAILABLE else freqs x_out = torch.view_as_real(x_out * freqs).flatten(2) return x_out.to(x.dtype) From 3b662da31e49e2cf9196d8608f7ba0c6c71875ec Mon Sep 17 00:00:00 2001 From: feng0w0 Date: Fri, 9 Jan 2026 18:11:40 +0800 Subject: [PATCH 2/7] [model][NPU]:Wan model rope use torch.complex64 in NPU --- diffsynth/utils/xfuser/xdit_context_parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index b7fa72d92..d365cfe3b 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -5,7 +5,7 @@ get_sequence_parallel_world_size, get_sp_group) from xfuser.core.long_ctx_attention import xFuserLongContextAttention -from ...core.device import parse_nccl_backend, parse_device_type +from ...core.device import parse_nccl_backend, parse_device_type, IS_NPU_AVAILABLE def initialize_usp(device_type): @@ -50,7 +50,7 @@ def rope_apply(x, freqs, num_heads): sp_rank = get_sequence_parallel_rank() freqs = pad_freqs(freqs, s_per_rank * sp_size) freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] - + freqs_rank = freqs_rank.to(torch.complex64) if IS_NPU_AVAILABLE else freqs_rank x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) return x_out.to(x.dtype) From a3c2744a4370d5159000eef56b9460d89e68776a Mon Sep 17 00:00:00 2001 From: feng0w0 Date: Thu, 15 Jan 2026 20:04:54 +0800 Subject: [PATCH 3/7] [NPU]:Replace 'cuda' in the project with abstract interfaces --- diffsynth/core/__init__.py | 1 + diffsynth/core/npu_patch/__init__.py | 5 +++++ .../core/npu_patch/npu_autocast_patch.py | 21 +++++++++++++++++++ diffsynth/diffusion/base_pipeline.py | 3 ++- diffsynth/models/dinov3_image_encoder.py | 4 +++- diffsynth/models/longcat_video_dit.py | 21 ++++++++++++++----- diffsynth/models/nexus_gen_ar_model.py | 2 +- diffsynth/models/siglip2_image_encoder.py | 4 +++- diffsynth/models/step1x_text_encoder.py | 19 +++++++++-------- diffsynth/models/z_image_dit.py | 6 +++--- diffsynth/pipelines/flux2_image.py | 5 +++-- diffsynth/pipelines/flux_image.py | 11 +++++----- diffsynth/pipelines/qwen_image.py | 5 +++-- diffsynth/pipelines/wan_video.py | 7 ++++--- diffsynth/pipelines/z_image.py | 5 +++-- diffsynth/utils/controlnet/annotator.py | 3 ++- 16 files changed, 86 insertions(+), 36 deletions(-) create mode 100644 diffsynth/core/npu_patch/__init__.py create mode 100644 diffsynth/core/npu_patch/npu_autocast_patch.py diff --git a/diffsynth/core/__init__.py b/diffsynth/core/__init__.py index 6c0a6c877..4d5f440d3 100644 --- a/diffsynth/core/__init__.py +++ b/diffsynth/core/__init__.py @@ -4,3 +4,4 @@ from .loader import * from .vram import * from .device import * +from .npu_patch import * diff --git a/diffsynth/core/npu_patch/__init__.py b/diffsynth/core/npu_patch/__init__.py new file mode 100644 index 000000000..eb1df930b --- /dev/null +++ b/diffsynth/core/npu_patch/__init__.py @@ -0,0 +1,5 @@ +from diffsynth.core.device.npu_compatible_device import IS_NPU_AVAILABLE +from .npu_autocast_patch import npu_autocast_patch + +if IS_NPU_AVAILABLE: + npu_autocast_patch() diff --git a/diffsynth/core/npu_patch/npu_autocast_patch.py b/diffsynth/core/npu_patch/npu_autocast_patch.py new file mode 100644 index 000000000..08b1caffa --- /dev/null +++ b/diffsynth/core/npu_patch/npu_autocast_patch.py @@ -0,0 +1,21 @@ +import torch +from contextlib import contextmanager + + +def npu_autocast_patch_wrapper(func): + @contextmanager + def wrapper(*args, **kwargs): + flag = False + if "npu" in args or ("device_type" in kwargs and kwargs["device_type"] == "npu"): + if torch.float32 in args or ("dtype" in kwargs and kwargs["dtype"] == torch.float32): + flag = True + with func(*args, **kwargs) as ctx: + if flag: + torch.npu.set_autocast_enabled(True) + yield ctx + return wrapper + + +def npu_autocast_patch(): + torch.amp.autocast = npu_autocast_patch_wrapper(torch.amp.autocast) + torch.autocast = npu_autocast_patch_wrapper(torch.autocast) diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py index 4fe155963..d4731fd18 100644 --- a/diffsynth/diffusion/base_pipeline.py +++ b/diffsynth/diffusion/base_pipeline.py @@ -4,6 +4,7 @@ from einops import repeat, reduce from typing import Union from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type +from ..core.device.npu_compatible_device import get_device_type from ..utils.lora import GeneralLoRALoader from ..models.model_loader import ModelPool from ..utils.controlnet import ControlNetInput @@ -61,7 +62,7 @@ class BasePipeline(torch.nn.Module): def __init__( self, - device="cuda", torch_dtype=torch.float16, + device=get_device_type(), torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64, time_division_factor=None, time_division_remainder=None, ): diff --git a/diffsynth/models/dinov3_image_encoder.py b/diffsynth/models/dinov3_image_encoder.py index be2ee5876..c394a0315 100644 --- a/diffsynth/models/dinov3_image_encoder.py +++ b/diffsynth/models/dinov3_image_encoder.py @@ -2,6 +2,8 @@ from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig import torch +from ..core.device.npu_compatible_device import get_device_type + class DINOv3ImageEncoder(DINOv3ViTModel): def __init__(self): @@ -70,7 +72,7 @@ def __init__(self): } ) - def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): + def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()): inputs = self.processor(images=image, return_tensors="pt") pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device) bool_masked_pos = None diff --git a/diffsynth/models/longcat_video_dit.py b/diffsynth/models/longcat_video_dit.py index 6d6572387..ebcc9d094 100644 --- a/diffsynth/models/longcat_video_dit.py +++ b/diffsynth/models/longcat_video_dit.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from einops import rearrange, repeat from .wan_video_dit import flash_attention +from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type from ..core.gradient import gradient_checkpoint_forward @@ -373,7 +374,9 @@ def forward(self, x, t, latent_shape): B, N, C = x.shape T, _, _ = latent_shape - with amp.autocast('cuda', dtype=torch.float32): + with amp.autocast(get_device_type(), dtype=torch.float32): + if IS_NPU_AVAILABLE: + torch.npu.set_autocast_enabled(True) shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C] x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C) x = self.linear(x) @@ -583,7 +586,9 @@ def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W. # compute modulation params in fp32 - with amp.autocast(device_type='cuda', dtype=torch.float32): + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): + if IS_NPU_AVAILABLE: + torch.npu.set_autocast_enabled(True) shift_msa, scale_msa, gate_msa, \ shift_mlp, scale_mlp, gate_mlp = \ self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C] @@ -602,7 +607,9 @@ def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return else: x_s = attn_outputs - with amp.autocast(device_type='cuda', dtype=torch.float32): + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): + if IS_NPU_AVAILABLE: + torch.npu.set_autocast_enabled(True) x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] x = x.to(x_dtype) @@ -615,7 +622,9 @@ def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return # ffn with modulation x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C) x_s = self.ffn(x_m) - with amp.autocast(device_type='cuda', dtype=torch.float32): + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): + if IS_NPU_AVAILABLE: + torch.npu.set_autocast_enabled(True) x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] x = x.to(x_dtype) @@ -797,7 +806,9 @@ def forward( hidden_states = self.x_embedder(hidden_states) # [B, N, C] - with amp.autocast(device_type='cuda', dtype=torch.float32): + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): + if IS_NPU_AVAILABLE: + torch.npu.set_autocast_enabled(True) t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t] encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C] diff --git a/diffsynth/models/nexus_gen_ar_model.py b/diffsynth/models/nexus_gen_ar_model.py index d5a29735e..b647786aa 100644 --- a/diffsynth/models/nexus_gen_ar_model.py +++ b/diffsynth/models/nexus_gen_ar_model.py @@ -583,7 +583,7 @@ def _sample( is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache is_compileable = is_compileable and not self.generation_config.disable_compile if is_compileable and ( - self.device.type == "cuda" or generation_config.compile_config._compile_all_devices + self.device.type in ["cuda", "npu"] or generation_config.compile_config._compile_all_devices ): os.environ["TOKENIZERS_PARALLELISM"] = "0" model_forward = self.get_compiled_call(generation_config.compile_config) diff --git a/diffsynth/models/siglip2_image_encoder.py b/diffsynth/models/siglip2_image_encoder.py index 87df85561..509eff409 100644 --- a/diffsynth/models/siglip2_image_encoder.py +++ b/diffsynth/models/siglip2_image_encoder.py @@ -2,6 +2,8 @@ from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast import torch +from diffsynth.core.device.npu_compatible_device import get_device_type + class Siglip2ImageEncoder(SiglipVisionTransformer): def __init__(self): @@ -47,7 +49,7 @@ def __init__(self): } ) - def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): + def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()): pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"] pixel_values = pixel_values.to(device=device, dtype=torch_dtype) output_attentions = False diff --git a/diffsynth/models/step1x_text_encoder.py b/diffsynth/models/step1x_text_encoder.py index d0fe22157..5d144236a 100644 --- a/diffsynth/models/step1x_text_encoder.py +++ b/diffsynth/models/step1x_text_encoder.py @@ -1,10 +1,11 @@ import torch from typing import Optional, Union from .qwen_image_text_encoder import QwenImageTextEncoder +from ..core.device.npu_compatible_device import get_device_type, get_torch_device class Step1xEditEmbedder(torch.nn.Module): - def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device="cuda"): + def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device=get_device_type()): super().__init__() self.max_length = max_length self.dtype = dtype @@ -77,13 +78,13 @@ def forward(self, caption, ref_images): self.max_length, self.model.config.hidden_size, dtype=torch.bfloat16, - device=torch.cuda.current_device(), + device=get_torch_device().current_device(), ) masks = torch.zeros( len(text_list), self.max_length, dtype=torch.long, - device=torch.cuda.current_device(), + device=get_torch_device().current_device(), ) def split_string(s): @@ -158,7 +159,7 @@ def split_string(s): else: token_list.append(token_each) - new_txt_ids = torch.cat(token_list, dim=1).to("cuda") + new_txt_ids = torch.cat(token_list, dim=1).to(get_device_type()) new_txt_ids = new_txt_ids.to(old_inputs_ids.device) @@ -167,15 +168,15 @@ def split_string(s): inputs.input_ids = ( torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0) .unsqueeze(0) - .to("cuda") + .to(get_device_type()) ) - inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda") + inputs.attention_mask = (inputs.input_ids > 0).long().to(get_device_type()) outputs = self.model_forward( self.model, input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, - pixel_values=inputs.pixel_values.to("cuda"), - image_grid_thw=inputs.image_grid_thw.to("cuda"), + pixel_values=inputs.pixel_values.to(get_device_type()), + image_grid_thw=inputs.image_grid_thw.to(get_device_type()), output_hidden_states=True, ) @@ -188,7 +189,7 @@ def split_string(s): masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones( (min(self.max_length, emb.shape[1] - 217)), dtype=torch.long, - device=torch.cuda.current_device(), + device=get_torch_device().current_device(), ) return embs, masks diff --git a/diffsynth/models/z_image_dit.py b/diffsynth/models/z_image_dit.py index f157f38d2..bb4906751 100644 --- a/diffsynth/models/z_image_dit.py +++ b/diffsynth/models/z_image_dit.py @@ -8,7 +8,7 @@ from torch.nn import RMSNorm from ..core.attention import attention_forward -from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE +from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type from ..core.gradient import gradient_checkpoint_forward @@ -40,7 +40,7 @@ def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): @staticmethod def timestep_embedding(t, dim, max_period=10000): - with torch.amp.autocast("cuda", enabled=False): + with torch.amp.autocast(get_device_type(), enabled=False): half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half @@ -105,7 +105,7 @@ def forward(self, hidden_states, freqs_cis, attention_mask): # Apply RoPE def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - with torch.amp.autocast("cuda", enabled=False): + with torch.amp.autocast(get_device_type(), enabled=False): x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) freqs_cis = freqs_cis.unsqueeze(2) x_out = torch.view_as_real(x * freqs_cis).flatten(3) diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index 8b0046949..5ecbb20a1 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -6,6 +6,7 @@ import numpy as np from typing import Union, List, Optional, Tuple +from ..core.device.npu_compatible_device import get_device_type from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput @@ -18,7 +19,7 @@ class Flux2ImagePipeline(BasePipeline): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, @@ -42,7 +43,7 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16): @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, - device: Union[str, torch.device] = "cuda", + device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"), vram_limit: float = None, diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 1ee5635ee..bfc53e505 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -6,6 +6,7 @@ import numpy as np from transformers import CLIPTokenizer, T5TokenizerFast +from ..core.device.npu_compatible_device import get_device_type from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput @@ -55,7 +56,7 @@ def forward(self, conditionings: list[torch.Tensor], controlnet_inputs: list[Con class FluxImagePipeline(BasePipeline): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, @@ -117,7 +118,7 @@ def enable_lora_merger(self): @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, - device: Union[str, torch.device] = "cuda", + device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"), tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"), @@ -377,7 +378,7 @@ def encode_prompt( text_encoder_2, prompt, positive=True, - device="cuda", + device=get_device_type(), t5_sequence_length=512, ): pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device) @@ -558,7 +559,7 @@ def encode_prompt( text_encoder_2, prompt, positive=True, - device="cuda", + device=get_device_type(), t5_sequence_length=512, ): pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device) @@ -793,7 +794,7 @@ def process(self, pipe: FluxImagePipeline, prompt_emb, text_ids, value_controlle class InfinitYou(torch.nn.Module): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__() from facexlib.recognition import init_recognition_model from insightface.app import FaceAnalysis diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index 4bfa00ef0..75cfbee77 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -6,6 +6,7 @@ import numpy as np from math import prod +from ..core.device.npu_compatible_device import get_device_type from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput @@ -22,7 +23,7 @@ class QwenImagePipeline(BasePipeline): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, @@ -60,7 +61,7 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16): @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, - device: Union[str, torch.device] = "cuda", + device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), processor_config: ModelConfig = None, diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index ca59d2a00..866ac18e3 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -11,6 +11,7 @@ from typing_extensions import Literal from transformers import Wav2Vec2Processor +from ..core.device.npu_compatible_device import get_device_type from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward from ..diffusion.base_pipeline import BasePipeline, PipelineUnit @@ -30,7 +31,7 @@ class WanVideoPipeline(BasePipeline): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 @@ -98,7 +99,7 @@ def enable_usp(self): @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, - device: Union[str, torch.device] = "cuda", + device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), audio_processor_config: ModelConfig = None, @@ -960,7 +961,7 @@ def __init__(self): onload_model_names=("vae",) ) - def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"): + def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device=get_device_type()): if mask_pixel_values is None: msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device) else: diff --git a/diffsynth/pipelines/z_image.py b/diffsynth/pipelines/z_image.py index 9ba182ae5..2c5b68730 100644 --- a/diffsynth/pipelines/z_image.py +++ b/diffsynth/pipelines/z_image.py @@ -6,6 +6,7 @@ import numpy as np from typing import Union, List, Optional, Tuple, Iterable, Dict +from ..core.device.npu_compatible_device import get_device_type from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward from ..core.data.operators import ImageCropAndResize @@ -25,7 +26,7 @@ class ZImagePipeline(BasePipeline): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, @@ -58,7 +59,7 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16): @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, - device: Union[str, torch.device] = "cuda", + device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), vram_limit: float = None, diff --git a/diffsynth/utils/controlnet/annotator.py b/diffsynth/utils/controlnet/annotator.py index 06553e06d..cb737385f 100644 --- a/diffsynth/utils/controlnet/annotator.py +++ b/diffsynth/utils/controlnet/annotator.py @@ -1,12 +1,13 @@ from typing_extensions import Literal, TypeAlias +from diffsynth.core.device.npu_compatible_device import get_device_type Processor_id: TypeAlias = Literal[ "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint" ] class Annotator: - def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False): + def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device=get_device_type(), skip_processor=False): if not skip_processor: if processor_id == "canny": from controlnet_aux.processor import CannyDetector From 209a350c0f6dcaccfcd5be3c3b73e28b8813bed0 Mon Sep 17 00:00:00 2001 From: feng0w0 Date: Thu, 15 Jan 2026 20:33:01 +0800 Subject: [PATCH 4/7] [NPU]:Replace 'cuda' in the project with abstract interfaces --- diffsynth/core/__init__.py | 1 + diffsynth/core/npu_patch/__init__.py | 5 +++++ .../core/npu_patch/npu_autocast_patch.py | 21 +++++++++++++++++++ diffsynth/diffusion/base_pipeline.py | 3 ++- diffsynth/models/dinov3_image_encoder.py | 4 +++- diffsynth/models/longcat_video_dit.py | 21 ++++++++++++++----- diffsynth/models/nexus_gen_ar_model.py | 2 +- diffsynth/models/siglip2_image_encoder.py | 4 +++- diffsynth/models/step1x_text_encoder.py | 19 +++++++++-------- diffsynth/models/wan_video_dit.py | 1 - diffsynth/models/z_image_dit.py | 6 +++--- diffsynth/pipelines/flux2_image.py | 5 +++-- diffsynth/pipelines/flux_image.py | 11 +++++----- diffsynth/pipelines/qwen_image.py | 5 +++-- diffsynth/pipelines/wan_video.py | 7 ++++--- diffsynth/pipelines/z_image.py | 5 +++-- diffsynth/utils/controlnet/annotator.py | 3 ++- .../utils/xfuser/xdit_context_parallel.py | 1 - 18 files changed, 86 insertions(+), 38 deletions(-) create mode 100644 diffsynth/core/npu_patch/__init__.py create mode 100644 diffsynth/core/npu_patch/npu_autocast_patch.py diff --git a/diffsynth/core/__init__.py b/diffsynth/core/__init__.py index 6c0a6c877..4d5f440d3 100644 --- a/diffsynth/core/__init__.py +++ b/diffsynth/core/__init__.py @@ -4,3 +4,4 @@ from .loader import * from .vram import * from .device import * +from .npu_patch import * diff --git a/diffsynth/core/npu_patch/__init__.py b/diffsynth/core/npu_patch/__init__.py new file mode 100644 index 000000000..eb1df930b --- /dev/null +++ b/diffsynth/core/npu_patch/__init__.py @@ -0,0 +1,5 @@ +from diffsynth.core.device.npu_compatible_device import IS_NPU_AVAILABLE +from .npu_autocast_patch import npu_autocast_patch + +if IS_NPU_AVAILABLE: + npu_autocast_patch() diff --git a/diffsynth/core/npu_patch/npu_autocast_patch.py b/diffsynth/core/npu_patch/npu_autocast_patch.py new file mode 100644 index 000000000..08b1caffa --- /dev/null +++ b/diffsynth/core/npu_patch/npu_autocast_patch.py @@ -0,0 +1,21 @@ +import torch +from contextlib import contextmanager + + +def npu_autocast_patch_wrapper(func): + @contextmanager + def wrapper(*args, **kwargs): + flag = False + if "npu" in args or ("device_type" in kwargs and kwargs["device_type"] == "npu"): + if torch.float32 in args or ("dtype" in kwargs and kwargs["dtype"] == torch.float32): + flag = True + with func(*args, **kwargs) as ctx: + if flag: + torch.npu.set_autocast_enabled(True) + yield ctx + return wrapper + + +def npu_autocast_patch(): + torch.amp.autocast = npu_autocast_patch_wrapper(torch.amp.autocast) + torch.autocast = npu_autocast_patch_wrapper(torch.autocast) diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py index 4fe155963..d4731fd18 100644 --- a/diffsynth/diffusion/base_pipeline.py +++ b/diffsynth/diffusion/base_pipeline.py @@ -4,6 +4,7 @@ from einops import repeat, reduce from typing import Union from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type +from ..core.device.npu_compatible_device import get_device_type from ..utils.lora import GeneralLoRALoader from ..models.model_loader import ModelPool from ..utils.controlnet import ControlNetInput @@ -61,7 +62,7 @@ class BasePipeline(torch.nn.Module): def __init__( self, - device="cuda", torch_dtype=torch.float16, + device=get_device_type(), torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64, time_division_factor=None, time_division_remainder=None, ): diff --git a/diffsynth/models/dinov3_image_encoder.py b/diffsynth/models/dinov3_image_encoder.py index be2ee5876..c394a0315 100644 --- a/diffsynth/models/dinov3_image_encoder.py +++ b/diffsynth/models/dinov3_image_encoder.py @@ -2,6 +2,8 @@ from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig import torch +from ..core.device.npu_compatible_device import get_device_type + class DINOv3ImageEncoder(DINOv3ViTModel): def __init__(self): @@ -70,7 +72,7 @@ def __init__(self): } ) - def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): + def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()): inputs = self.processor(images=image, return_tensors="pt") pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device) bool_masked_pos = None diff --git a/diffsynth/models/longcat_video_dit.py b/diffsynth/models/longcat_video_dit.py index 6d6572387..ebcc9d094 100644 --- a/diffsynth/models/longcat_video_dit.py +++ b/diffsynth/models/longcat_video_dit.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from einops import rearrange, repeat from .wan_video_dit import flash_attention +from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type from ..core.gradient import gradient_checkpoint_forward @@ -373,7 +374,9 @@ def forward(self, x, t, latent_shape): B, N, C = x.shape T, _, _ = latent_shape - with amp.autocast('cuda', dtype=torch.float32): + with amp.autocast(get_device_type(), dtype=torch.float32): + if IS_NPU_AVAILABLE: + torch.npu.set_autocast_enabled(True) shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C] x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C) x = self.linear(x) @@ -583,7 +586,9 @@ def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W. # compute modulation params in fp32 - with amp.autocast(device_type='cuda', dtype=torch.float32): + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): + if IS_NPU_AVAILABLE: + torch.npu.set_autocast_enabled(True) shift_msa, scale_msa, gate_msa, \ shift_mlp, scale_mlp, gate_mlp = \ self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C] @@ -602,7 +607,9 @@ def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return else: x_s = attn_outputs - with amp.autocast(device_type='cuda', dtype=torch.float32): + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): + if IS_NPU_AVAILABLE: + torch.npu.set_autocast_enabled(True) x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] x = x.to(x_dtype) @@ -615,7 +622,9 @@ def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return # ffn with modulation x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C) x_s = self.ffn(x_m) - with amp.autocast(device_type='cuda', dtype=torch.float32): + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): + if IS_NPU_AVAILABLE: + torch.npu.set_autocast_enabled(True) x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] x = x.to(x_dtype) @@ -797,7 +806,9 @@ def forward( hidden_states = self.x_embedder(hidden_states) # [B, N, C] - with amp.autocast(device_type='cuda', dtype=torch.float32): + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): + if IS_NPU_AVAILABLE: + torch.npu.set_autocast_enabled(True) t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t] encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C] diff --git a/diffsynth/models/nexus_gen_ar_model.py b/diffsynth/models/nexus_gen_ar_model.py index d5a29735e..b647786aa 100644 --- a/diffsynth/models/nexus_gen_ar_model.py +++ b/diffsynth/models/nexus_gen_ar_model.py @@ -583,7 +583,7 @@ def _sample( is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache is_compileable = is_compileable and not self.generation_config.disable_compile if is_compileable and ( - self.device.type == "cuda" or generation_config.compile_config._compile_all_devices + self.device.type in ["cuda", "npu"] or generation_config.compile_config._compile_all_devices ): os.environ["TOKENIZERS_PARALLELISM"] = "0" model_forward = self.get_compiled_call(generation_config.compile_config) diff --git a/diffsynth/models/siglip2_image_encoder.py b/diffsynth/models/siglip2_image_encoder.py index 87df85561..509eff409 100644 --- a/diffsynth/models/siglip2_image_encoder.py +++ b/diffsynth/models/siglip2_image_encoder.py @@ -2,6 +2,8 @@ from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast import torch +from diffsynth.core.device.npu_compatible_device import get_device_type + class Siglip2ImageEncoder(SiglipVisionTransformer): def __init__(self): @@ -47,7 +49,7 @@ def __init__(self): } ) - def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): + def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()): pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"] pixel_values = pixel_values.to(device=device, dtype=torch_dtype) output_attentions = False diff --git a/diffsynth/models/step1x_text_encoder.py b/diffsynth/models/step1x_text_encoder.py index d0fe22157..5d144236a 100644 --- a/diffsynth/models/step1x_text_encoder.py +++ b/diffsynth/models/step1x_text_encoder.py @@ -1,10 +1,11 @@ import torch from typing import Optional, Union from .qwen_image_text_encoder import QwenImageTextEncoder +from ..core.device.npu_compatible_device import get_device_type, get_torch_device class Step1xEditEmbedder(torch.nn.Module): - def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device="cuda"): + def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device=get_device_type()): super().__init__() self.max_length = max_length self.dtype = dtype @@ -77,13 +78,13 @@ def forward(self, caption, ref_images): self.max_length, self.model.config.hidden_size, dtype=torch.bfloat16, - device=torch.cuda.current_device(), + device=get_torch_device().current_device(), ) masks = torch.zeros( len(text_list), self.max_length, dtype=torch.long, - device=torch.cuda.current_device(), + device=get_torch_device().current_device(), ) def split_string(s): @@ -158,7 +159,7 @@ def split_string(s): else: token_list.append(token_each) - new_txt_ids = torch.cat(token_list, dim=1).to("cuda") + new_txt_ids = torch.cat(token_list, dim=1).to(get_device_type()) new_txt_ids = new_txt_ids.to(old_inputs_ids.device) @@ -167,15 +168,15 @@ def split_string(s): inputs.input_ids = ( torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0) .unsqueeze(0) - .to("cuda") + .to(get_device_type()) ) - inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda") + inputs.attention_mask = (inputs.input_ids > 0).long().to(get_device_type()) outputs = self.model_forward( self.model, input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, - pixel_values=inputs.pixel_values.to("cuda"), - image_grid_thw=inputs.image_grid_thw.to("cuda"), + pixel_values=inputs.pixel_values.to(get_device_type()), + image_grid_thw=inputs.image_grid_thw.to(get_device_type()), output_hidden_states=True, ) @@ -188,7 +189,7 @@ def split_string(s): masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones( (min(self.max_length, emb.shape[1] - 217)), dtype=torch.long, - device=torch.cuda.current_device(), + device=get_torch_device().current_device(), ) return embs, masks diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 43cd601e6..cfee25869 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -94,7 +94,6 @@ def rope_apply(x, freqs, num_heads): x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) x_out = torch.view_as_complex(x.to(torch.float64).reshape( x.shape[0], x.shape[1], x.shape[2], -1, 2)) - freqs = freqs.to(torch.complex64) if IS_NPU_AVAILABLE else freqs x_out = torch.view_as_real(x_out * freqs).flatten(2) return x_out.to(x.dtype) diff --git a/diffsynth/models/z_image_dit.py b/diffsynth/models/z_image_dit.py index f157f38d2..bb4906751 100644 --- a/diffsynth/models/z_image_dit.py +++ b/diffsynth/models/z_image_dit.py @@ -8,7 +8,7 @@ from torch.nn import RMSNorm from ..core.attention import attention_forward -from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE +from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type from ..core.gradient import gradient_checkpoint_forward @@ -40,7 +40,7 @@ def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): @staticmethod def timestep_embedding(t, dim, max_period=10000): - with torch.amp.autocast("cuda", enabled=False): + with torch.amp.autocast(get_device_type(), enabled=False): half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half @@ -105,7 +105,7 @@ def forward(self, hidden_states, freqs_cis, attention_mask): # Apply RoPE def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - with torch.amp.autocast("cuda", enabled=False): + with torch.amp.autocast(get_device_type(), enabled=False): x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) freqs_cis = freqs_cis.unsqueeze(2) x_out = torch.view_as_real(x * freqs_cis).flatten(3) diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index 8b0046949..5ecbb20a1 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -6,6 +6,7 @@ import numpy as np from typing import Union, List, Optional, Tuple +from ..core.device.npu_compatible_device import get_device_type from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput @@ -18,7 +19,7 @@ class Flux2ImagePipeline(BasePipeline): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, @@ -42,7 +43,7 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16): @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, - device: Union[str, torch.device] = "cuda", + device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"), vram_limit: float = None, diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 1ee5635ee..bfc53e505 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -6,6 +6,7 @@ import numpy as np from transformers import CLIPTokenizer, T5TokenizerFast +from ..core.device.npu_compatible_device import get_device_type from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput @@ -55,7 +56,7 @@ def forward(self, conditionings: list[torch.Tensor], controlnet_inputs: list[Con class FluxImagePipeline(BasePipeline): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, @@ -117,7 +118,7 @@ def enable_lora_merger(self): @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, - device: Union[str, torch.device] = "cuda", + device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"), tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"), @@ -377,7 +378,7 @@ def encode_prompt( text_encoder_2, prompt, positive=True, - device="cuda", + device=get_device_type(), t5_sequence_length=512, ): pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device) @@ -558,7 +559,7 @@ def encode_prompt( text_encoder_2, prompt, positive=True, - device="cuda", + device=get_device_type(), t5_sequence_length=512, ): pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device) @@ -793,7 +794,7 @@ def process(self, pipe: FluxImagePipeline, prompt_emb, text_ids, value_controlle class InfinitYou(torch.nn.Module): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__() from facexlib.recognition import init_recognition_model from insightface.app import FaceAnalysis diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index 4bfa00ef0..75cfbee77 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -6,6 +6,7 @@ import numpy as np from math import prod +from ..core.device.npu_compatible_device import get_device_type from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput @@ -22,7 +23,7 @@ class QwenImagePipeline(BasePipeline): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, @@ -60,7 +61,7 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16): @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, - device: Union[str, torch.device] = "cuda", + device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), processor_config: ModelConfig = None, diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index ca59d2a00..866ac18e3 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -11,6 +11,7 @@ from typing_extensions import Literal from transformers import Wav2Vec2Processor +from ..core.device.npu_compatible_device import get_device_type from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward from ..diffusion.base_pipeline import BasePipeline, PipelineUnit @@ -30,7 +31,7 @@ class WanVideoPipeline(BasePipeline): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 @@ -98,7 +99,7 @@ def enable_usp(self): @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, - device: Union[str, torch.device] = "cuda", + device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), audio_processor_config: ModelConfig = None, @@ -960,7 +961,7 @@ def __init__(self): onload_model_names=("vae",) ) - def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"): + def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device=get_device_type()): if mask_pixel_values is None: msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device) else: diff --git a/diffsynth/pipelines/z_image.py b/diffsynth/pipelines/z_image.py index 9ba182ae5..2c5b68730 100644 --- a/diffsynth/pipelines/z_image.py +++ b/diffsynth/pipelines/z_image.py @@ -6,6 +6,7 @@ import numpy as np from typing import Union, List, Optional, Tuple, Iterable, Dict +from ..core.device.npu_compatible_device import get_device_type from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward from ..core.data.operators import ImageCropAndResize @@ -25,7 +26,7 @@ class ZImagePipeline(BasePipeline): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, @@ -58,7 +59,7 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16): @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, - device: Union[str, torch.device] = "cuda", + device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), vram_limit: float = None, diff --git a/diffsynth/utils/controlnet/annotator.py b/diffsynth/utils/controlnet/annotator.py index 06553e06d..cb737385f 100644 --- a/diffsynth/utils/controlnet/annotator.py +++ b/diffsynth/utils/controlnet/annotator.py @@ -1,12 +1,13 @@ from typing_extensions import Literal, TypeAlias +from diffsynth.core.device.npu_compatible_device import get_device_type Processor_id: TypeAlias = Literal[ "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint" ] class Annotator: - def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False): + def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device=get_device_type(), skip_processor=False): if not skip_processor: if processor_id == "canny": from controlnet_aux.processor import CannyDetector diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index d365cfe3b..4a1cd14a4 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -50,7 +50,6 @@ def rope_apply(x, freqs, num_heads): sp_rank = get_sequence_parallel_rank() freqs = pad_freqs(freqs, s_per_rank * sp_size) freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] - freqs_rank = freqs_rank.to(torch.complex64) if IS_NPU_AVAILABLE else freqs_rank x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) return x_out.to(x.dtype) From 5c0b07d939d2d2537af3b25b77349f79997cf8b5 Mon Sep 17 00:00:00 2001 From: feng0w0 Date: Thu, 15 Jan 2026 20:34:52 +0800 Subject: [PATCH 5/7] [NPU]:Replace 'cuda' in the project with abstract interfaces --- diffsynth/models/wan_video_dit.py | 1 - diffsynth/utils/xfuser/xdit_context_parallel.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index cfee25869..aca7d676f 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -5,7 +5,6 @@ from typing import Tuple, Optional from einops import rearrange from .wan_video_camera_controller import SimpleAdapter -from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE try: import flash_attn_interface diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index 4a1cd14a4..b7fa72d92 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -5,7 +5,7 @@ get_sequence_parallel_world_size, get_sp_group) from xfuser.core.long_ctx_attention import xFuserLongContextAttention -from ...core.device import parse_nccl_backend, parse_device_type, IS_NPU_AVAILABLE +from ...core.device import parse_nccl_backend, parse_device_type def initialize_usp(device_type): @@ -50,6 +50,7 @@ def rope_apply(x, freqs, num_heads): sp_rank = get_sequence_parallel_rank() freqs = pad_freqs(freqs, s_per_rank * sp_size) freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] + x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) return x_out.to(x.dtype) From dce77ec4d1ffd0123c54d75ea81647ec0746a0b6 Mon Sep 17 00:00:00 2001 From: feng0w0 Date: Thu, 15 Jan 2026 20:35:41 +0800 Subject: [PATCH 6/7] [NPU]:Replace 'cuda' in the project with abstract interfaces --- diffsynth/models/wan_video_dit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index aca7d676f..daafa7a68 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -5,7 +5,6 @@ from typing import Tuple, Optional from einops import rearrange from .wan_video_camera_controller import SimpleAdapter - try: import flash_attn_interface FLASH_ATTN_3_AVAILABLE = True From ad91d416018da6cefa4b15ba7e826a45aa93a472 Mon Sep 17 00:00:00 2001 From: feng0w0 Date: Fri, 16 Jan 2026 10:28:24 +0800 Subject: [PATCH 7/7] [NPU]:Replace 'cuda' in the project with abstract interfaces --- diffsynth/core/__init__.py | 1 - diffsynth/core/npu_patch/__init__.py | 5 ----- .../core/npu_patch/npu_autocast_patch.py | 21 ------------------- diffsynth/models/longcat_video_dit.py | 12 +---------- 4 files changed, 1 insertion(+), 38 deletions(-) delete mode 100644 diffsynth/core/npu_patch/__init__.py delete mode 100644 diffsynth/core/npu_patch/npu_autocast_patch.py diff --git a/diffsynth/core/__init__.py b/diffsynth/core/__init__.py index 4d5f440d3..6c0a6c877 100644 --- a/diffsynth/core/__init__.py +++ b/diffsynth/core/__init__.py @@ -4,4 +4,3 @@ from .loader import * from .vram import * from .device import * -from .npu_patch import * diff --git a/diffsynth/core/npu_patch/__init__.py b/diffsynth/core/npu_patch/__init__.py deleted file mode 100644 index eb1df930b..000000000 --- a/diffsynth/core/npu_patch/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from diffsynth.core.device.npu_compatible_device import IS_NPU_AVAILABLE -from .npu_autocast_patch import npu_autocast_patch - -if IS_NPU_AVAILABLE: - npu_autocast_patch() diff --git a/diffsynth/core/npu_patch/npu_autocast_patch.py b/diffsynth/core/npu_patch/npu_autocast_patch.py deleted file mode 100644 index 08b1caffa..000000000 --- a/diffsynth/core/npu_patch/npu_autocast_patch.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch -from contextlib import contextmanager - - -def npu_autocast_patch_wrapper(func): - @contextmanager - def wrapper(*args, **kwargs): - flag = False - if "npu" in args or ("device_type" in kwargs and kwargs["device_type"] == "npu"): - if torch.float32 in args or ("dtype" in kwargs and kwargs["dtype"] == torch.float32): - flag = True - with func(*args, **kwargs) as ctx: - if flag: - torch.npu.set_autocast_enabled(True) - yield ctx - return wrapper - - -def npu_autocast_patch(): - torch.amp.autocast = npu_autocast_patch_wrapper(torch.amp.autocast) - torch.autocast = npu_autocast_patch_wrapper(torch.autocast) diff --git a/diffsynth/models/longcat_video_dit.py b/diffsynth/models/longcat_video_dit.py index ebcc9d094..dbe1c21b9 100644 --- a/diffsynth/models/longcat_video_dit.py +++ b/diffsynth/models/longcat_video_dit.py @@ -9,7 +9,7 @@ import torch.nn.functional as F from einops import rearrange, repeat from .wan_video_dit import flash_attention -from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type +from ..core.device.npu_compatible_device import get_device_type from ..core.gradient import gradient_checkpoint_forward @@ -375,8 +375,6 @@ def forward(self, x, t, latent_shape): T, _, _ = latent_shape with amp.autocast(get_device_type(), dtype=torch.float32): - if IS_NPU_AVAILABLE: - torch.npu.set_autocast_enabled(True) shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C] x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C) x = self.linear(x) @@ -587,8 +585,6 @@ def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return # compute modulation params in fp32 with amp.autocast(device_type=get_device_type(), dtype=torch.float32): - if IS_NPU_AVAILABLE: - torch.npu.set_autocast_enabled(True) shift_msa, scale_msa, gate_msa, \ shift_mlp, scale_mlp, gate_mlp = \ self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C] @@ -608,8 +604,6 @@ def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return x_s = attn_outputs with amp.autocast(device_type=get_device_type(), dtype=torch.float32): - if IS_NPU_AVAILABLE: - torch.npu.set_autocast_enabled(True) x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] x = x.to(x_dtype) @@ -623,8 +617,6 @@ def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C) x_s = self.ffn(x_m) with amp.autocast(device_type=get_device_type(), dtype=torch.float32): - if IS_NPU_AVAILABLE: - torch.npu.set_autocast_enabled(True) x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] x = x.to(x_dtype) @@ -807,8 +799,6 @@ def forward( hidden_states = self.x_embedder(hidden_states) # [B, N, C] with amp.autocast(device_type=get_device_type(), dtype=torch.float32): - if IS_NPU_AVAILABLE: - torch.npu.set_autocast_enabled(True) t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t] encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]