@@ -1125,6 +1125,99 @@ def _group_process(cls, texts):
11251125# ========== Training Dataset Nodes ==========
11261126
11271127
1128+ class ResolutionBucket (io .ComfyNode ):
1129+ """Bucket latents and conditions by resolution for efficient batch training."""
1130+
1131+ @classmethod
1132+ def define_schema (cls ):
1133+ return io .Schema (
1134+ node_id = "ResolutionBucket" ,
1135+ display_name = "Resolution Bucket" ,
1136+ category = "dataset" ,
1137+ is_experimental = True ,
1138+ is_input_list = True ,
1139+ inputs = [
1140+ io .Latent .Input (
1141+ "latents" ,
1142+ tooltip = "List of latent dicts to bucket by resolution." ,
1143+ ),
1144+ io .Conditioning .Input (
1145+ "conditioning" ,
1146+ tooltip = "List of conditioning lists (must match latents length)." ,
1147+ ),
1148+ ],
1149+ outputs = [
1150+ io .Latent .Output (
1151+ display_name = "latents" ,
1152+ is_output_list = True ,
1153+ tooltip = "List of batched latent dicts, one per resolution bucket." ,
1154+ ),
1155+ io .Conditioning .Output (
1156+ display_name = "conditioning" ,
1157+ is_output_list = True ,
1158+ tooltip = "List of condition lists, one per resolution bucket." ,
1159+ ),
1160+ ],
1161+ )
1162+
1163+ @classmethod
1164+ def execute (cls , latents , conditioning ):
1165+ # latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1
1166+ # conditioning: list[list[cond]]
1167+
1168+ # Validate lengths match
1169+ if len (latents ) != len (conditioning ):
1170+ raise ValueError (
1171+ f"Number of latents ({ len (latents )} ) does not match number of conditions ({ len (conditioning )} )."
1172+ )
1173+
1174+ # Flatten latents and conditions to individual samples
1175+ flat_latents = [] # list of (C, H, W) tensors
1176+ flat_conditions = [] # list of condition lists
1177+
1178+ for latent_dict , cond in zip (latents , conditioning ):
1179+ samples = latent_dict ["samples" ] # (B, C, H, W)
1180+ batch_size = samples .shape [0 ]
1181+
1182+ # cond is a list of conditions with length == batch_size
1183+ for i in range (batch_size ):
1184+ flat_latents .append (samples [i ]) # (C, H, W)
1185+ flat_conditions .append (cond [i ]) # single condition
1186+
1187+ # Group by resolution (H, W)
1188+ buckets = {} # (H, W) -> {"latents": list, "conditions": list}
1189+
1190+ for latent , cond in zip (flat_latents , flat_conditions ):
1191+ # latent shape is (..., H, W) (B, C, H, W) or (B, T, C, H ,W)
1192+ h , w = latent .shape [- 2 ], latent .shape [- 1 ]
1193+ key = (h , w )
1194+
1195+ if key not in buckets :
1196+ buckets [key ] = {"latents" : [], "conditions" : []}
1197+
1198+ buckets [key ]["latents" ].append (latent )
1199+ buckets [key ]["conditions" ].append (cond )
1200+
1201+ # Convert buckets to output format
1202+ output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, ..., H, W)
1203+ output_conditions = [] # list[list[cond]] where each inner list has Bi conditions
1204+
1205+ for (h , w ), bucket_data in buckets .items ():
1206+ # Stack latents into batch: list of (..., H, W) -> (Bi, ..., H, W)
1207+ stacked_latents = torch .stack (bucket_data ["latents" ], dim = 0 )
1208+ output_latents .append ({"samples" : stacked_latents })
1209+
1210+ # Conditions stay as list of condition lists
1211+ output_conditions .append (bucket_data ["conditions" ])
1212+
1213+ logging .info (
1214+ f"Resolution bucket ({ h } x{ w } ): { len (bucket_data ['latents' ])} samples"
1215+ )
1216+
1217+ logging .info (f"Created { len (buckets )} resolution buckets from { len (flat_latents )} samples" )
1218+ return io .NodeOutput (output_latents , output_conditions )
1219+
1220+
11281221class MakeTrainingDataset (io .ComfyNode ):
11291222 """Encode images with VAE and texts with CLIP to create a training dataset."""
11301223
@@ -1373,7 +1466,7 @@ def execute(cls, folder_name):
13731466 shard_path = os .path .join (dataset_dir , shard_file )
13741467
13751468 with open (shard_path , "rb" ) as f :
1376- shard_data = torch .load (f , weights_only = True )
1469+ shard_data = torch .load (f )
13771470
13781471 all_latents .extend (shard_data ["latents" ])
13791472 all_conditioning .extend (shard_data ["conditioning" ])
@@ -1425,6 +1518,7 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]:
14251518 MakeTrainingDataset ,
14261519 SaveTrainingDataset ,
14271520 LoadTrainingDataset ,
1521+ ResolutionBucket ,
14281522 ]
14291523
14301524
0 commit comments