diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py index 4fe15596..d4731fd1 100644 --- a/diffsynth/diffusion/base_pipeline.py +++ b/diffsynth/diffusion/base_pipeline.py @@ -4,6 +4,7 @@ from einops import repeat, reduce from typing import Union from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type +from ..core.device.npu_compatible_device import get_device_type from ..utils.lora import GeneralLoRALoader from ..models.model_loader import ModelPool from ..utils.controlnet import ControlNetInput @@ -61,7 +62,7 @@ class BasePipeline(torch.nn.Module): def __init__( self, - device="cuda", torch_dtype=torch.float16, + device=get_device_type(), torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64, time_division_factor=None, time_division_remainder=None, ): diff --git a/diffsynth/models/dinov3_image_encoder.py b/diffsynth/models/dinov3_image_encoder.py index be2ee587..c394a031 100644 --- a/diffsynth/models/dinov3_image_encoder.py +++ b/diffsynth/models/dinov3_image_encoder.py @@ -2,6 +2,8 @@ from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig import torch +from ..core.device.npu_compatible_device import get_device_type + class DINOv3ImageEncoder(DINOv3ViTModel): def __init__(self): @@ -70,7 +72,7 @@ def __init__(self): } ) - def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): + def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()): inputs = self.processor(images=image, return_tensors="pt") pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device) bool_masked_pos = None diff --git a/diffsynth/models/longcat_video_dit.py b/diffsynth/models/longcat_video_dit.py index 6d657238..dbe1c21b 100644 --- a/diffsynth/models/longcat_video_dit.py +++ b/diffsynth/models/longcat_video_dit.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from einops import rearrange, repeat from .wan_video_dit import flash_attention +from ..core.device.npu_compatible_device import get_device_type from ..core.gradient import gradient_checkpoint_forward @@ -373,7 +374,7 @@ def forward(self, x, t, latent_shape): B, N, C = x.shape T, _, _ = latent_shape - with amp.autocast('cuda', dtype=torch.float32): + with amp.autocast(get_device_type(), dtype=torch.float32): shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C] x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C) x = self.linear(x) @@ -583,7 +584,7 @@ def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W. # compute modulation params in fp32 - with amp.autocast(device_type='cuda', dtype=torch.float32): + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): shift_msa, scale_msa, gate_msa, \ shift_mlp, scale_mlp, gate_mlp = \ self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C] @@ -602,7 +603,7 @@ def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return else: x_s = attn_outputs - with amp.autocast(device_type='cuda', dtype=torch.float32): + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] x = x.to(x_dtype) @@ -615,7 +616,7 @@ def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return # ffn with modulation x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C) x_s = self.ffn(x_m) - with amp.autocast(device_type='cuda', dtype=torch.float32): + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] x = x.to(x_dtype) @@ -797,7 +798,7 @@ def forward( hidden_states = self.x_embedder(hidden_states) # [B, N, C] - with amp.autocast(device_type='cuda', dtype=torch.float32): + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t] encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C] diff --git a/diffsynth/models/nexus_gen_ar_model.py b/diffsynth/models/nexus_gen_ar_model.py index d5a29735..b647786a 100644 --- a/diffsynth/models/nexus_gen_ar_model.py +++ b/diffsynth/models/nexus_gen_ar_model.py @@ -583,7 +583,7 @@ def _sample( is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache is_compileable = is_compileable and not self.generation_config.disable_compile if is_compileable and ( - self.device.type == "cuda" or generation_config.compile_config._compile_all_devices + self.device.type in ["cuda", "npu"] or generation_config.compile_config._compile_all_devices ): os.environ["TOKENIZERS_PARALLELISM"] = "0" model_forward = self.get_compiled_call(generation_config.compile_config) diff --git a/diffsynth/models/siglip2_image_encoder.py b/diffsynth/models/siglip2_image_encoder.py index 87df8556..509eff40 100644 --- a/diffsynth/models/siglip2_image_encoder.py +++ b/diffsynth/models/siglip2_image_encoder.py @@ -2,6 +2,8 @@ from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast import torch +from diffsynth.core.device.npu_compatible_device import get_device_type + class Siglip2ImageEncoder(SiglipVisionTransformer): def __init__(self): @@ -47,7 +49,7 @@ def __init__(self): } ) - def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): + def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()): pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"] pixel_values = pixel_values.to(device=device, dtype=torch_dtype) output_attentions = False diff --git a/diffsynth/models/step1x_text_encoder.py b/diffsynth/models/step1x_text_encoder.py index d0fe2215..5d144236 100644 --- a/diffsynth/models/step1x_text_encoder.py +++ b/diffsynth/models/step1x_text_encoder.py @@ -1,10 +1,11 @@ import torch from typing import Optional, Union from .qwen_image_text_encoder import QwenImageTextEncoder +from ..core.device.npu_compatible_device import get_device_type, get_torch_device class Step1xEditEmbedder(torch.nn.Module): - def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device="cuda"): + def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device=get_device_type()): super().__init__() self.max_length = max_length self.dtype = dtype @@ -77,13 +78,13 @@ def forward(self, caption, ref_images): self.max_length, self.model.config.hidden_size, dtype=torch.bfloat16, - device=torch.cuda.current_device(), + device=get_torch_device().current_device(), ) masks = torch.zeros( len(text_list), self.max_length, dtype=torch.long, - device=torch.cuda.current_device(), + device=get_torch_device().current_device(), ) def split_string(s): @@ -158,7 +159,7 @@ def split_string(s): else: token_list.append(token_each) - new_txt_ids = torch.cat(token_list, dim=1).to("cuda") + new_txt_ids = torch.cat(token_list, dim=1).to(get_device_type()) new_txt_ids = new_txt_ids.to(old_inputs_ids.device) @@ -167,15 +168,15 @@ def split_string(s): inputs.input_ids = ( torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0) .unsqueeze(0) - .to("cuda") + .to(get_device_type()) ) - inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda") + inputs.attention_mask = (inputs.input_ids > 0).long().to(get_device_type()) outputs = self.model_forward( self.model, input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, - pixel_values=inputs.pixel_values.to("cuda"), - image_grid_thw=inputs.image_grid_thw.to("cuda"), + pixel_values=inputs.pixel_values.to(get_device_type()), + image_grid_thw=inputs.image_grid_thw.to(get_device_type()), output_hidden_states=True, ) @@ -188,7 +189,7 @@ def split_string(s): masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones( (min(self.max_length, emb.shape[1] - 217)), dtype=torch.long, - device=torch.cuda.current_device(), + device=get_torch_device().current_device(), ) return embs, masks diff --git a/diffsynth/models/z_image_dit.py b/diffsynth/models/z_image_dit.py index f157f38d..bb490675 100644 --- a/diffsynth/models/z_image_dit.py +++ b/diffsynth/models/z_image_dit.py @@ -8,7 +8,7 @@ from torch.nn import RMSNorm from ..core.attention import attention_forward -from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE +from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type from ..core.gradient import gradient_checkpoint_forward @@ -40,7 +40,7 @@ def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): @staticmethod def timestep_embedding(t, dim, max_period=10000): - with torch.amp.autocast("cuda", enabled=False): + with torch.amp.autocast(get_device_type(), enabled=False): half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half @@ -105,7 +105,7 @@ def forward(self, hidden_states, freqs_cis, attention_mask): # Apply RoPE def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - with torch.amp.autocast("cuda", enabled=False): + with torch.amp.autocast(get_device_type(), enabled=False): x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) freqs_cis = freqs_cis.unsqueeze(2) x_out = torch.view_as_real(x * freqs_cis).flatten(3) diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index 8b004694..5ecbb20a 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -6,6 +6,7 @@ import numpy as np from typing import Union, List, Optional, Tuple +from ..core.device.npu_compatible_device import get_device_type from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput @@ -18,7 +19,7 @@ class Flux2ImagePipeline(BasePipeline): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, @@ -42,7 +43,7 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16): @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, - device: Union[str, torch.device] = "cuda", + device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"), vram_limit: float = None, diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 1ee5635e..bfc53e50 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -6,6 +6,7 @@ import numpy as np from transformers import CLIPTokenizer, T5TokenizerFast +from ..core.device.npu_compatible_device import get_device_type from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput @@ -55,7 +56,7 @@ def forward(self, conditionings: list[torch.Tensor], controlnet_inputs: list[Con class FluxImagePipeline(BasePipeline): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, @@ -117,7 +118,7 @@ def enable_lora_merger(self): @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, - device: Union[str, torch.device] = "cuda", + device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"), tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"), @@ -377,7 +378,7 @@ def encode_prompt( text_encoder_2, prompt, positive=True, - device="cuda", + device=get_device_type(), t5_sequence_length=512, ): pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device) @@ -558,7 +559,7 @@ def encode_prompt( text_encoder_2, prompt, positive=True, - device="cuda", + device=get_device_type(), t5_sequence_length=512, ): pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device) @@ -793,7 +794,7 @@ def process(self, pipe: FluxImagePipeline, prompt_emb, text_ids, value_controlle class InfinitYou(torch.nn.Module): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__() from facexlib.recognition import init_recognition_model from insightface.app import FaceAnalysis diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index 4bfa00ef..75cfbee7 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -6,6 +6,7 @@ import numpy as np from math import prod +from ..core.device.npu_compatible_device import get_device_type from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput @@ -22,7 +23,7 @@ class QwenImagePipeline(BasePipeline): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, @@ -60,7 +61,7 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16): @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, - device: Union[str, torch.device] = "cuda", + device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), processor_config: ModelConfig = None, diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index ca59d2a0..866ac18e 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -11,6 +11,7 @@ from typing_extensions import Literal from transformers import Wav2Vec2Processor +from ..core.device.npu_compatible_device import get_device_type from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward from ..diffusion.base_pipeline import BasePipeline, PipelineUnit @@ -30,7 +31,7 @@ class WanVideoPipeline(BasePipeline): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 @@ -98,7 +99,7 @@ def enable_usp(self): @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, - device: Union[str, torch.device] = "cuda", + device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), audio_processor_config: ModelConfig = None, @@ -960,7 +961,7 @@ def __init__(self): onload_model_names=("vae",) ) - def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"): + def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device=get_device_type()): if mask_pixel_values is None: msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device) else: diff --git a/diffsynth/pipelines/z_image.py b/diffsynth/pipelines/z_image.py index 9ba182ae..2c5b6873 100644 --- a/diffsynth/pipelines/z_image.py +++ b/diffsynth/pipelines/z_image.py @@ -6,6 +6,7 @@ import numpy as np from typing import Union, List, Optional, Tuple, Iterable, Dict +from ..core.device.npu_compatible_device import get_device_type from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward from ..core.data.operators import ImageCropAndResize @@ -25,7 +26,7 @@ class ZImagePipeline(BasePipeline): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, @@ -58,7 +59,7 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16): @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, - device: Union[str, torch.device] = "cuda", + device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), vram_limit: float = None, diff --git a/diffsynth/utils/controlnet/annotator.py b/diffsynth/utils/controlnet/annotator.py index 06553e06..cb737385 100644 --- a/diffsynth/utils/controlnet/annotator.py +++ b/diffsynth/utils/controlnet/annotator.py @@ -1,12 +1,13 @@ from typing_extensions import Literal, TypeAlias +from diffsynth.core.device.npu_compatible_device import get_device_type Processor_id: TypeAlias = Literal[ "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint" ] class Annotator: - def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False): + def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device=get_device_type(), skip_processor=False): if not skip_processor: if processor_id == "canny": from controlnet_aux.processor import CannyDetector