Skip to content
3 changes: 2 additions & 1 deletion diffsynth/diffusion/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from einops import repeat, reduce
from typing import Union
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
from ..core.device.npu_compatible_device import get_device_type
from ..utils.lora import GeneralLoRALoader
from ..models.model_loader import ModelPool
from ..utils.controlnet import ControlNetInput
Expand Down Expand Up @@ -61,7 +62,7 @@ class BasePipeline(torch.nn.Module):

def __init__(
self,
device="cuda", torch_dtype=torch.float16,
device=get_device_type(), torch_dtype=torch.float16,
height_division_factor=64, width_division_factor=64,
time_division_factor=None, time_division_remainder=None,
):
Expand Down
4 changes: 3 additions & 1 deletion diffsynth/models/dinov3_image_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
import torch

from ..core.device.npu_compatible_device import get_device_type


class DINOv3ImageEncoder(DINOv3ViTModel):
def __init__(self):
Expand Down Expand Up @@ -70,7 +72,7 @@ def __init__(self):
}
)

def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
inputs = self.processor(images=image, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
bool_masked_pos = None
Expand Down
11 changes: 6 additions & 5 deletions diffsynth/models/longcat_video_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn.functional as F
from einops import rearrange, repeat
from .wan_video_dit import flash_attention
from ..core.device.npu_compatible_device import get_device_type
from ..core.gradient import gradient_checkpoint_forward


Expand Down Expand Up @@ -373,7 +374,7 @@ def forward(self, x, t, latent_shape):
B, N, C = x.shape
T, _, _ = latent_shape

with amp.autocast('cuda', dtype=torch.float32):
with amp.autocast(get_device_type(), dtype=torch.float32):
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
x = self.linear(x)
Expand Down Expand Up @@ -583,7 +584,7 @@ def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return
T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.

# compute modulation params in fp32
with amp.autocast(device_type='cuda', dtype=torch.float32):
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
shift_msa, scale_msa, gate_msa, \
shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
Expand All @@ -602,7 +603,7 @@ def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return
else:
x_s = attn_outputs

with amp.autocast(device_type='cuda', dtype=torch.float32):
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype)

Expand All @@ -615,7 +616,7 @@ def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return
# ffn with modulation
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
x_s = self.ffn(x_m)
with amp.autocast(device_type='cuda', dtype=torch.float32):
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype)

Expand Down Expand Up @@ -797,7 +798,7 @@ def forward(

hidden_states = self.x_embedder(hidden_states) # [B, N, C]

with amp.autocast(device_type='cuda', dtype=torch.float32):
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]

encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]
Expand Down
2 changes: 1 addition & 1 deletion diffsynth/models/nexus_gen_ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def _sample(
is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache
is_compileable = is_compileable and not self.generation_config.disable_compile
if is_compileable and (
self.device.type == "cuda" or generation_config.compile_config._compile_all_devices
self.device.type in ["cuda", "npu"] or generation_config.compile_config._compile_all_devices
):
os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = self.get_compiled_call(generation_config.compile_config)
Expand Down
4 changes: 3 additions & 1 deletion diffsynth/models/siglip2_image_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
import torch

from diffsynth.core.device.npu_compatible_device import get_device_type


class Siglip2ImageEncoder(SiglipVisionTransformer):
def __init__(self):
Expand Down Expand Up @@ -47,7 +49,7 @@ def __init__(self):
}
)

def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"]
pixel_values = pixel_values.to(device=device, dtype=torch_dtype)
output_attentions = False
Expand Down
19 changes: 10 additions & 9 deletions diffsynth/models/step1x_text_encoder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import torch
from typing import Optional, Union
from .qwen_image_text_encoder import QwenImageTextEncoder
from ..core.device.npu_compatible_device import get_device_type, get_torch_device


class Step1xEditEmbedder(torch.nn.Module):
def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device="cuda"):
def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device=get_device_type()):
super().__init__()
self.max_length = max_length
self.dtype = dtype
Expand Down Expand Up @@ -77,13 +78,13 @@ def forward(self, caption, ref_images):
self.max_length,
self.model.config.hidden_size,
dtype=torch.bfloat16,
device=torch.cuda.current_device(),
device=get_torch_device().current_device(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using get_torch_device().current_device() will cause a crash on CPU-only systems. The get_torch_device function falls back to torch.cuda when the device is 'cpu', and torch.cuda.current_device() will then fail if no CUDA device is available. A simpler and more robust approach is to use get_device_type() directly, as it returns a device string ('cpu', 'cuda', 'npu') that is accepted by PyTorch tensor creation functions.

Suggested change
device=get_torch_device().current_device(),
device=get_device_type(),

)
masks = torch.zeros(
len(text_list),
self.max_length,
dtype=torch.long,
device=torch.cuda.current_device(),
device=get_torch_device().current_device(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using get_torch_device().current_device() will cause a crash on CPU-only systems. The get_torch_device function falls back to torch.cuda when the device is 'cpu', and torch.cuda.current_device() will then fail if no CUDA device is available. A simpler and more robust approach is to use get_device_type() directly, as it returns a device string ('cpu', 'cuda', 'npu') that is accepted by PyTorch tensor creation functions.

Suggested change
device=get_torch_device().current_device(),
device=get_device_type(),

)

def split_string(s):
Expand Down Expand Up @@ -158,7 +159,7 @@ def split_string(s):
else:
token_list.append(token_each)

new_txt_ids = torch.cat(token_list, dim=1).to("cuda")
new_txt_ids = torch.cat(token_list, dim=1).to(get_device_type())

new_txt_ids = new_txt_ids.to(old_inputs_ids.device)

Expand All @@ -167,15 +168,15 @@ def split_string(s):
inputs.input_ids = (
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
.unsqueeze(0)
.to("cuda")
.to(get_device_type())
)
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
inputs.attention_mask = (inputs.input_ids > 0).long().to(get_device_type())
outputs = self.model_forward(
self.model,
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
pixel_values=inputs.pixel_values.to("cuda"),
image_grid_thw=inputs.image_grid_thw.to("cuda"),
pixel_values=inputs.pixel_values.to(get_device_type()),
image_grid_thw=inputs.image_grid_thw.to(get_device_type()),
output_hidden_states=True,
)

Expand All @@ -188,7 +189,7 @@ def split_string(s):
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
(min(self.max_length, emb.shape[1] - 217)),
dtype=torch.long,
device=torch.cuda.current_device(),
device=get_torch_device().current_device(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using get_torch_device().current_device() will cause a crash on CPU-only systems. The get_torch_device function falls back to torch.cuda when the device is 'cpu', and torch.cuda.current_device() will then fail if no CUDA device is available. A simpler and more robust approach is to use get_device_type() directly, as it returns a device string ('cpu', 'cuda', 'npu') that is accepted by PyTorch tensor creation functions.

Suggested change
device=get_torch_device().current_device(),
device=get_device_type(),

)

return embs, masks
6 changes: 3 additions & 3 deletions diffsynth/models/z_image_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from torch.nn import RMSNorm
from ..core.attention import attention_forward
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type
from ..core.gradient import gradient_checkpoint_forward


Expand Down Expand Up @@ -40,7 +40,7 @@ def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):

@staticmethod
def timestep_embedding(t, dim, max_period=10000):
with torch.amp.autocast("cuda", enabled=False):
with torch.amp.autocast(get_device_type(), enabled=False):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
Expand Down Expand Up @@ -105,7 +105,7 @@ def forward(self, hidden_states, freqs_cis, attention_mask):

# Apply RoPE
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast("cuda", enabled=False):
with torch.amp.autocast(get_device_type(), enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
Expand Down
5 changes: 3 additions & 2 deletions diffsynth/pipelines/flux2_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from typing import Union, List, Optional, Tuple

from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
Expand All @@ -18,7 +19,7 @@

class Flux2ImagePipeline(BasePipeline):

def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
Expand All @@ -42,7 +43,7 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
vram_limit: float = None,
Expand Down
11 changes: 6 additions & 5 deletions diffsynth/pipelines/flux_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast

from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
Expand Down Expand Up @@ -55,7 +56,7 @@ def forward(self, conditionings: list[torch.Tensor], controlnet_inputs: list[Con

class FluxImagePipeline(BasePipeline):

def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
Expand Down Expand Up @@ -117,7 +118,7 @@ def enable_lora_merger(self):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"),
tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"),
Expand Down Expand Up @@ -377,7 +378,7 @@ def encode_prompt(
text_encoder_2,
prompt,
positive=True,
device="cuda",
device=get_device_type(),
t5_sequence_length=512,
):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
Expand Down Expand Up @@ -558,7 +559,7 @@ def encode_prompt(
text_encoder_2,
prompt,
positive=True,
device="cuda",
device=get_device_type(),
t5_sequence_length=512,
):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
Expand Down Expand Up @@ -793,7 +794,7 @@ def process(self, pipe: FluxImagePipeline, prompt_emb, text_ids, value_controlle


class InfinitYou(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__()
from facexlib.recognition import init_recognition_model
from insightface.app import FaceAnalysis
Expand Down
5 changes: 3 additions & 2 deletions diffsynth/pipelines/qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from math import prod

from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
Expand All @@ -22,7 +23,7 @@

class QwenImagePipeline(BasePipeline):

def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
Expand Down Expand Up @@ -60,7 +61,7 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
processor_config: ModelConfig = None,
Expand Down
7 changes: 4 additions & 3 deletions diffsynth/pipelines/wan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing_extensions import Literal
from transformers import Wav2Vec2Processor

from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
Expand All @@ -30,7 +31,7 @@

class WanVideoPipeline(BasePipeline):

def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
Expand Down Expand Up @@ -98,7 +99,7 @@ def enable_usp(self):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
audio_processor_config: ModelConfig = None,
Expand Down Expand Up @@ -960,7 +961,7 @@ def __init__(self):
onload_model_names=("vae",)
)

def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device=get_device_type()):
if mask_pixel_values is None:
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
else:
Expand Down
5 changes: 3 additions & 2 deletions diffsynth/pipelines/z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from typing import Union, List, Optional, Tuple, Iterable, Dict

from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..core.data.operators import ImageCropAndResize
Expand All @@ -25,7 +26,7 @@

class ZImagePipeline(BasePipeline):

def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
Expand Down Expand Up @@ -58,7 +59,7 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
vram_limit: float = None,
Expand Down
3 changes: 2 additions & 1 deletion diffsynth/utils/controlnet/annotator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing_extensions import Literal, TypeAlias

from diffsynth.core.device.npu_compatible_device import get_device_type

Processor_id: TypeAlias = Literal[
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
]

class Annotator:
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device=get_device_type(), skip_processor=False):
if not skip_processor:
if processor_id == "canny":
from controlnet_aux.processor import CannyDetector
Expand Down