diff --git a/lightx2v/models/runners/longcat_image/longcat_image_runner.py b/lightx2v/models/runners/longcat_image/longcat_image_runner.py index 2103f81c..55462159 100755 --- a/lightx2v/models/runners/longcat_image/longcat_image_runner.py +++ b/lightx2v/models/runners/longcat_image/longcat_image_runner.py @@ -397,9 +397,10 @@ def run_pipeline(self, input_info): images = self.run_vae_decoder(latents) self.end_run() - image = images[0] - image.save(f"{input_info.save_result_path}") - logger.info(f"Image saved: {input_info.save_result_path}") + if not input_info.return_result_tensor: + image = images[0] + image.save(f"{input_info.save_result_path}") + logger.info(f"Image saved: {input_info.save_result_path}") del latents, generator torch_device_module.empty_cache() diff --git a/lightx2v/models/runners/qwen_image/qwen_image_runner.py b/lightx2v/models/runners/qwen_image/qwen_image_runner.py index b677c9fa..8246b8b9 100755 --- a/lightx2v/models/runners/qwen_image/qwen_image_runner.py +++ b/lightx2v/models/runners/qwen_image/qwen_image_runner.py @@ -378,15 +378,16 @@ def run_pipeline(self, input_info): images = self.run_vae_decoder(latents) self.end_run() - if isinstance(images[0], list) and len(images[0]) > 1: - image_prefix = f"{input_info.save_result_path}".split(".")[0] - for idx, image in enumerate(images[0]): - image.save(f"{image_prefix}_{idx}.png") - logger.info(f"Image saved: {image_prefix}_{idx}.png") - else: - image = images[0] - image.save(f"{input_info.save_result_path}") - logger.info(f"Image saved: {input_info.save_result_path}") + if not input_info.return_result_tensor: + if isinstance(images[0], list) and len(images[0]) > 1: + image_prefix = f"{input_info.save_result_path}".split(".")[0] + for idx, image in enumerate(images[0]): + image.save(f"{image_prefix}_{idx}.png") + logger.info(f"Image saved: {image_prefix}_{idx}.png") + else: + image = images[0] + image.save(f"{input_info.save_result_path}") + logger.info(f"Image saved: {input_info.save_result_path}") del latents, generator torch_device_module.empty_cache() diff --git a/lightx2v/models/runners/z_image/z_image_runner.py b/lightx2v/models/runners/z_image/z_image_runner.py index c9d7d9e6..8959486d 100755 --- a/lightx2v/models/runners/z_image/z_image_runner.py +++ b/lightx2v/models/runners/z_image/z_image_runner.py @@ -342,9 +342,10 @@ def run_pipeline(self, input_info): images = self.run_vae_decoder(latents) self.end_run() - image = images[0] - image.save(f"{input_info.save_result_path}") - logger.info(f"Image saved: {input_info.save_result_path}") + if not input_info.return_result_tensor: + image = images[0] + image.save(f"{input_info.save_result_path}") + logger.info(f"Image saved: {input_info.save_result_path}") del latents, generator torch_device_module.empty_cache() diff --git a/lightx2v/models/video_encoders/hf/longcat_image/vae.py b/lightx2v/models/video_encoders/hf/longcat_image/vae.py index 5c4da0fa..7bbe9f50 100755 --- a/lightx2v/models/video_encoders/hf/longcat_image/vae.py +++ b/lightx2v/models/video_encoders/hf/longcat_image/vae.py @@ -106,7 +106,7 @@ def decode(self, latents, input_info): # Decode - latents is now [B, 16, H, W] images = self.model.decode(latents, return_dict=False)[0] - images = self.image_processor.postprocess(images, output_type="pil") + images = self.image_processor.postprocess(images, output_type="pt" if input_info.return_result_tensor else "pil") if self.cpu_offload: self.model.to(torch.device("cpu")) diff --git a/lightx2v/models/video_encoders/hf/qwen_image/vae.py b/lightx2v/models/video_encoders/hf/qwen_image/vae.py index 6ef7ccc4..13a64bf3 100755 --- a/lightx2v/models/video_encoders/hf/qwen_image/vae.py +++ b/lightx2v/models/video_encoders/hf/qwen_image/vae.py @@ -80,13 +80,13 @@ def decode(self, latents, input_info): latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w) image = self.model.decode(latents, return_dict=False)[0] # (b f) c 1 h w image = image.squeeze(2) - image = self.image_processor.postprocess(image, output_type="pil") + image = self.image_processor.postprocess(image, output_type="pt" if input_info.return_result_tensor else "pil") images = [] for bidx in range(b): images.append(image[bidx * f : (bidx + 1) * f]) else: images = self.model.decode(latents, return_dict=False)[0][:, :, 0] - images = self.image_processor.postprocess(images, output_type="pil") + images = self.image_processor.postprocess(images, output_type="pt" if input_info.return_result_tensor else "pil") if self.cpu_offload: self.model.to(torch.device("cpu")) torch_device_module.empty_cache() diff --git a/lightx2v/models/video_encoders/hf/z_image/vae.py b/lightx2v/models/video_encoders/hf/z_image/vae.py index fd74bb85..a4d871b3 100755 --- a/lightx2v/models/video_encoders/hf/z_image/vae.py +++ b/lightx2v/models/video_encoders/hf/z_image/vae.py @@ -67,8 +67,7 @@ def decode(self, latents, input_info): latents = (latents / scaling_factor) + shift_factor images = self.model.decode(latents, return_dict=False)[0] - images_postprocessed = self.image_processor.postprocess(images, output_type="pil") - images = images_postprocessed + images = self.image_processor.postprocess(images, output_type="pt" if input_info.return_result_tensor else "pil") if self.cpu_offload: self.model.to(torch.device("cpu")) torch.cuda.empty_cache() diff --git a/lightx2v/utils/input_info.py b/lightx2v/utils/input_info.py index b796dfb3..8a146457 100755 --- a/lightx2v/utils/input_info.py +++ b/lightx2v/utils/input_info.py @@ -129,6 +129,7 @@ class T2IInputInfo: prompt: str = field(default_factory=str) negative_prompt: str = field(default_factory=str) save_result_path: str = field(default_factory=str) + return_result_tensor: bool = field(default_factory=lambda: False) # shape related resize_mode: str = field(default_factory=str) target_shape: list = field(default_factory=list) @@ -144,6 +145,7 @@ class I2IInputInfo: negative_prompt: str = field(default_factory=str) image_path: str = field(default_factory=str) save_result_path: str = field(default_factory=str) + return_result_tensor: bool = field(default_factory=lambda: False) # shape related resize_mode: str = field(default_factory=str) target_shape: list = field(default_factory=list)