Skip to content
10 changes: 5 additions & 5 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,10 @@ class MyCustomRankingTask(RankingTask):
"""Override default metrics if needed"""
return ["map", "mrr", "recall@5", "recall@10"]

def load_monolingual_data(self, split: DatasetSplit, language: Language) -> RankingDataset:
def load_dataset(self, dataset_id: str, split: DatasetSplit) -> RankingDataset:
"""
Load dataset for a specific language and split.
Load dataset for a specific dataset ID and split.

Returns:
RankingDataset with query_texts, target_indices, and target_space
"""
Expand All @@ -196,12 +196,12 @@ class MyCustomRankingTask(RankingTask):
[0, 2], # Software Engineer -> Python, SQL
[0, 1], # Data Scientist -> Python, Machine Learning
]

return RankingDataset(
query_texts=query_texts,
target_indices=target_indices,
target_space=target_space,
language=language,
dataset_id=dataset_id,
)
```

Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ Feel free to make a PR to add your models & tasks to the official package! See [

### Checkpointing & Resuming

WorkRB automatically saves result checkpoints after each task completion in a specific language.
WorkRB automatically saves result checkpoints after each dataset evaluation within a task.

**Automatic Resuming** - Simply rerun with the same `output_folder`:

Expand Down Expand Up @@ -207,12 +207,14 @@ lang_result_ci = summary["mean_per_language/en/f1_macro/ci_margin"]
| Job to Skills WorkBench | multi_label | 3039 queries x 13939 targets | 28 |
| Job Title Similarity | multi_label | 105 queries x 2619 targets | 11 |
| Job Normalization | single_label | 15463 queries x 2942 targets | 28 |
| Job Normalization MELO | multi_label | 633 queries x 33813 targets | 21 |
| Skill to Job WorkBench | multi_label | 13492 queries x 3039 targets | 28 |
| Skill Extraction House | multi_label | 262 queries x 13891 targets | 28 |
| Skill Extraction Tech | multi_label | 338 queries x 13891 targets | 28 |
| Skill Extraction SkillSkape | multi_label | 1191 queries x 13891 targets | 28 |
| Skill Similarity SkillMatch-1K | single_label | 900 queries x 2648 targets | 1 |
| Skill Normalization ESCO | multi_label | 72008 queries x 13939 targets | 28 |
| Skill Normalization MELS | multi_label | 1722 queries x 19466 targets | 5 |
| **Classification**
| Job-Skill Classification | multi_label | 3039 samples, 13939 classes | 28 |

Expand Down
3 changes: 3 additions & 0 deletions examples/custom_model_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from sentence_transformers import SentenceTransformer

import workrb
from workrb.models.base import ModelInterface
from workrb.registry import register_model
from workrb.types import ModelInputType
Expand Down Expand Up @@ -47,10 +48,12 @@ def __init__(
self.encoder.to(device)
self.encoder.eval()

@property
def name(self) -> str:
"""Return the unique name of this model."""
return f"MyCustomModel-{self.base_model_name.split('/')[-1]}"

@property
def description(self) -> str:
"""Return the description of this model."""
return "A custom model that demonstrates WorkRB extensibility"
Expand Down
7 changes: 4 additions & 3 deletions examples/custom_task_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
and implement the required abstract methods.
"""

import workrb
from workrb.registry import register_task
from workrb.tasks.abstract.base import DatasetSplit, LabelType, Language
from workrb.tasks.abstract.ranking_base import RankingDataset, RankingTaskGroup
Expand Down Expand Up @@ -78,14 +79,14 @@ def supported_target_languages(self) -> list[Language]:
"""Supported target languages are English."""
return [Language.EN]

def load_monolingual_data(self, language: Language, split: DatasetSplit) -> RankingDataset:
def load_dataset(self, dataset_id: str, split: DatasetSplit) -> RankingDataset:
"""
Load data for evaluation.

This method must return a RankingDataset.

Args:
language: Language code (e.g., "en", "de", "fr")
dataset_id: Dataset identifier (e.g., "en", "de", "fr" for language-based tasks)
split: Data split ("test", "validation", "train")

Returns
Expand Down Expand Up @@ -121,7 +122,7 @@ def load_monolingual_data(self, language: Language, split: DatasetSplit) -> Rank
query_texts=queries,
target_indices=labels,
target_space=targets,
language=language,
dataset_id=dataset_id,
)

# Note: The evaluate() method is inherited from RankingTask and doesn't need
Expand Down
2 changes: 2 additions & 0 deletions examples/run_multiple_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Reproduce benchmark results.
"""

import workrb

if __name__ == "__main__":
# 1. Setup model and tasks
models = [
Expand Down
12 changes: 6 additions & 6 deletions src/workrb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,24 +205,24 @@ def get_pending_work(
self,
results: BenchmarkResults | None,
tasks: Sequence[Task],
) -> list[tuple]:
) -> list[tuple[Task, str]]:
"""Determine what work still needs to be done.

Work is defined as a (task, language) combination that is not completed.
Work is defined as a (task, dataset_id) combination that is not completed.
"""
pending_work = []
for task in tasks:
for language in task.languages:
# Successful completed (task, language) combination
for dataset_id in task.dataset_ids:
# Successful completed (task, dataset_id) combination
if (
results is not None
and task.name in results.task_results
and language in results.task_results[task.name].language_results
and dataset_id in results.task_results[task.name].language_results
):
continue

# Add to pending work
pending_work.append((task, language))
pending_work.append((task, dataset_id))

return pending_work

Expand Down
32 changes: 22 additions & 10 deletions src/workrb/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,24 @@ class TaskResultMetadata(BaseModel):
class MetricsResult(BaseModel):
"""Metric results for a single evaluation run.

In the becnhmark, this is a single evaluation run for a single language.
In the benchmark, this is a single evaluation run for a single dataset.
"""

evaluation_time: float = Field(ge=0)
metrics_dict: dict[str, Any] = Field(default_factory=dict)
""" Dictionary of metric names to their computed values. """
language: str | None = Field(
default=None,
description="Language code if this is a monolingual dataset, None for cross-language datasets.",
)


class TaskResults(BaseModel):
"""Results for a task."""

metadata: TaskResultMetadata
language_results: dict[str, MetricsResult] # language -> results
""" Dictionary of language codes to their computed results. """
language_results: dict[str, MetricsResult] # dataset_id -> results
"""Dictionary of dataset IDs to their computed results."""


class BenchmarkMetadata(BaseModel):
Expand Down Expand Up @@ -292,15 +296,20 @@ def _aggregate_per_language(
) -> dict[ResultTagString, float]:
"""Aggregate results per language.

Collects language-specific results over all tasks, and aggregates all availble results.
Collects results for monolingual datasets and aggregates by language across all tasks.
Cross-language datasets (where language is None) are excluded from this aggregation.
Results may be imbalanced if tasks support different languages.
"""
# Collect metric values per task
# Collect metric values per language
raw_results = defaultdict(list)
for task_result in self.task_results.values():
for language, metrics_result in task_result.language_results.items():
for metrics_result in task_result.language_results.values():
# Skip cross-language datasets
if metrics_result.language is None:
continue

for metric_name, metric_value in metrics_result.metrics_dict.items():
raw_results[(language, metric_name)].append(metric_value)
raw_results[(metrics_result.language, metric_name)].append(metric_value)

# Compute stats
results = {}
Expand All @@ -309,7 +318,10 @@ def _aggregate_per_language(
for agg in aggregations:
assert agg in stats, f"Aggregation {agg} not found in stats: {stats.keys()}"
tag = ResultTagString(
name=tag_name, metric_name=metric_name, aggregation=agg, grouping_name=language
name=tag_name,
metric_name=metric_name,
aggregation=agg,
grouping_name=language,
)
results[tag] = stats[agg]
return results
Expand Down Expand Up @@ -340,7 +352,7 @@ def _get_flat_dataframe(self) -> pd.DataFrame:
"""Get flat dataframe of the benchmark results with each metric value as a separate row."""
data = []
for task_name, task_result in self.task_results.items():
for language, metrics_result in task_result.language_results.items():
for dataset_id, metrics_result in task_result.language_results.items():
for metric_name, metric_value in metrics_result.metrics_dict.items():
data.append(
{
Expand All @@ -349,7 +361,7 @@ def _get_flat_dataframe(self) -> pd.DataFrame:
"task_type": str(task_result.metadata.task_type),
# "task_label_type": str(task_result.metadata.label_type),
# "task_split": str(task_result.metadata.split),
"task_language": str(language),
"dataset_id": str(dataset_id),
"metric_name": str(metric_name),
"metric_value": float(metric_value),
}
Expand Down
44 changes: 22 additions & 22 deletions src/workrb/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
TaskResultMetadata,
TaskResults,
)
from workrb.tasks.abstract.base import Language, Task
from workrb.tasks.abstract.base import Task

logger = logging.getLogger(__name__)
setup_logger(__name__, verbose=False)
Expand Down Expand Up @@ -80,10 +80,10 @@ def evaluate(

# Group pending work by task for better organization
work_by_task = {}
for task, language in pending_work:
for task, dataset_id in pending_work:
if task.name not in work_by_task:
work_by_task[task.name] = {"task": task, "languages": []}
work_by_task[task.name]["languages"].append(language)
work_by_task[task.name] = {"task": task, "dataset_ids": []}
work_by_task[task.name]["dataset_ids"].append(dataset_id)

# Run pending work
start_time_benchmark = time.time()
Expand All @@ -101,7 +101,7 @@ def evaluate(
# Update metadata
results.metadata.total_evaluation_time = time.time() - start_time_benchmark
results.metadata.resumed_from_checkpoint = len(pending_work) < sum(
len(task.languages) for task in tasks
len(task.dataset_ids) for task in tasks
)

# Save config and results
Expand Down Expand Up @@ -206,11 +206,11 @@ def get_tasks_overview(tasks: Sequence[Task]) -> str:

lines.append(f"{task_name:<40} {group:<20} {task_languages:<20}")

# Add size one-liner for each language
for lang in task.languages:
size_info = task.get_size_oneliner(lang)
# Add size one-liner for each dataset
for dataset_id in task.dataset_ids:
size_info = task.get_size_oneliner(dataset_id)
if size_info:
lines.append(f" └─ {lang}: {size_info}")
lines.append(f" └─ {dataset_id}: {size_info}")

lines.append("-" * 80)

Expand All @@ -227,7 +227,7 @@ def _get_all_languages(tasks: Sequence[Task]) -> list[str]:

def _get_total_evaluations(tasks: Sequence[Task]) -> int:
"""Get the total number of evaluations."""
return sum(len(task.languages) for task in tasks)
return sum(len(task.dataset_ids) for task in tasks)


def _init_checkpointing(
Expand Down Expand Up @@ -307,12 +307,12 @@ def _run_pending_work(
run_idx = results.get_num_evaluation_results() # Already completed evaluations
for work_info in work_by_task.values():
task: Task = work_info["task"]
pending_languages: list[str] = work_info["languages"]
pending_dataset_ids: list[str] = work_info["dataset_ids"]

logger.info(f"{'=' * 60}")
logger.info(f"Evaluating task: {task.name}")
logger.info(f"Completed {run_idx} / {_get_total_evaluations(tasks)} evaluations. ")
logger.info(f"Pending languages for this task: {len(pending_languages)}")
logger.info(f"Pending datasets for this task: {len(pending_dataset_ids)}")

# Initialize task results if not exists
if task.name not in results.task_results:
Expand All @@ -327,11 +327,9 @@ def _run_pending_work(
language_results={},
)

# Evaluate pending languages
for language in pending_languages:
logger.info(
f"* Running language: {language} ({task.get_size_oneliner(Language(language))})"
)
# Evaluate pending datasets
for dataset_id in pending_dataset_ids:
logger.info(f"* Running dataset: {dataset_id} ({task.get_size_oneliner(dataset_id)})")

# Get metrics for this task
task_metrics = None
Expand All @@ -340,15 +338,17 @@ def _run_pending_work(

try:
start_time_eval = time.time()
lang_results: dict[str, float] = task.evaluate(
model=model, metrics=task_metrics, language=Language(language)
dataset_results: dict[str, float] = task.evaluate(
model=model, metrics=task_metrics, dataset_id=dataset_id
)
evaluation_time = time.time() - start_time_eval

# Store results
results.task_results[task.name].language_results[language] = MetricsResult(
dataset_language = task.get_dataset_language(dataset_id)
results.task_results[task.name].language_results[dataset_id] = MetricsResult(
evaluation_time=evaluation_time,
metrics_dict=lang_results,
metrics_dict=dataset_results,
language=dataset_language.value if dataset_language else None,
)

# Save incremental results to checkpoint
Expand All @@ -357,7 +357,7 @@ def _run_pending_work(

# Show key metrics
key_metric = task.default_metrics[0]
logger.info(f"\t{key_metric}: {lang_results[key_metric]:.3f}")
logger.info(f"\t{key_metric}: {dataset_results[key_metric]:.3f}")
run_idx += 1
except Exception as e:
logger.error(f"Error: {e}")
Expand Down
4 changes: 4 additions & 0 deletions src/workrb/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from .ranking.job2skill import ESCOJob2SkillRanking
from .ranking.job_similarity import JobTitleSimilarityRanking
from .ranking.jobnorm import JobBERTJobNormRanking
from .ranking.melo import MELORanking
from .ranking.mels import MELSRanking
from .ranking.skill2job import ESCOSkill2JobRanking
from .ranking.skill_extraction import (
HouseSkillExtractRanking,
Expand All @@ -35,6 +37,8 @@
"ESCOSkillNormRanking",
"JobBERTJobNormRanking",
"JobTitleSimilarityRanking",
"MELORanking",
"MELSRanking",
"HouseSkillExtractRanking",
"TechSkillExtractRanking",
"SkillSkapeExtractRanking",
Expand Down
Loading