diff --git a/src/gimbench/base.py b/src/gimbench/base.py index 9b9e5b4..af5b36d 100644 --- a/src/gimbench/base.py +++ b/src/gimbench/base.py @@ -95,8 +95,15 @@ def __init__(self, args: Namespace, dataset: Dataset): self.args = args @staticmethod - def _safe_average(items: list, attr: str) -> float: - values = [getattr(item, attr) for item in items if getattr(item, attr) != -1] + def _filter_non_error_items(items: list) -> list: + """Filter out items that have error messages.""" + return [item for item in items if not item.error_msg] + + @staticmethod + def _safe_average(items: list, attr: str, exclude_errors: bool = True) -> float: + """Calculate average of attribute, optionally excluding errored items and sentinel values (-1).""" + filtered_items = BaseEvaluator._filter_non_error_items(items) if exclude_errors else items + values = [getattr(item, attr) for item in filtered_items if getattr(item, attr) != -1] return sum(values) / len(values) if values else 0.0 def _log_progress(self, total: int, curr_idx: int, log_interval: int = 10) -> None: diff --git a/src/gimbench/cv/evaluators.py b/src/gimbench/cv/evaluators.py index 6ec9f6c..b169c93 100644 --- a/src/gimbench/cv/evaluators.py +++ b/src/gimbench/cv/evaluators.py @@ -139,7 +139,7 @@ def evaluate(self) -> EvalResult: evaled_items = list(tqdm(results, total=total, desc=f"Evaluating {self.args.model_name}")) # TODO: Add progress logging for multi-threaded evaluation - non_error_items = [item for item in evaled_items if not item.error_msg] + non_error_items = self._filter_non_error_items(evaled_items) errors = sum(1 for item in evaled_items if item.error_msg) total_fields = sum(item.num_fields for item in non_error_items) total_correct = sum(item.num_correct for item in non_error_items) diff --git a/src/gimbench/match/evaluators.py b/src/gimbench/match/evaluators.py index a87ae27..9040ddb 100644 --- a/src/gimbench/match/evaluators.py +++ b/src/gimbench/match/evaluators.py @@ -109,13 +109,14 @@ def evaluate(self) -> EvalResult: self.end_time = datetime.now() logger.info(f"Evaluation completed at {self.end_time}") + non_error_items = self._filter_non_error_items(evaled_items) total_tags = sum(item.num_tags for item in evaled_items) - valid_tags = sum(item.num_tags for item in evaled_items if not item.error_msg) - total_has_prediction = sum(item.num_has_prediction for item in evaled_items if not item.error_msg) + valid_tags = sum(item.num_tags for item in non_error_items) + total_has_prediction = sum(item.num_has_prediction for item in non_error_items) total_regex = sum(item.num_regex for item in evaled_items) - valid_regex = sum(item.num_regex for item in evaled_items if not item.error_msg) - total_regex_match = sum(item.num_regex_match for item in evaled_items if not item.error_msg) + valid_regex = sum(item.num_regex for item in non_error_items) + total_regex_match = sum(item.num_regex_match for item in non_error_items) return EvalResult( args=self.args, start_time=self.start_time, diff --git a/src/gimbench/mcqa/evaluators.py b/src/gimbench/mcqa/evaluators.py index 78c95d3..431c09d 100644 --- a/src/gimbench/mcqa/evaluators.py +++ b/src/gimbench/mcqa/evaluators.py @@ -135,11 +135,12 @@ def evaluate(self) -> EvalResult: evaled_items = list(tqdm(results, total=total, desc=f"Evaluating {self.args.model_name}")) # TODO: Add progress logging for multi-threaded evaluation + non_error_items = self._filter_non_error_items(evaled_items) errors = sum(1 for item in evaled_items if item.error_msg) corrects = sum(1 for item in evaled_items if item.conclusion) evaluates = len(evaled_items) accuracy = corrects / evaluates if evaluates > 0 else 0.0 - calibrated_accuracy = corrects / (evaluates - errors) if (evaluates - errors) > 0 else 0.0 + calibrated_accuracy = corrects / len(non_error_items) if non_error_items else 0.0 logger.info(f"Final accuracy over {total} examples: {corrects}/{total} = {accuracy:.4f}") self.end_time = datetime.now() logger.info(f"Evaluation completed at {self.end_time}")