Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 160 additions & 74 deletions ace/ace_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

class ACEBatch:
"""
Batched ACE: parallel generator+reflector per mini-batch, chunked curator (cbs), then parallel post-curate.
Batched ACE: parallel generator+reflector per mini-batch, ComBEE-style curator reducer, then parallel post-curate.
"""

def __init__(
Expand Down Expand Up @@ -119,12 +119,17 @@ def _extract_config_params(self, config: Dict[str, Any]) -> Dict[str, Any]:
Dictionary with extracted parameters
"""
batch_size = int(config.get("batch_size", 1))
# Retained for compatibility with existing configs; batched curator updates
# now use curator_num_groups or default to floor(sqrt(n)) ComBEE grouping.
cbs = config.get("curator_batch_size")
if cbs is None:
cbs = batch_size
else:
cbs = int(cbs)
cbs = max(1, cbs)
curator_num_groups = config.get("curator_num_groups")
if curator_num_groups is not None:
curator_num_groups = max(1, int(curator_num_groups))
use_aug = config.get("augmented_shuffling", True)
aug_factor = int(config.get("augmented_shuffling_factor", 2))
if not use_aug:
Expand All @@ -147,6 +152,7 @@ def _extract_config_params(self, config: Dict[str, Any]) -> Dict[str, Any]:
'bulletpoint_analyzer_threshold': config.get('bulletpoint_analyzer_threshold', 0.90),
'batch_size': batch_size,
'curator_batch_size': cbs,
'curator_num_groups': curator_num_groups,
'augmented_shuffling_factor': aug_factor,
'continue_on_llm_error': config.get('continue_on_llm_error', False),
}
Expand Down Expand Up @@ -620,11 +626,11 @@ def _train_batch(
step_id_prefix: str = "train"
) -> List[Tuple[str, str, Dict[str, Any]]]:
"""
Train on a batch with async parallel generator+reflector, then sync for curator.
Train on a batch with async parallel generator+reflector, then sync for curator reduction.

Architecture:
Phase 1 (PARALLEL): Run generator + reflector for each sample in separate threads
Phase 2 (SYNC): Aggregate all bullet tags and reflections, run curator once
Phase 2 (SYNC): Propose ADD operations per reflection group, reduce them, apply once
Phase 3 (PARALLEL): Run post-curator generation for each sample in separate threads

Args:
Expand Down Expand Up @@ -717,40 +723,44 @@ def _train_batch(
print(f"Phase 1 complete: All {len(batch)} samples processed")

# ================================================================
# PHASE 2: Aggregate bullet tags + Run Curator
# PHASE 2: Aggregate bullet tags + ComBEE-style Curator reducer
# ================================================================
curator_batch_size = config_params.get('curator_batch_size', 10)

print(f"\n{'='*40}")
print(f"PHASE 2: Aggregation + Curator (curator_batch_size={curator_batch_size})")
print("PHASE 2: Aggregation + ComBEE-style Curator Reducer")
print(f"{'='*40}")

# Aggregate bullet tags and reflections only from samples that did NOT fail with API errors.
# API-error samples get score 0 and are excluded from playbook updates.
all_bullet_tags = []
all_reflections = []
all_contexts = []
reflection_records = []
api_error_count = 0
for result in sample_results:
for idx, result in enumerate(sample_results):
if result.get("pre_train_answer") == INCORRECT_DUE_TO_API_ERROR:
api_error_count += 1
continue
all_bullet_tags.extend(result["all_bullet_tags"])
all_reflections.append(result["reflection_content"])
all_contexts.append(result["context"])
reflection_records.append({
"sample_number": batch_step_start + idx,
"reflection": result["reflection_content"],
"context": result["context"],
})
if api_error_count:
print(f" Excluded {api_error_count} API-error sample(s) from Phase 2 aggregation")

original_reflection_count = len(reflection_records)

# Augmented Shuffling (Hive): duplicate each reflection p times, shuffle.
# Gives each reflection more opportunities to contribute under large batch sizes.
augmented_factor = config_params.get('augmented_shuffling_factor', 1)
if augmented_factor > 1 and all_reflections:
pairs = list(zip(all_reflections, all_contexts))
augmented = [p for p in pairs for _ in range(augmented_factor)]
random.shuffle(augmented)
all_reflections, all_contexts = map(list, zip(*augmented))
if augmented_factor > 1 and reflection_records:
augmented_records = [
record for record in reflection_records for _ in range(augmented_factor)
]
random.shuffle(augmented_records)
reflection_records = augmented_records
print(f" [Augmented Shuffling] factor={augmented_factor} | "
f"{len(pairs)} reflections -> {len(augmented)} after augmentation")
f"{original_reflection_count} reflections -> "
f"{len(reflection_records)} after augmentation")

# Save playbook and next_global_id before Phase 2 updates (for rollback on Phase 3 API errors)
playbook_before_phase2 = self.playbook
Expand All @@ -761,88 +771,163 @@ def _train_batch(
self.playbook = update_bullet_counts(self.playbook, all_bullet_tags)
print(f" Applied {len(all_bullet_tags)} bullet tag updates from {len(batch)} samples")

# All curator proposal calls below read from the same base playbook. Only the
# reducer output is applied, so chunk order cannot affect playbook updates.
base_playbook = self.playbook
last_batch_step = batch_step_start + len(batch) - 1

def _run_one_curator_call(
combined_reflection: str,
combined_context: str,
last_step: int,
def _combine_group_records(group_records: List[Dict[str, Any]]) -> Tuple[str, str]:
combined_reflection = "\n\n---\n\n".join(
f"[Sample {record['sample_number']}] {record['reflection']}"
for record in group_records
if record["reflection"] != "(empty)"
)
if not combined_reflection:
combined_reflection = "(empty)"
combined_context = "\n\n---\n\n".join(
f"[Sample {record['sample_number']}] {record['context']}"
for record in group_records
if record["context"]
)
return combined_reflection, combined_context

def _run_one_curator_proposal(
group_records: List[Dict[str, Any]],
call_id: str,
diag_chunk_size: int,
) -> None:
group_idx: int,
) -> Dict[str, Any]:
combined_reflection, combined_context = _combine_group_records(group_records)
try:
cr_tokens = count_tokens(combined_reflection)
cc_tokens = count_tokens(combined_context)
pb_tokens = count_tokens(self.playbook)
pb_tokens = count_tokens(base_playbook)
print(
f" [DIAG] curator_chunk_size={diag_chunk_size} | "
f" [DIAG] curator_group={group_idx} | "
f"group_size={len(group_records)} | "
f"reflection={cr_tokens} tok | context={cc_tokens} tok | "
f"playbook={pb_tokens} tok | total~{cr_tokens + cc_tokens + pb_tokens} tok"
)
except Exception:
pass
stats = get_playbook_stats(self.playbook)
self.playbook, self.next_global_id, operations, _ = self.curator.curate(
current_playbook=self.playbook,
stats = get_playbook_stats(base_playbook)
operations, _ = self.curator.propose_operations(
current_playbook=base_playbook,
recent_reflection=combined_reflection,
question_context=combined_context,
current_step=last_step,
current_step=last_batch_step,
total_samples=total_samples,
token_budget=token_budget,
playbook_stats=stats,
use_ground_truth=not no_ground_truth,
use_json_mode=use_json_mode,
call_id=call_id,
log_dir=log_dir,
next_global_id=self.next_global_id,
)
if self.use_bulletpoint_analyzer and self.bulletpoint_analyzer:
print(f" Running BulletpointAnalyzer (threshold={self.bulletpoint_analyzer_threshold})...")
self.playbook = self.bulletpoint_analyzer.analyze(
playbook=self.playbook,
threshold=self.bulletpoint_analyzer_threshold,
merge=True,
)
return {
"group_id": group_idx,
"sample_numbers": [record["sample_number"] for record in group_records],
"operations": operations,
}

print(
f" Chunk by curator_batch_size={curator_batch_size} "
f"({len(all_reflections)} reflections)"
)
num_chunks = (len(all_reflections) + curator_batch_size - 1) // curator_batch_size
print(f" Running Curator {num_chunks} times (each with up to {curator_batch_size} samples)")
for chunk_idx in range(num_chunks):
start_idx = chunk_idx * curator_batch_size
end_idx = min(start_idx + curator_batch_size, len(all_reflections))
chunk_reflections = all_reflections[start_idx:end_idx]
chunk_contexts = all_contexts[start_idx:end_idx]
combined_reflection = "\n\n---\n\n".join(
f"[Sample {start_idx + i + 1}] {r}"
for i, r in enumerate(chunk_reflections)
if r != "(empty)"
operation_groups = []
if reflection_records:
configured_num_groups = config_params.get('curator_num_groups')
if configured_num_groups is None:
num_groups = max(1, int(original_reflection_count ** 0.5))
else:
num_groups = configured_num_groups
num_groups = max(1, min(num_groups, len(reflection_records)))

print(
f" ComBEE grouping: {len(reflection_records)} records into "
f"{num_groups} groups (original reflections={original_reflection_count})"
)
if not combined_reflection:
combined_reflection = "(empty)"
combined_context = "\n\n---\n\n".join(
f"[Sample {start_idx + i + 1}] {c}"
for i, c in enumerate(chunk_contexts)
if c

base_group_size = len(reflection_records) // num_groups
remainder = len(reflection_records) % num_groups
offset = 0
group_tasks = []
for group_idx in range(num_groups):
group_size = base_group_size + (1 if group_idx < remainder else 0)
group_records = reflection_records[offset: offset + group_size]
offset += group_size
if not group_records:
continue

group_tasks.append((group_idx + 1, group_records))

print(f" Running {len(group_tasks)} curator proposal groups in parallel")
with ThreadPoolExecutor(max_workers=len(group_tasks)) as executor:
future_to_group = {}
for group_id, group_records in group_tasks:
print(
f"\n--- Curator proposal group {group_id}/{num_groups} "
f"({len(group_records)} records) ---"
)
future = executor.submit(
_run_one_curator_proposal,
group_records,
f"{step_id_prefix}_s_{last_batch_step}_group_{group_id}",
group_id,
)
future_to_group[future] = group_id

for future in as_completed(future_to_group):
group_id = future_to_group[future]
try:
group_result = future.result()
except Exception as e:
print(f" ERROR in curator proposal group {group_id}: {e}")
raise
if group_result["operations"]:
operation_groups.append(group_result)

operation_groups.sort(key=lambda group: group["group_id"])
else:
print(" No valid reflections available for curator proposals")

proposed_operation_count = sum(len(group["operations"]) for group in operation_groups)
final_operations = []
if proposed_operation_count == 0:
print(" No proposed ADD operations from curator groups")
else:
print(
f"\n--- Curator reducer ({len(operation_groups)} groups, "
f"{proposed_operation_count} proposed ADD operations) ---"
)
final_operations, _ = self.curator.aggregate_operations(
current_playbook=base_playbook,
operation_groups=operation_groups,
current_step=last_batch_step,
total_samples=total_samples,
token_budget=token_budget,
playbook_stats=get_playbook_stats(base_playbook),
use_json_mode=use_json_mode,
call_id=f"{step_id_prefix}_s_{last_batch_step}_reducer",
log_dir=log_dir,
)

if final_operations:
self.playbook, self.next_global_id = apply_curator_operations(
base_playbook,
final_operations,
self.next_global_id,
)
last_step_in_chunk = batch_step_start + end_idx - 1
print(
f"\n--- Curator chunk {chunk_idx + 1}/{num_chunks} "
f"(samples {start_idx + 1}-{end_idx}, step {last_step_in_chunk}) ---"
f"\n Playbook updated after reducer: {len(final_operations)} ADD operations | "
f"{count_tokens(self.playbook)} tokens"
)
_run_one_curator_call(
combined_reflection,
combined_context,
last_step_in_chunk,
f"{step_id_prefix}_s_{last_step_in_chunk}_chunk_{chunk_idx + 1}",
len(chunk_reflections),
else:
self.playbook = base_playbook
print(f"\n Playbook unchanged by curator reducer: {count_tokens(self.playbook)} tokens")

if self.use_bulletpoint_analyzer and self.bulletpoint_analyzer:
print(f" Running BulletpointAnalyzer (threshold={self.bulletpoint_analyzer_threshold})...")
self.playbook = self.bulletpoint_analyzer.analyze(
playbook=self.playbook,
threshold=self.bulletpoint_analyzer_threshold,
merge=True,
)
print(
f"\n Playbook updated after {num_chunks} Curator calls: "
f"{count_tokens(self.playbook)} tokens"
)

# ================================================================
# PHASE 3: Parallel Post-Curator Generation
Expand Down Expand Up @@ -980,7 +1065,8 @@ def _offline_train(

print(f"Total epochs: {num_epochs}")
print(f"Train samples per epoch: {len(train_samples)}")
print(f"Gen batch size: {batch_size} | Curator batch size: {config_params.get('curator_batch_size', 10)}")
curator_grouping = config_params.get('curator_num_groups') or "sqrt(n)"
print(f"Gen batch size: {batch_size} | Curator reducer groups: {curator_grouping}")
print(f"Batches per epoch: {num_batches}")
print(f"Val samples: {len(val_samples)}")
print(f"Evaluation frequency: every {eval_steps} steps\n")
Expand Down Expand Up @@ -1463,4 +1549,4 @@ def _online_train_and_test(
"accuracy": final_test_accuracy,
"correct": correct_count_sample_based,
"total": total_count,
}
}
Loading