Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/gimbench/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,15 @@ def __init__(self, args: Namespace, dataset: Dataset):
self.args = args

@staticmethod
def _safe_average(items: list, attr: str) -> float:
values = [getattr(item, attr) for item in items if getattr(item, attr) != -1]
def _filter_non_error_items(items: list) -> list:
"""Filter out items that have error messages."""
return [item for item in items if not item.error_msg]

@staticmethod
def _safe_average(items: list, attr: str, exclude_errors: bool = True) -> float:
"""Calculate average of attribute, optionally excluding errored items and sentinel values (-1)."""
filtered_items = BaseEvaluator._filter_non_error_items(items) if exclude_errors else items
values = [getattr(item, attr) for item in filtered_items if getattr(item, attr) != -1]
return sum(values) / len(values) if values else 0.0

def _log_progress(self, total: int, curr_idx: int, log_interval: int = 10) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/gimbench/cv/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def evaluate(self) -> EvalResult:
evaled_items = list(tqdm(results, total=total, desc=f"Evaluating {self.args.model_name}"))
# TODO: Add progress logging for multi-threaded evaluation

non_error_items = [item for item in evaled_items if not item.error_msg]
non_error_items = self._filter_non_error_items(evaled_items)
errors = sum(1 for item in evaled_items if item.error_msg)
total_fields = sum(item.num_fields for item in non_error_items)
total_correct = sum(item.num_correct for item in non_error_items)
Expand Down
9 changes: 5 additions & 4 deletions src/gimbench/match/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,14 @@ def evaluate(self) -> EvalResult:
self.end_time = datetime.now()
logger.info(f"Evaluation completed at {self.end_time}")

non_error_items = self._filter_non_error_items(evaled_items)
total_tags = sum(item.num_tags for item in evaled_items)
valid_tags = sum(item.num_tags for item in evaled_items if not item.error_msg)
total_has_prediction = sum(item.num_has_prediction for item in evaled_items if not item.error_msg)
valid_tags = sum(item.num_tags for item in non_error_items)
total_has_prediction = sum(item.num_has_prediction for item in non_error_items)

total_regex = sum(item.num_regex for item in evaled_items)
valid_regex = sum(item.num_regex for item in evaled_items if not item.error_msg)
total_regex_match = sum(item.num_regex_match for item in evaled_items if not item.error_msg)
valid_regex = sum(item.num_regex for item in non_error_items)
total_regex_match = sum(item.num_regex_match for item in non_error_items)
return EvalResult(
args=self.args,
start_time=self.start_time,
Expand Down
3 changes: 2 additions & 1 deletion src/gimbench/mcqa/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,12 @@ def evaluate(self) -> EvalResult:
evaled_items = list(tqdm(results, total=total, desc=f"Evaluating {self.args.model_name}"))
# TODO: Add progress logging for multi-threaded evaluation

non_error_items = self._filter_non_error_items(evaled_items)
errors = sum(1 for item in evaled_items if item.error_msg)
corrects = sum(1 for item in evaled_items if item.conclusion)
evaluates = len(evaled_items)
accuracy = corrects / evaluates if evaluates > 0 else 0.0
calibrated_accuracy = corrects / (evaluates - errors) if (evaluates - errors) > 0 else 0.0
calibrated_accuracy = corrects / len(non_error_items) if non_error_items else 0.0
logger.info(f"Final accuracy over {total} examples: {corrects}/{total} = {accuracy:.4f}")
self.end_time = datetime.now()
logger.info(f"Evaluation completed at {self.end_time}")
Expand Down
Loading