Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions lightx2v/models/runners/longcat_image/longcat_image_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,10 @@ def run_pipeline(self, input_info):
images = self.run_vae_decoder(latents)
self.end_run()

image = images[0]
image.save(f"{input_info.save_result_path}")
logger.info(f"Image saved: {input_info.save_result_path}")
if not input_info.return_result_tensor:
image = images[0]
image.save(f"{input_info.save_result_path}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The f-string formatting is redundant here as input_info.save_result_path is already a string. You can pass the variable directly to image.save for cleaner code.

Suggested change
image.save(f"{input_info.save_result_path}")
image.save(input_info.save_result_path)

logger.info(f"Image saved: {input_info.save_result_path}")

del latents, generator
torch_device_module.empty_cache()
Expand Down
19 changes: 10 additions & 9 deletions lightx2v/models/runners/qwen_image/qwen_image_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,15 +378,16 @@ def run_pipeline(self, input_info):
images = self.run_vae_decoder(latents)
self.end_run()

if isinstance(images[0], list) and len(images[0]) > 1:
image_prefix = f"{input_info.save_result_path}".split(".")[0]
for idx, image in enumerate(images[0]):
image.save(f"{image_prefix}_{idx}.png")
logger.info(f"Image saved: {image_prefix}_{idx}.png")
else:
image = images[0]
image.save(f"{input_info.save_result_path}")
logger.info(f"Image saved: {input_info.save_result_path}")
if not input_info.return_result_tensor:
if isinstance(images[0], list) and len(images[0]) > 1:
image_prefix = f"{input_info.save_result_path}".split(".")[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using .split('.') to remove a file extension is not robust as it will fail for file paths that contain dots in their directory names (e.g., /path.with.dots/image.png). It's safer to use rsplit('.', 1) to split only on the last dot. Also, the f-string is redundant.

A more robust way would be to use os.path.splitext, but that would require an import. Using rsplit is a good improvement without adding imports.

Suggested change
image_prefix = f"{input_info.save_result_path}".split(".")[0]
image_prefix = input_info.save_result_path.rsplit(".", 1)[0]

for idx, image in enumerate(images[0]):
image.save(f"{image_prefix}_{idx}.png")
logger.info(f"Image saved: {image_prefix}_{idx}.png")
else:
image = images[0]
image.save(f"{input_info.save_result_path}")
logger.info(f"Image saved: {input_info.save_result_path}")

del latents, generator
torch_device_module.empty_cache()
Expand Down
7 changes: 4 additions & 3 deletions lightx2v/models/runners/z_image/z_image_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,10 @@ def run_pipeline(self, input_info):
images = self.run_vae_decoder(latents)
self.end_run()

image = images[0]
image.save(f"{input_info.save_result_path}")
logger.info(f"Image saved: {input_info.save_result_path}")
if not input_info.return_result_tensor:
image = images[0]
image.save(f"{input_info.save_result_path}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The f-string formatting is redundant here since input_info.save_result_path is already a string. You can simplify the code by passing the variable directly to image.save.

Suggested change
image.save(f"{input_info.save_result_path}")
image.save(input_info.save_result_path)

logger.info(f"Image saved: {input_info.save_result_path}")

del latents, generator
torch_device_module.empty_cache()
Expand Down
2 changes: 1 addition & 1 deletion lightx2v/models/video_encoders/hf/longcat_image/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def decode(self, latents, input_info):

# Decode - latents is now [B, 16, H, W]
images = self.model.decode(latents, return_dict=False)[0]
images = self.image_processor.postprocess(images, output_type="pil")
images = self.image_processor.postprocess(images, output_type="pt" if input_info.return_result_tensor else "pil")

if self.cpu_offload:
self.model.to(torch.device("cpu"))
Expand Down
4 changes: 2 additions & 2 deletions lightx2v/models/video_encoders/hf/qwen_image/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ def decode(self, latents, input_info):
latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w)
image = self.model.decode(latents, return_dict=False)[0] # (b f) c 1 h w
image = image.squeeze(2)
image = self.image_processor.postprocess(image, output_type="pil")
image = self.image_processor.postprocess(image, output_type="pt" if input_info.return_result_tensor else "pil")
images = []
for bidx in range(b):
images.append(image[bidx * f : (bidx + 1) * f])
else:
images = self.model.decode(latents, return_dict=False)[0][:, :, 0]
images = self.image_processor.postprocess(images, output_type="pil")
images = self.image_processor.postprocess(images, output_type="pt" if input_info.return_result_tensor else "pil")
if self.cpu_offload:
self.model.to(torch.device("cpu"))
torch_device_module.empty_cache()
Expand Down
3 changes: 1 addition & 2 deletions lightx2v/models/video_encoders/hf/z_image/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ def decode(self, latents, input_info):
latents = (latents / scaling_factor) + shift_factor
images = self.model.decode(latents, return_dict=False)[0]

images_postprocessed = self.image_processor.postprocess(images, output_type="pil")
images = images_postprocessed
images = self.image_processor.postprocess(images, output_type="pt" if input_info.return_result_tensor else "pil")
if self.cpu_offload:
self.model.to(torch.device("cpu"))
torch.cuda.empty_cache()
Expand Down
2 changes: 2 additions & 0 deletions lightx2v/utils/input_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class T2IInputInfo:
prompt: str = field(default_factory=str)
negative_prompt: str = field(default_factory=str)
save_result_path: str = field(default_factory=str)
return_result_tensor: bool = field(default_factory=lambda: False)
# shape related
resize_mode: str = field(default_factory=str)
target_shape: list = field(default_factory=list)
Expand All @@ -144,6 +145,7 @@ class I2IInputInfo:
negative_prompt: str = field(default_factory=str)
image_path: str = field(default_factory=str)
save_result_path: str = field(default_factory=str)
return_result_tensor: bool = field(default_factory=lambda: False)
# shape related
resize_mode: str = field(default_factory=str)
target_shape: list = field(default_factory=list)
Expand Down