Skip to content

Commit 86dbb89

Browse files
Resolution bucketing and Trainer implementation refactoring (#11117)
1 parent ba6080b commit 86dbb89

File tree

4 files changed

+743
-237
lines changed

4 files changed

+743
-237
lines changed

comfy/sampler_helpers.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,20 +122,21 @@ def estimate_memory(model, noise_shape, conds):
122122
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
123123
return memory_required, minimum_memory_required
124124

125-
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
125+
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, skip_load_model=False):
126126
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
127127
_prepare_sampling,
128128
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
129129
)
130-
return executor.execute(model, noise_shape, conds, model_options=model_options)
130+
return executor.execute(model, noise_shape, conds, model_options=model_options, skip_load_model=skip_load_model)
131131

132-
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
132+
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, skip_load_model=False):
133133
real_model: BaseModel = None
134134
models, inference_memory = get_additional_models(conds, model.model_dtype())
135135
models += get_additional_models_from_model_options(model_options)
136136
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
137137
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
138-
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
138+
models_list = [model] if not skip_load_model else []
139+
comfy.model_management.load_models_gpu(models_list + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
139140
real_model = model.model
140141

141142
return real_model, conds, models

comfy_extras/nodes_dataset.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1125,6 +1125,99 @@ def _group_process(cls, texts):
11251125
# ========== Training Dataset Nodes ==========
11261126

11271127

1128+
class ResolutionBucket(io.ComfyNode):
1129+
"""Bucket latents and conditions by resolution for efficient batch training."""
1130+
1131+
@classmethod
1132+
def define_schema(cls):
1133+
return io.Schema(
1134+
node_id="ResolutionBucket",
1135+
display_name="Resolution Bucket",
1136+
category="dataset",
1137+
is_experimental=True,
1138+
is_input_list=True,
1139+
inputs=[
1140+
io.Latent.Input(
1141+
"latents",
1142+
tooltip="List of latent dicts to bucket by resolution.",
1143+
),
1144+
io.Conditioning.Input(
1145+
"conditioning",
1146+
tooltip="List of conditioning lists (must match latents length).",
1147+
),
1148+
],
1149+
outputs=[
1150+
io.Latent.Output(
1151+
display_name="latents",
1152+
is_output_list=True,
1153+
tooltip="List of batched latent dicts, one per resolution bucket.",
1154+
),
1155+
io.Conditioning.Output(
1156+
display_name="conditioning",
1157+
is_output_list=True,
1158+
tooltip="List of condition lists, one per resolution bucket.",
1159+
),
1160+
],
1161+
)
1162+
1163+
@classmethod
1164+
def execute(cls, latents, conditioning):
1165+
# latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1
1166+
# conditioning: list[list[cond]]
1167+
1168+
# Validate lengths match
1169+
if len(latents) != len(conditioning):
1170+
raise ValueError(
1171+
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})."
1172+
)
1173+
1174+
# Flatten latents and conditions to individual samples
1175+
flat_latents = [] # list of (C, H, W) tensors
1176+
flat_conditions = [] # list of condition lists
1177+
1178+
for latent_dict, cond in zip(latents, conditioning):
1179+
samples = latent_dict["samples"] # (B, C, H, W)
1180+
batch_size = samples.shape[0]
1181+
1182+
# cond is a list of conditions with length == batch_size
1183+
for i in range(batch_size):
1184+
flat_latents.append(samples[i]) # (C, H, W)
1185+
flat_conditions.append(cond[i]) # single condition
1186+
1187+
# Group by resolution (H, W)
1188+
buckets = {} # (H, W) -> {"latents": list, "conditions": list}
1189+
1190+
for latent, cond in zip(flat_latents, flat_conditions):
1191+
# latent shape is (..., H, W) (B, C, H, W) or (B, T, C, H ,W)
1192+
h, w = latent.shape[-2], latent.shape[-1]
1193+
key = (h, w)
1194+
1195+
if key not in buckets:
1196+
buckets[key] = {"latents": [], "conditions": []}
1197+
1198+
buckets[key]["latents"].append(latent)
1199+
buckets[key]["conditions"].append(cond)
1200+
1201+
# Convert buckets to output format
1202+
output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, ..., H, W)
1203+
output_conditions = [] # list[list[cond]] where each inner list has Bi conditions
1204+
1205+
for (h, w), bucket_data in buckets.items():
1206+
# Stack latents into batch: list of (..., H, W) -> (Bi, ..., H, W)
1207+
stacked_latents = torch.stack(bucket_data["latents"], dim=0)
1208+
output_latents.append({"samples": stacked_latents})
1209+
1210+
# Conditions stay as list of condition lists
1211+
output_conditions.append(bucket_data["conditions"])
1212+
1213+
logging.info(
1214+
f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples"
1215+
)
1216+
1217+
logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples")
1218+
return io.NodeOutput(output_latents, output_conditions)
1219+
1220+
11281221
class MakeTrainingDataset(io.ComfyNode):
11291222
"""Encode images with VAE and texts with CLIP to create a training dataset."""
11301223

@@ -1373,7 +1466,7 @@ def execute(cls, folder_name):
13731466
shard_path = os.path.join(dataset_dir, shard_file)
13741467

13751468
with open(shard_path, "rb") as f:
1376-
shard_data = torch.load(f, weights_only=True)
1469+
shard_data = torch.load(f)
13771470

13781471
all_latents.extend(shard_data["latents"])
13791472
all_conditioning.extend(shard_data["conditioning"])
@@ -1425,6 +1518,7 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]:
14251518
MakeTrainingDataset,
14261519
SaveTrainingDataset,
14271520
LoadTrainingDataset,
1521+
ResolutionBucket,
14281522
]
14291523

14301524

comfy_extras/nodes_post_processing.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -221,22 +221,23 @@ def define_schema(cls):
221221
io.Image.Input("image"),
222222
io.Combo.Input("upscale_method", options=cls.upscale_methods),
223223
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
224+
io.Int.Input("resolution_steps", default=1, min=1, max=256),
224225
],
225226
outputs=[
226227
io.Image.Output(),
227228
],
228229
)
229230

230231
@classmethod
231-
def execute(cls, image, upscale_method, megapixels) -> io.NodeOutput:
232+
def execute(cls, image, upscale_method, megapixels, resolution_steps) -> io.NodeOutput:
232233
samples = image.movedim(-1,1)
233-
total = int(megapixels * 1024 * 1024)
234+
total = megapixels * 1024 * 1024
234235

235236
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
236-
width = round(samples.shape[3] * scale_by)
237-
height = round(samples.shape[2] * scale_by)
237+
width = round(samples.shape[3] * scale_by / resolution_steps) * resolution_steps
238+
height = round(samples.shape[2] * scale_by / resolution_steps) * resolution_steps
238239

239-
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
240+
s = comfy.utils.common_upscale(samples, int(width), int(height), upscale_method, "disabled")
240241
s = s.movedim(1,-1)
241242
return io.NodeOutput(s)
242243

0 commit comments

Comments
 (0)