From 7869802a2f15e5f4634725c123f86c63be62e11b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sita=20B=C3=A9r=C3=A9t=C3=A9?= Date: Thu, 15 Jan 2026 04:09:20 +0000 Subject: [PATCH 1/2] Return the Tensor when input_info.return_result_tensor is true --- .../longcat_image/longcat_image_runner.py | 7 ++++--- .../runners/qwen_image/qwen_image_runner.py | 19 ++++++++++--------- .../models/runners/z_image/z_image_runner.py | 7 ++++--- .../video_encoders/hf/longcat_image/vae.py | 2 +- .../video_encoders/hf/qwen_image/vae.py | 4 ++-- .../models/video_encoders/hf/z_image/vae.py | 3 +-- 6 files changed, 22 insertions(+), 20 deletions(-) diff --git a/lightx2v/models/runners/longcat_image/longcat_image_runner.py b/lightx2v/models/runners/longcat_image/longcat_image_runner.py index 7bda2bff..654e8b29 100755 --- a/lightx2v/models/runners/longcat_image/longcat_image_runner.py +++ b/lightx2v/models/runners/longcat_image/longcat_image_runner.py @@ -390,9 +390,10 @@ def run_pipeline(self, input_info): images = self.run_vae_decoder(latents) self.end_run() - image = images[0] - image.save(f"{input_info.save_result_path}") - logger.info(f"Image saved: {input_info.save_result_path}") + if not input_info.return_result_tensor: + image = images[0] + image.save(f"{input_info.save_result_path}") + logger.info(f"Image saved: {input_info.save_result_path}") del latents, generator torch_device_module.empty_cache() diff --git a/lightx2v/models/runners/qwen_image/qwen_image_runner.py b/lightx2v/models/runners/qwen_image/qwen_image_runner.py index c5b3e701..bffc9668 100755 --- a/lightx2v/models/runners/qwen_image/qwen_image_runner.py +++ b/lightx2v/models/runners/qwen_image/qwen_image_runner.py @@ -349,15 +349,16 @@ def run_pipeline(self, input_info): images = self.run_vae_decoder(latents) self.end_run() - if isinstance(images[0], list) and len(images[0]) > 1: - image_prefix = f"{input_info.save_result_path}".split(".")[0] - for idx, image in enumerate(images[0]): - image.save(f"{image_prefix}_{idx}.png") - logger.info(f"Image saved: {image_prefix}_{idx}.png") - else: - image = images[0] - image.save(f"{input_info.save_result_path}") - logger.info(f"Image saved: {input_info.save_result_path}") + if not input_info.return_result_tensor: + if isinstance(images[0], list) and len(images[0]) > 1: + image_prefix = f"{input_info.save_result_path}".split(".")[0] + for idx, image in enumerate(images[0]): + image.save(f"{image_prefix}_{idx}.png") + logger.info(f"Image saved: {image_prefix}_{idx}.png") + else: + image = images[0] + image.save(f"{input_info.save_result_path}") + logger.info(f"Image saved: {input_info.save_result_path}") del latents, generator torch_device_module.empty_cache() diff --git a/lightx2v/models/runners/z_image/z_image_runner.py b/lightx2v/models/runners/z_image/z_image_runner.py index 69568981..8e11e39d 100755 --- a/lightx2v/models/runners/z_image/z_image_runner.py +++ b/lightx2v/models/runners/z_image/z_image_runner.py @@ -325,9 +325,10 @@ def run_pipeline(self, input_info): images = self.run_vae_decoder(latents) self.end_run() - image = images[0] - image.save(f"{input_info.save_result_path}") - logger.info(f"Image saved: {input_info.save_result_path}") + if not input_info.return_result_tensor: + image = images[0] + image.save(f"{input_info.save_result_path}") + logger.info(f"Image saved: {input_info.save_result_path}") del latents, generator torch_device_module.empty_cache() diff --git a/lightx2v/models/video_encoders/hf/longcat_image/vae.py b/lightx2v/models/video_encoders/hf/longcat_image/vae.py index 5c4da0fa..7bbe9f50 100755 --- a/lightx2v/models/video_encoders/hf/longcat_image/vae.py +++ b/lightx2v/models/video_encoders/hf/longcat_image/vae.py @@ -106,7 +106,7 @@ def decode(self, latents, input_info): # Decode - latents is now [B, 16, H, W] images = self.model.decode(latents, return_dict=False)[0] - images = self.image_processor.postprocess(images, output_type="pil") + images = self.image_processor.postprocess(images, output_type="pt" if input_info.return_result_tensor else "pil") if self.cpu_offload: self.model.to(torch.device("cpu")) diff --git a/lightx2v/models/video_encoders/hf/qwen_image/vae.py b/lightx2v/models/video_encoders/hf/qwen_image/vae.py index 6ef7ccc4..13a64bf3 100755 --- a/lightx2v/models/video_encoders/hf/qwen_image/vae.py +++ b/lightx2v/models/video_encoders/hf/qwen_image/vae.py @@ -80,13 +80,13 @@ def decode(self, latents, input_info): latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w) image = self.model.decode(latents, return_dict=False)[0] # (b f) c 1 h w image = image.squeeze(2) - image = self.image_processor.postprocess(image, output_type="pil") + image = self.image_processor.postprocess(image, output_type="pt" if input_info.return_result_tensor else "pil") images = [] for bidx in range(b): images.append(image[bidx * f : (bidx + 1) * f]) else: images = self.model.decode(latents, return_dict=False)[0][:, :, 0] - images = self.image_processor.postprocess(images, output_type="pil") + images = self.image_processor.postprocess(images, output_type="pt" if input_info.return_result_tensor else "pil") if self.cpu_offload: self.model.to(torch.device("cpu")) torch_device_module.empty_cache() diff --git a/lightx2v/models/video_encoders/hf/z_image/vae.py b/lightx2v/models/video_encoders/hf/z_image/vae.py index fd74bb85..a4d871b3 100755 --- a/lightx2v/models/video_encoders/hf/z_image/vae.py +++ b/lightx2v/models/video_encoders/hf/z_image/vae.py @@ -67,8 +67,7 @@ def decode(self, latents, input_info): latents = (latents / scaling_factor) + shift_factor images = self.model.decode(latents, return_dict=False)[0] - images_postprocessed = self.image_processor.postprocess(images, output_type="pil") - images = images_postprocessed + images = self.image_processor.postprocess(images, output_type="pt" if input_info.return_result_tensor else "pil") if self.cpu_offload: self.model.to(torch.device("cpu")) torch.cuda.empty_cache() From 9bbd317096f25c27126fec6178ab12e8c1b9b70d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sita=20B=C3=A9r=C3=A9t=C3=A9?= Date: Thu, 15 Jan 2026 17:07:39 +0000 Subject: [PATCH 2/2] add return_result_tensor to T2IInputInfo and I2IInputInfo --- lightx2v/utils/input_info.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/lightx2v/utils/input_info.py b/lightx2v/utils/input_info.py index 5f003142..3b98e7fc 100755 --- a/lightx2v/utils/input_info.py +++ b/lightx2v/utils/input_info.py @@ -123,6 +123,7 @@ class T2IInputInfo: prompt: str = field(default_factory=str) negative_prompt: str = field(default_factory=str) save_result_path: str = field(default_factory=str) + return_result_tensor: bool = field(default_factory=lambda: False) # shape related target_shape: list = field(default_factory=list) image_shapes: list = field(default_factory=list) @@ -137,6 +138,7 @@ class I2IInputInfo: negative_prompt: str = field(default_factory=str) image_path: str = field(default_factory=str) save_result_path: str = field(default_factory=str) + return_result_tensor: bool = field(default_factory=lambda: False) # shape related target_shape: list = field(default_factory=list) image_shapes: list = field(default_factory=list) @@ -215,10 +217,23 @@ def set_input_info(args): return_result_tensor=args.return_result_tensor, ) elif args.task == "t2i": - input_info = T2IInputInfo(seed=args.seed, prompt=args.prompt, negative_prompt=args.negative_prompt, save_result_path=args.save_result_path, aspect_ratio=args.aspect_ratio) + input_info = T2IInputInfo( + seed=args.seed, + prompt=args.prompt, + negative_prompt=args.negative_prompt, + save_result_path=args.save_result_path, + aspect_ratio=args.aspect_ratio, + return_result_tensor=args.return_result_tensor, + ) elif args.task == "i2i": input_info = I2IInputInfo( - seed=args.seed, prompt=args.prompt, negative_prompt=args.negative_prompt, image_path=args.image_path, save_result_path=args.save_result_path, aspect_ratio=args.aspect_ratio + seed=args.seed, + prompt=args.prompt, + negative_prompt=args.negative_prompt, + image_path=args.image_path, + save_result_path=args.save_result_path, + aspect_ratio=args.aspect_ratio, + return_result_tensor=args.return_result_tensor, ) else: raise ValueError(f"Unsupported task: {args.task}")