diff --git a/ace/ace_batch.py b/ace/ace_batch.py index 0293ee58..9513c9c9 100644 --- a/ace/ace_batch.py +++ b/ace/ace_batch.py @@ -26,7 +26,7 @@ class ACEBatch: """ - Batched ACE: parallel generator+reflector per mini-batch, chunked curator (cbs), then parallel post-curate. + Batched ACE: parallel generator+reflector per mini-batch, ComBEE-style curator reducer, then parallel post-curate. """ def __init__( @@ -119,12 +119,17 @@ def _extract_config_params(self, config: Dict[str, Any]) -> Dict[str, Any]: Dictionary with extracted parameters """ batch_size = int(config.get("batch_size", 1)) + # Retained for compatibility with existing configs; batched curator updates + # now use curator_num_groups or default to floor(sqrt(n)) ComBEE grouping. cbs = config.get("curator_batch_size") if cbs is None: cbs = batch_size else: cbs = int(cbs) cbs = max(1, cbs) + curator_num_groups = config.get("curator_num_groups") + if curator_num_groups is not None: + curator_num_groups = max(1, int(curator_num_groups)) use_aug = config.get("augmented_shuffling", True) aug_factor = int(config.get("augmented_shuffling_factor", 2)) if not use_aug: @@ -147,6 +152,7 @@ def _extract_config_params(self, config: Dict[str, Any]) -> Dict[str, Any]: 'bulletpoint_analyzer_threshold': config.get('bulletpoint_analyzer_threshold', 0.90), 'batch_size': batch_size, 'curator_batch_size': cbs, + 'curator_num_groups': curator_num_groups, 'augmented_shuffling_factor': aug_factor, 'continue_on_llm_error': config.get('continue_on_llm_error', False), } @@ -620,11 +626,11 @@ def _train_batch( step_id_prefix: str = "train" ) -> List[Tuple[str, str, Dict[str, Any]]]: """ - Train on a batch with async parallel generator+reflector, then sync for curator. + Train on a batch with async parallel generator+reflector, then sync for curator reduction. Architecture: Phase 1 (PARALLEL): Run generator + reflector for each sample in separate threads - Phase 2 (SYNC): Aggregate all bullet tags and reflections, run curator once + Phase 2 (SYNC): Propose ADD operations per reflection group, reduce them, apply once Phase 3 (PARALLEL): Run post-curator generation for each sample in separate threads Args: @@ -717,40 +723,44 @@ def _train_batch( print(f"Phase 1 complete: All {len(batch)} samples processed") # ================================================================ - # PHASE 2: Aggregate bullet tags + Run Curator + # PHASE 2: Aggregate bullet tags + ComBEE-style Curator reducer # ================================================================ - curator_batch_size = config_params.get('curator_batch_size', 10) - print(f"\n{'='*40}") - print(f"PHASE 2: Aggregation + Curator (curator_batch_size={curator_batch_size})") + print("PHASE 2: Aggregation + ComBEE-style Curator Reducer") print(f"{'='*40}") # Aggregate bullet tags and reflections only from samples that did NOT fail with API errors. # API-error samples get score 0 and are excluded from playbook updates. all_bullet_tags = [] - all_reflections = [] - all_contexts = [] + reflection_records = [] api_error_count = 0 - for result in sample_results: + for idx, result in enumerate(sample_results): if result.get("pre_train_answer") == INCORRECT_DUE_TO_API_ERROR: api_error_count += 1 continue all_bullet_tags.extend(result["all_bullet_tags"]) - all_reflections.append(result["reflection_content"]) - all_contexts.append(result["context"]) + reflection_records.append({ + "sample_number": batch_step_start + idx, + "reflection": result["reflection_content"], + "context": result["context"], + }) if api_error_count: print(f" Excluded {api_error_count} API-error sample(s) from Phase 2 aggregation") + original_reflection_count = len(reflection_records) + # Augmented Shuffling (Hive): duplicate each reflection p times, shuffle. # Gives each reflection more opportunities to contribute under large batch sizes. augmented_factor = config_params.get('augmented_shuffling_factor', 1) - if augmented_factor > 1 and all_reflections: - pairs = list(zip(all_reflections, all_contexts)) - augmented = [p for p in pairs for _ in range(augmented_factor)] - random.shuffle(augmented) - all_reflections, all_contexts = map(list, zip(*augmented)) + if augmented_factor > 1 and reflection_records: + augmented_records = [ + record for record in reflection_records for _ in range(augmented_factor) + ] + random.shuffle(augmented_records) + reflection_records = augmented_records print(f" [Augmented Shuffling] factor={augmented_factor} | " - f"{len(pairs)} reflections -> {len(augmented)} after augmentation") + f"{original_reflection_count} reflections -> " + f"{len(reflection_records)} after augmentation") # Save playbook and next_global_id before Phase 2 updates (for rollback on Phase 3 API errors) playbook_before_phase2 = self.playbook @@ -761,32 +771,50 @@ def _train_batch( self.playbook = update_bullet_counts(self.playbook, all_bullet_tags) print(f" Applied {len(all_bullet_tags)} bullet tag updates from {len(batch)} samples") + # All curator proposal calls below read from the same base playbook. Only the + # reducer output is applied, so chunk order cannot affect playbook updates. + base_playbook = self.playbook last_batch_step = batch_step_start + len(batch) - 1 - def _run_one_curator_call( - combined_reflection: str, - combined_context: str, - last_step: int, + def _combine_group_records(group_records: List[Dict[str, Any]]) -> Tuple[str, str]: + combined_reflection = "\n\n---\n\n".join( + f"[Sample {record['sample_number']}] {record['reflection']}" + for record in group_records + if record["reflection"] != "(empty)" + ) + if not combined_reflection: + combined_reflection = "(empty)" + combined_context = "\n\n---\n\n".join( + f"[Sample {record['sample_number']}] {record['context']}" + for record in group_records + if record["context"] + ) + return combined_reflection, combined_context + + def _run_one_curator_proposal( + group_records: List[Dict[str, Any]], call_id: str, - diag_chunk_size: int, - ) -> None: + group_idx: int, + ) -> Dict[str, Any]: + combined_reflection, combined_context = _combine_group_records(group_records) try: cr_tokens = count_tokens(combined_reflection) cc_tokens = count_tokens(combined_context) - pb_tokens = count_tokens(self.playbook) + pb_tokens = count_tokens(base_playbook) print( - f" [DIAG] curator_chunk_size={diag_chunk_size} | " + f" [DIAG] curator_group={group_idx} | " + f"group_size={len(group_records)} | " f"reflection={cr_tokens} tok | context={cc_tokens} tok | " f"playbook={pb_tokens} tok | total~{cr_tokens + cc_tokens + pb_tokens} tok" ) except Exception: pass - stats = get_playbook_stats(self.playbook) - self.playbook, self.next_global_id, operations, _ = self.curator.curate( - current_playbook=self.playbook, + stats = get_playbook_stats(base_playbook) + operations, _ = self.curator.propose_operations( + current_playbook=base_playbook, recent_reflection=combined_reflection, question_context=combined_context, - current_step=last_step, + current_step=last_batch_step, total_samples=total_samples, token_budget=token_budget, playbook_stats=stats, @@ -794,55 +822,112 @@ def _run_one_curator_call( use_json_mode=use_json_mode, call_id=call_id, log_dir=log_dir, - next_global_id=self.next_global_id, ) - if self.use_bulletpoint_analyzer and self.bulletpoint_analyzer: - print(f" Running BulletpointAnalyzer (threshold={self.bulletpoint_analyzer_threshold})...") - self.playbook = self.bulletpoint_analyzer.analyze( - playbook=self.playbook, - threshold=self.bulletpoint_analyzer_threshold, - merge=True, - ) + return { + "group_id": group_idx, + "sample_numbers": [record["sample_number"] for record in group_records], + "operations": operations, + } - print( - f" Chunk by curator_batch_size={curator_batch_size} " - f"({len(all_reflections)} reflections)" - ) - num_chunks = (len(all_reflections) + curator_batch_size - 1) // curator_batch_size - print(f" Running Curator {num_chunks} times (each with up to {curator_batch_size} samples)") - for chunk_idx in range(num_chunks): - start_idx = chunk_idx * curator_batch_size - end_idx = min(start_idx + curator_batch_size, len(all_reflections)) - chunk_reflections = all_reflections[start_idx:end_idx] - chunk_contexts = all_contexts[start_idx:end_idx] - combined_reflection = "\n\n---\n\n".join( - f"[Sample {start_idx + i + 1}] {r}" - for i, r in enumerate(chunk_reflections) - if r != "(empty)" + operation_groups = [] + if reflection_records: + configured_num_groups = config_params.get('curator_num_groups') + if configured_num_groups is None: + num_groups = max(1, int(original_reflection_count ** 0.5)) + else: + num_groups = configured_num_groups + num_groups = max(1, min(num_groups, len(reflection_records))) + + print( + f" ComBEE grouping: {len(reflection_records)} records into " + f"{num_groups} groups (original reflections={original_reflection_count})" ) - if not combined_reflection: - combined_reflection = "(empty)" - combined_context = "\n\n---\n\n".join( - f"[Sample {start_idx + i + 1}] {c}" - for i, c in enumerate(chunk_contexts) - if c + + base_group_size = len(reflection_records) // num_groups + remainder = len(reflection_records) % num_groups + offset = 0 + group_tasks = [] + for group_idx in range(num_groups): + group_size = base_group_size + (1 if group_idx < remainder else 0) + group_records = reflection_records[offset: offset + group_size] + offset += group_size + if not group_records: + continue + + group_tasks.append((group_idx + 1, group_records)) + + print(f" Running {len(group_tasks)} curator proposal groups in parallel") + with ThreadPoolExecutor(max_workers=len(group_tasks)) as executor: + future_to_group = {} + for group_id, group_records in group_tasks: + print( + f"\n--- Curator proposal group {group_id}/{num_groups} " + f"({len(group_records)} records) ---" + ) + future = executor.submit( + _run_one_curator_proposal, + group_records, + f"{step_id_prefix}_s_{last_batch_step}_group_{group_id}", + group_id, + ) + future_to_group[future] = group_id + + for future in as_completed(future_to_group): + group_id = future_to_group[future] + try: + group_result = future.result() + except Exception as e: + print(f" ERROR in curator proposal group {group_id}: {e}") + raise + if group_result["operations"]: + operation_groups.append(group_result) + + operation_groups.sort(key=lambda group: group["group_id"]) + else: + print(" No valid reflections available for curator proposals") + + proposed_operation_count = sum(len(group["operations"]) for group in operation_groups) + final_operations = [] + if proposed_operation_count == 0: + print(" No proposed ADD operations from curator groups") + else: + print( + f"\n--- Curator reducer ({len(operation_groups)} groups, " + f"{proposed_operation_count} proposed ADD operations) ---" + ) + final_operations, _ = self.curator.aggregate_operations( + current_playbook=base_playbook, + operation_groups=operation_groups, + current_step=last_batch_step, + total_samples=total_samples, + token_budget=token_budget, + playbook_stats=get_playbook_stats(base_playbook), + use_json_mode=use_json_mode, + call_id=f"{step_id_prefix}_s_{last_batch_step}_reducer", + log_dir=log_dir, + ) + + if final_operations: + self.playbook, self.next_global_id = apply_curator_operations( + base_playbook, + final_operations, + self.next_global_id, ) - last_step_in_chunk = batch_step_start + end_idx - 1 print( - f"\n--- Curator chunk {chunk_idx + 1}/{num_chunks} " - f"(samples {start_idx + 1}-{end_idx}, step {last_step_in_chunk}) ---" + f"\n Playbook updated after reducer: {len(final_operations)} ADD operations | " + f"{count_tokens(self.playbook)} tokens" ) - _run_one_curator_call( - combined_reflection, - combined_context, - last_step_in_chunk, - f"{step_id_prefix}_s_{last_step_in_chunk}_chunk_{chunk_idx + 1}", - len(chunk_reflections), + else: + self.playbook = base_playbook + print(f"\n Playbook unchanged by curator reducer: {count_tokens(self.playbook)} tokens") + + if self.use_bulletpoint_analyzer and self.bulletpoint_analyzer: + print(f" Running BulletpointAnalyzer (threshold={self.bulletpoint_analyzer_threshold})...") + self.playbook = self.bulletpoint_analyzer.analyze( + playbook=self.playbook, + threshold=self.bulletpoint_analyzer_threshold, + merge=True, ) - print( - f"\n Playbook updated after {num_chunks} Curator calls: " - f"{count_tokens(self.playbook)} tokens" - ) # ================================================================ # PHASE 3: Parallel Post-Curator Generation @@ -980,7 +1065,8 @@ def _offline_train( print(f"Total epochs: {num_epochs}") print(f"Train samples per epoch: {len(train_samples)}") - print(f"Gen batch size: {batch_size} | Curator batch size: {config_params.get('curator_batch_size', 10)}") + curator_grouping = config_params.get('curator_num_groups') or "sqrt(n)" + print(f"Gen batch size: {batch_size} | Curator reducer groups: {curator_grouping}") print(f"Batches per epoch: {num_batches}") print(f"Val samples: {len(val_samples)}") print(f"Evaluation frequency: every {eval_steps} steps\n") @@ -1463,4 +1549,4 @@ def _online_train_and_test( "accuracy": final_test_accuracy, "correct": correct_count_sample_based, "total": total_count, - } \ No newline at end of file + } diff --git a/ace/core/curator.py b/ace/core/curator.py index d1e4cf7d..fcadcfbc 100644 --- a/ace/core/curator.py +++ b/ace/core/curator.py @@ -6,7 +6,11 @@ import json from pathlib import Path from typing import Dict, List, Tuple, Optional, Any -from ..prompts.curator import CURATOR_PROMPT, CURATOR_PROMPT_NO_GT +from ..prompts.curator import ( + CURATOR_OPERATIONS_AGGREGATION_PROMPT, + CURATOR_PROMPT, + CURATOR_PROMPT_NO_GT, +) from playbook_utils import extract_json_from_text, apply_curator_operations from logger import log_curator_operation_diff, log_curator_failure from llm import timed_llm_call @@ -161,6 +165,181 @@ def curate( print("⏭️ Skipping curator operation and continuing training") return current_playbook, next_global_id, [], call_info + + def propose_operations( + self, + current_playbook: str, + recent_reflection: str, + question_context: str, + current_step: int, + total_samples: int, + token_budget: int, + playbook_stats: Dict[str, Any], + use_ground_truth: bool = True, + use_json_mode: bool = False, + call_id: str = "curator_propose", + log_dir: Optional[str] = None, + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + """ + Generate curator operations without applying them to the playbook. + + This is used by batched ComBEE-style aggregation: each reflection group + proposes local ADD operations from the same playbook snapshot, and a + reducer later deduplicates those proposals before they are applied once. + """ + stats_str = json.dumps(playbook_stats, indent=2) + + if use_ground_truth: + prompt = CURATOR_PROMPT.format( + current_step=current_step, + total_samples=total_samples, + token_budget=token_budget, + playbook_stats=stats_str, + recent_reflection=recent_reflection, + current_playbook=current_playbook, + question_context=question_context, + ) + else: + prompt = CURATOR_PROMPT_NO_GT.format( + current_step=current_step, + total_samples=total_samples, + token_budget=token_budget, + playbook_stats=stats_str, + recent_reflection=recent_reflection, + current_playbook=current_playbook, + question_context=question_context, + ) + + response, call_info = timed_llm_call( + self.api_client, + self.api_provider, + self.model, + prompt, + role="curator", + call_id=call_id, + max_tokens=self.max_tokens, + log_dir=log_dir, + use_json_mode=use_json_mode, + ) + + if response.startswith("INCORRECT_DUE_TO_EMPTY_RESPONSE"): + print("⏭️ Skipping curator proposal due to empty response") + if log_dir: + log_curator_failure(log_dir, current_step, "empty_response", response[:200], 0) + return [], call_info + + try: + operations_info = self._extract_and_validate_operations(response) + operations = self._add_operations_only(operations_info["operations"], call_id) + print(f"✅ Curator proposal JSON schema validated successfully: {len(operations)} ADD operations") + return operations, call_info + + except (ValueError, KeyError, TypeError, json.JSONDecodeError) as e: + print(f"❌ Curator proposal JSON parsing failed: {e}") + print(f"📄 Raw curator proposal preview: {response[:300]}...") + if log_dir: + log_curator_failure(log_dir, current_step, "json_parse_error", response, 0, str(e)) + print("⏭️ Skipping curator proposal due to invalid JSON format") + return [], call_info + + except Exception as e: + print(f"❌ Curator proposal failed: {e}") + print(f"📄 Raw curator proposal preview: {response[:300]}...") + if log_dir: + log_curator_failure(log_dir, current_step, "operation_error", response, 0, str(e)) + print("⏭️ Skipping curator proposal and continuing training") + return [], call_info + + def aggregate_operations( + self, + current_playbook: str, + operation_groups: List[Dict[str, Any]], + current_step: int, + total_samples: int, + token_budget: int, + playbook_stats: Dict[str, Any], + use_json_mode: bool = False, + call_id: str = "curator_reduce", + log_dir: Optional[str] = None, + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + """ + Reduce independently proposed ADD operations into one final operation set. + """ + proposed_operations = json.dumps(operation_groups, indent=2, ensure_ascii=False) + stats_str = json.dumps(playbook_stats, indent=2) + prompt = CURATOR_OPERATIONS_AGGREGATION_PROMPT.format( + current_step=current_step, + total_samples=total_samples, + token_budget=token_budget, + playbook_stats=stats_str, + current_playbook=current_playbook, + proposed_operations=proposed_operations, + ) + + response, call_info = timed_llm_call( + self.api_client, + self.api_provider, + self.model, + prompt, + role="curator", + call_id=call_id, + max_tokens=self.max_tokens, + log_dir=log_dir, + use_json_mode=use_json_mode, + ) + + if response.startswith("INCORRECT_DUE_TO_EMPTY_RESPONSE"): + print("⏭️ Skipping curator reducer due to empty response") + if log_dir: + log_curator_failure(log_dir, current_step, "empty_response", response[:200], 0) + return [], call_info + + try: + operations_info = self._extract_and_validate_operations(response) + operations = self._add_operations_only(operations_info["operations"], call_id) + print(f"✅ Curator reducer JSON schema validated successfully: {len(operations)} final ADD operations") + + for op in operations: + try: + diff_log_dir = Path(log_dir).parent if log_dir else None + log_curator_operation_diff(diff_log_dir, op, current_playbook, call_id) + except Exception as e: + print(f"Warning: Failed to log reducer operation diff: {e}") + + return operations, call_info + + except (ValueError, KeyError, TypeError, json.JSONDecodeError) as e: + print(f"❌ Curator reducer JSON parsing failed: {e}") + print(f"📄 Raw curator reducer preview: {response[:300]}...") + if log_dir: + log_curator_failure(log_dir, current_step, "json_parse_error", response, 0, str(e)) + print("⏭️ Skipping curator reducer due to invalid JSON format") + return [], call_info + + except Exception as e: + print(f"❌ Curator reducer failed: {e}") + print(f"📄 Raw curator reducer preview: {response[:300]}...") + if log_dir: + log_curator_failure(log_dir, current_step, "operation_error", response, 0, str(e)) + print("⏭️ Skipping curator reducer and continuing training") + return [], call_info + + def _add_operations_only( + self, + operations: List[Dict[str, Any]], + source: str, + ) -> List[Dict[str, Any]]: + """ACE currently applies only ADD operations; ignore unsupported proposal types.""" + add_operations = [] + for op in operations: + if op.get("type") == "ADD": + add_operations.append(op) + else: + print( + f"Warning: Ignoring unsupported curator operation " + f"'{op.get('type', 'UNKNOWN')}' from {source}" + ) + return add_operations def _extract_and_validate_operations( self, @@ -221,4 +400,4 @@ def _extract_and_validate_operations( if missing_fields: raise ValueError(f"ADD operation {i} missing fields: {list(missing_fields)}") - return operations_info \ No newline at end of file + return operations_info diff --git a/ace/prompts/curator.py b/ace/prompts/curator.py index dadb4543..11c53a37 100644 --- a/ace/prompts/curator.py +++ b/ace/prompts/curator.py @@ -127,4 +127,59 @@ }} --- -""" \ No newline at end of file +""" + + +CURATOR_OPERATIONS_AGGREGATION_PROMPT = """You are aggregating playbook update proposals generated independently from different subsets of a training batch. + +**CRITICAL: You MUST respond with valid JSON only. Do not use markdown formatting or code blocks.** + +**Context:** +- Each proposal below was produced by a Curator after reviewing a different reflection group. +- The current ACE implementation only supports ADD operations. Do not output UPDATE, MERGE, DELETE, or CREATE_META operations. +- Bullet IDs are assigned by the system. Do not include bullet IDs in operation content. + +**Instructions:** +- Review the current playbook and all proposed ADD operations. +- Synthesize overlapping proposals into a single more specific ADD operation. +- Drop proposals that are already covered by the current playbook. +- Drop vague, redundant, or low-value proposals. +- Preserve complementary insights as separate ADD operations when they teach genuinely different behavior. +- Keep each operation concise, actionable, and suitable for future tasks. +- If no new content should be added, return an empty operations list. + +**Training Context:** +- Total token budget: {token_budget} tokens +- Training progress: Sample {current_step} out of {total_samples} + +**Current Playbook Stats:** +{playbook_stats} + +**Current Playbook:** +{current_playbook} + +**Independent Proposed Operations:** +{proposed_operations} + +**Your Task:** +Output ONLY a valid JSON object with these exact fields: +- reasoning: a concise rationale for the aggregation decisions +- operations: a list of final ADD operations to apply to the playbook + - type: must be "ADD" + - section: the section to add the bullet to + - content: the final bullet content; do not include a bullet ID + +**RESPONSE FORMAT - Output ONLY this JSON structure (no markdown, no code blocks):** +{{ + "reasoning": "[Concise rationale here]", + "operations": [ + {{ + "type": "ADD", + "section": "formulas_and_calculations", + "content": "[Deduplicated, actionable insight...]" + }} + ] +}} + +--- +""" diff --git a/eval/finance/run.py b/eval/finance/run.py index 5c94df7d..0aa808b3 100644 --- a/eval/finance/run.py +++ b/eval/finance/run.py @@ -57,7 +57,9 @@ def parse_args(): parser.add_argument("--batch_size", type=int, default=1, help="Generator mini-batch size; >1 uses batched ACE (parallel phase1/3)") parser.add_argument("--curator_batch_size", type=int, default=None, - help="Curator chunk size (default: same as --batch_size)") + help="Legacy option retained for compatibility; batched ACE now uses curator_num_groups") + parser.add_argument("--curator_num_groups", type=int, default=None, + help="Number of ComBEE curator proposal groups (default: floor(sqrt(batch reflections)))") parser.add_argument( "--augmented_shuffling", action=argparse.BooleanOptionalAction, @@ -237,6 +239,7 @@ def main(): 'api_provider': args.api_provider, 'batch_size': args.batch_size, 'curator_batch_size': args.curator_batch_size, + 'curator_num_groups': args.curator_num_groups, 'augmented_shuffling': args.augmented_shuffling, } diff --git a/eval/mind2web/run.py b/eval/mind2web/run.py index fcb2f180..75a60559 100644 --- a/eval/mind2web/run.py +++ b/eval/mind2web/run.py @@ -50,7 +50,9 @@ def parse_args(): parser.add_argument("--batch_size", type=int, default=1, help="Generator mini-batch size; >1 uses batched ACE") parser.add_argument("--curator_batch_size", type=int, default=None, - help="Curator chunk size (default: same as --batch_size)") + help="Legacy option retained for compatibility; batched ACE now uses curator_num_groups") + parser.add_argument("--curator_num_groups", type=int, default=None, + help="Number of ComBEE curator proposal groups (default: floor(sqrt(batch reflections)))") parser.add_argument( "--augmented_shuffling", action=argparse.BooleanOptionalAction, @@ -212,6 +214,7 @@ def main(): 'api_provider': args.api_provider, 'batch_size': args.batch_size, 'curator_batch_size': args.curator_batch_size, + 'curator_num_groups': args.curator_num_groups, 'augmented_shuffling': args.augmented_shuffling, } diff --git a/eval/mind2web2/run.py b/eval/mind2web2/run.py index 73d03f0c..ba563743 100644 --- a/eval/mind2web2/run.py +++ b/eval/mind2web2/run.py @@ -50,7 +50,9 @@ def parse_args(): parser.add_argument("--batch_size", type=int, default=1, help="Generator mini-batch size; >1 uses batched ACE") parser.add_argument("--curator_batch_size", type=int, default=None, - help="Curator chunk size (default: same as --batch_size)") + help="Legacy option retained for compatibility; batched ACE now uses curator_num_groups") + parser.add_argument("--curator_num_groups", type=int, default=None, + help="Number of ComBEE curator proposal groups (default: floor(sqrt(batch reflections)))") parser.add_argument( "--augmented_shuffling", action=argparse.BooleanOptionalAction, @@ -212,6 +214,7 @@ def main(): 'api_provider': args.api_provider, 'batch_size': args.batch_size, 'curator_batch_size': args.curator_batch_size, + 'curator_num_groups': args.curator_num_groups, 'augmented_shuffling': args.augmented_shuffling, }