Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions comfy/sampler_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,20 +122,21 @@ def estimate_memory(model, noise_shape, conds):
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
return memory_required, minimum_memory_required

def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, skip_load_model=False):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_prepare_sampling,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
)
return executor.execute(model, noise_shape, conds, model_options=model_options)
return executor.execute(model, noise_shape, conds, model_options=model_options, skip_load_model=skip_load_model)

def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, skip_load_model=False):
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
models_list = [model] if not skip_load_model else []
comfy.model_management.load_models_gpu(models_list + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
real_model = model.model

return real_model, conds, models
Expand Down
96 changes: 95 additions & 1 deletion comfy_extras/nodes_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,99 @@ def _group_process(cls, texts):
# ========== Training Dataset Nodes ==========


class ResolutionBucket(io.ComfyNode):
"""Bucket latents and conditions by resolution for efficient batch training."""

@classmethod
def define_schema(cls):
return io.Schema(
node_id="ResolutionBucket",
display_name="Resolution Bucket",
category="dataset",
is_experimental=True,
is_input_list=True,
inputs=[
io.Latent.Input(
"latents",
tooltip="List of latent dicts to bucket by resolution.",
),
io.Conditioning.Input(
"conditioning",
tooltip="List of conditioning lists (must match latents length).",
),
],
outputs=[
io.Latent.Output(
display_name="latents",
is_output_list=True,
tooltip="List of batched latent dicts, one per resolution bucket.",
),
io.Conditioning.Output(
display_name="conditioning",
is_output_list=True,
tooltip="List of condition lists, one per resolution bucket.",
),
],
)

@classmethod
def execute(cls, latents, conditioning):
# latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1
# conditioning: list[list[cond]]

# Validate lengths match
if len(latents) != len(conditioning):
raise ValueError(
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})."
)

# Flatten latents and conditions to individual samples
flat_latents = [] # list of (C, H, W) tensors
flat_conditions = [] # list of condition lists

for latent_dict, cond in zip(latents, conditioning):
samples = latent_dict["samples"] # (B, C, H, W)
batch_size = samples.shape[0]

# cond is a list of conditions with length == batch_size
for i in range(batch_size):
flat_latents.append(samples[i]) # (C, H, W)
flat_conditions.append(cond[i]) # single condition

# Group by resolution (H, W)
buckets = {} # (H, W) -> {"latents": list, "conditions": list}

for latent, cond in zip(flat_latents, flat_conditions):
# latent shape is (..., H, W) (B, C, H, W) or (B, T, C, H ,W)
h, w = latent.shape[-2], latent.shape[-1]
key = (h, w)

if key not in buckets:
buckets[key] = {"latents": [], "conditions": []}

buckets[key]["latents"].append(latent)
buckets[key]["conditions"].append(cond)

# Convert buckets to output format
output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, ..., H, W)
output_conditions = [] # list[list[cond]] where each inner list has Bi conditions

for (h, w), bucket_data in buckets.items():
# Stack latents into batch: list of (..., H, W) -> (Bi, ..., H, W)
stacked_latents = torch.stack(bucket_data["latents"], dim=0)
output_latents.append({"samples": stacked_latents})

# Conditions stay as list of condition lists
output_conditions.append(bucket_data["conditions"])

logging.info(
f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples"
)

logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples")
return io.NodeOutput(output_latents, output_conditions)


class MakeTrainingDataset(io.ComfyNode):
"""Encode images with VAE and texts with CLIP to create a training dataset."""

Expand Down Expand Up @@ -1373,7 +1466,7 @@ def execute(cls, folder_name):
shard_path = os.path.join(dataset_dir, shard_file)

with open(shard_path, "rb") as f:
shard_data = torch.load(f, weights_only=True)
shard_data = torch.load(f)

all_latents.extend(shard_data["latents"])
all_conditioning.extend(shard_data["conditioning"])
Expand Down Expand Up @@ -1425,6 +1518,7 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]:
MakeTrainingDataset,
SaveTrainingDataset,
LoadTrainingDataset,
ResolutionBucket,
]


Expand Down
11 changes: 6 additions & 5 deletions comfy_extras/nodes_post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,22 +221,23 @@ def define_schema(cls):
io.Image.Input("image"),
io.Combo.Input("upscale_method", options=cls.upscale_methods),
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
io.Int.Input("resolution_steps", default=1, min=1, max=256),
],
outputs=[
io.Image.Output(),
],
)

@classmethod
def execute(cls, image, upscale_method, megapixels) -> io.NodeOutput:
def execute(cls, image, upscale_method, megapixels, resolution_steps) -> io.NodeOutput:
samples = image.movedim(-1,1)
total = int(megapixels * 1024 * 1024)
total = megapixels * 1024 * 1024

scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
width = round(samples.shape[3] * scale_by / resolution_steps) * resolution_steps
height = round(samples.shape[2] * scale_by / resolution_steps) * resolution_steps

s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
s = comfy.utils.common_upscale(samples, int(width), int(height), upscale_method, "disabled")
s = s.movedim(1,-1)
return io.NodeOutput(s)

Expand Down
Loading