Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions src/workrb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,9 @@
"""

from workrb import data, metrics, models, tasks
from workrb.evaluate import (
evaluate,
evaluate_multiple_models,
get_tasks_overview,
list_available_tasks,
load_results,
)
from workrb.registry import list_available_tasks
from workrb.results import load_results
from workrb.run import evaluate, evaluate_multiple_models, get_tasks_overview

__all__ = [
"data",
Expand Down
5 changes: 5 additions & 0 deletions src/workrb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def auto_discover(cls):
importlib.import_module(modname)


def list_available_tasks() -> dict[str, str]:
"""List all available task classes that can be used in configs."""
return TaskRegistry.list_available()


def register_task(name: str | None = None):
"""
Decorator registering a task class.
Expand Down
12 changes: 12 additions & 0 deletions src/workrb/results.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import pprint
from collections import defaultdict
from typing import Any
Expand Down Expand Up @@ -355,3 +356,14 @@ def _get_flat_dataframe(self) -> pd.DataFrame:
)

return pd.DataFrame(data)


def load_results(results_path: str = "./results.json") -> BenchmarkResults:
"""
Load results from specified folder.

Useful for external usage of the results, when only the folder is available.
"""
with open(results_path) as f:
data = json.load(f)
return BenchmarkResults.model_validate(data)
28 changes: 0 additions & 28 deletions src/workrb/evaluate.py → src/workrb/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
checkpointing, resuming, and efficient multi-model evaluation.
"""

import json
import logging
import time
from collections import Counter
Expand All @@ -16,7 +15,6 @@
from workrb.logging import setup_logger
from workrb.metrics.reporting import format_results
from workrb.models.base import ModelInterface
from workrb.registry import TaskRegistry
from workrb.results import (
BenchmarkMetadata,
BenchmarkResults,
Expand Down Expand Up @@ -219,22 +217,6 @@ def get_tasks_overview(tasks: Sequence[Task]) -> str:
return "\n".join(lines)


def load_results(results_path: str = "./results.json") -> BenchmarkResults:
"""
Load results from specified folder.

Useful for external usage of the results, when only the folder is available.
"""
with open(results_path) as f:
data = json.load(f)
return BenchmarkResults.model_validate(data)


def list_available_tasks() -> dict[str, str]:
"""List all available task classes that can be used in configs."""
return TaskRegistry.list_available()


def _get_all_languages(tasks: Sequence[Task]) -> list[str]:
"""Get all unique languages across tasks."""
languages = set()
Expand All @@ -248,16 +230,6 @@ def _get_total_evaluations(tasks: Sequence[Task]) -> int:
return sum(len(task.languages) for task in tasks)


def _validate_tasks(tasks: Sequence[Task]):
"""Validate that all tasks are properly configured."""
if not tasks:
raise ValueError("At least one task must be provided")

for task in tasks:
if not isinstance(task, Task):
raise TypeError(f"All tasks must inherit from Task, got {type(task)}")


def _init_checkpointing(
tasks: Sequence[Task],
config: BenchmarkConfig,
Expand Down
14 changes: 7 additions & 7 deletions tests/test_evaluate_multiple_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import torch

from tests.test_utils import create_toy_task_class
from workrb.evaluate import evaluate_multiple_models
from workrb.models.base import ModelInterface
from workrb.results import (
BenchmarkMetadata,
Expand All @@ -24,6 +23,7 @@
TaskResultMetadata,
TaskResults,
)
from workrb.run import evaluate_multiple_models
from workrb.tasks import SkillMatch1kSkillSimilarityRanking
from workrb.tasks.abstract.base import DatasetSplit, Language
from workrb.types import ModelInputType
Expand Down Expand Up @@ -118,7 +118,7 @@ def test_evaluate_multiple_models_basic():
task_name = task.name

# Mock the evaluate function
with patch("workrb.evaluate.evaluate") as mock_evaluate:
with patch("workrb.run.evaluate") as mock_evaluate:
# Set up return values for each model
mock_evaluate.side_effect = [
create_mock_results("model1", task_name),
Expand Down Expand Up @@ -165,7 +165,7 @@ def test_evaluate_multiple_models_with_additional_kwargs():
task = ToyTask(split=DatasetSplit.VAL, languages=[Language.EN])
task_name = task.name

with patch("workrb.evaluate.evaluate") as mock_evaluate:
with patch("workrb.run.evaluate") as mock_evaluate:
mock_evaluate.return_value = create_mock_results("test_model", task_name)

results = evaluate_multiple_models(
Expand Down Expand Up @@ -227,7 +227,7 @@ def test_evaluate_multiple_models_error_handling():
task = ToyTask(split=DatasetSplit.VAL, languages=[Language.EN])
task_name = task.name

with patch("workrb.evaluate.evaluate") as mock_evaluate:
with patch("workrb.run.evaluate") as mock_evaluate:
# First model succeeds, second fails
mock_evaluate.side_effect = [
create_mock_results("model1", task_name),
Expand Down Expand Up @@ -256,7 +256,7 @@ def test_evaluate_multiple_models_output_folder_overrides_kwargs():
task = ToyTask(split=DatasetSplit.VAL, languages=[Language.EN])
task_name = task.name

with patch("workrb.evaluate.evaluate") as mock_evaluate:
with patch("workrb.run.evaluate") as mock_evaluate:
mock_evaluate.side_effect = [
create_mock_results("model1", task_name),
create_mock_results("model2", task_name),
Expand Down Expand Up @@ -287,7 +287,7 @@ def test_evaluate_multiple_models_single_model():
task = ToyTask(split=DatasetSplit.VAL, languages=[Language.EN])
task_name = task.name

with patch("workrb.evaluate.evaluate") as mock_evaluate:
with patch("workrb.run.evaluate") as mock_evaluate:
mock_evaluate.return_value = create_mock_results("single_model", task_name)

results = evaluate_multiple_models(
Expand All @@ -307,7 +307,7 @@ def test_evaluate_multiple_models_empty_models_list():
ToyTask = create_toy_task_class(SkillMatch1kSkillSimilarityRanking)
task = ToyTask(split=DatasetSplit.VAL, languages=[Language.EN])

with patch("workrb.evaluate.evaluate") as mock_evaluate:
with patch("workrb.run.evaluate") as mock_evaluate:
with pytest.raises(AssertionError) as excinfo:
evaluate_multiple_models(
models=[],
Expand Down