diff --git a/.github/workflows/build-and-push-apptainer.yml b/.github/workflows/build-and-push-apptainer.yml index dd2369a..0d1c8dd 100644 --- a/.github/workflows/build-and-push-apptainer.yml +++ b/.github/workflows/build-and-push-apptainer.yml @@ -37,7 +37,7 @@ jobs: - name: Build SIF from definition file run: | - apptainer --verbose build --fakeroot eval_env-${{ matrix.image }}.sif apptainer/${{ matrix.image }}.def + apptainer --verbose build --mksquashfs-args="-comp gzip -Xcompression-level 1" --fakeroot eval_env-${{ matrix.image }}.sif apptainer/${{ matrix.image }}.def - name: Install Hugging Face Hub CLI run: pip install --upgrade "huggingface_hub" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6cf30a0..0d9fa48 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,7 +14,7 @@ jobs: - uses: actions/checkout@v4 - name: Install uv - uses: astral-sh/setup-uv@v3 + uses: astral-sh/setup-uv@v7 with: version: "latest" @@ -40,7 +40,7 @@ jobs: - uses: actions/checkout@v4 - name: Install uv - uses: astral-sh/setup-uv@v3 + uses: astral-sh/setup-uv@v7 with: version: "latest" diff --git a/.gitignore b/.gitignore index 77fc697..b897fcc 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ **/*.egg-info **/*.csv **/uv.lock +**/task_map_cache.json diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5d48d2d..f47629e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v6.0.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -18,10 +18,3 @@ repos: - id: ruff args: [--fix, --exit-non-zero-on-fix] - id: ruff-format - - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.8.0 - hooks: - - id: mypy - additional_dependencies: [types-PyYAML] - args: [--ignore-missing-imports] diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..7e64337 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,5 @@ +Rules: +- no try...Except unless absolutely necessary +- no unnecessary comments +- don't worry about tests +- if you need to run stuff, assume there is a .venv at the root of the project. you can also just use uv diff --git a/README.md b/README.md index 78c1649..f0f568a 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ A package for running OELLM CLI workflows across multiple HPC clusters using SLU - Restart failed evaluations (e.g., due to node failures) ✅ `oellm collect-results ... --reschedule true` - Interactive eval job/csv builder ✅ `oellm build-csv` - Recursively resolve local paths: pass a directory containing models and their nested intermediate checkpoints, will eval all checkpoints - - Support default task groups (cf `oellm/task-groups.yaml`) + - Support default task groups (cf `oellm/resources/task-groups.yaml`) ## Planned workflows - Sync and download evaluation results from all clusters via a shared data layer @@ -21,7 +21,7 @@ A package for running OELLM CLI workflows across multiple HPC clusters using SLU ```bash # Install the package -uv tool install --python 3.12 git+https://github.com/OpenEuroLLM/oellm-cli.git +uv tool install git+https://github.com/OpenEuroLLM/oellm-cli.git # Run evaluations on multiple models and tasks oellm schedule-eval \ @@ -50,6 +50,10 @@ This will launch an interactive workflow where you can: - Configure n-shot settings - Preview and save your evaluation configuration +The resulting CSV includes an additional `eval_suite` column that records which +evaluation framework (e.g., `lm_eval` or `lighteval`) should be used for each +task. + Otherwise you can also directly schedule using a CSV file: ```bash oellm schedule-eval --eval_csv_path custom_evals.csv @@ -104,7 +108,7 @@ The `oellm` package orchestrates distributed LLM evaluations through the followi ### 1. **Cluster Auto-Detection** - Automatically detects the current HPC cluster based on hostname patterns -- Loads cluster-specific configurations from [`clusters.yaml`](oellm/clusters.yaml) including: +- Loads cluster-specific configurations from [`clusters.yaml`](oellm/resources/clusters.yaml) including: - SLURM partition and account settings - Shared storage paths for models, datasets, and results - GPU allocation and queue limits diff --git a/apptainer/jureca.def b/apptainer/jureca.def index bfdf18a..7f088ad 100644 --- a/apptainer/jureca.def +++ b/apptainer/jureca.def @@ -2,24 +2,37 @@ Bootstrap: docker From: nvcr.io/nvidia/pytorch:25.06-py3 %labels - Author multi-cluster-eval - Description Apptainer image for JURECA cluster (converted from dockerfile) + Author oellm-cli + Description Apptainer image for JURECA JSC cluster %post - # 1. Install uv package manager - curl -LsSf https://astral.sh/uv/install.sh | sh - echo 'export PATH=$HOME/.local/bin:$PATH' >> /etc/profile + # Install uv into a global bin + curl -LsSf https://astral.sh/uv/install.sh | env UV_INSTALL_DIR=/usr/local/bin sh - # Make uv visible for subsequent commands during build - export PATH=/root/.local/bin:$PATH + # Put uv-installed tool shims in a global bin too + export UV_TOOL_BIN_DIR=/usr/local/bin + uv --version - # 2. Install Python dependencies uv pip install --system --break-system-packages lm-eval \ "transformers<=4.53.0" "datasets<4.0.0" wandb sentencepiece tiktoken accelerate + # Optional: keep tool envs under /opt to avoid $HOME + export UV_TOOL_DIR=/opt/uv-tools + uv tool install --python 3.12 "lighteval[multilingual] @ git+https://github.com/huggingface/lighteval.git@63424f4e795ecc577b90646381b374af3a627978" + uv pip install --system --break-system-packages nltk + mkdir -p /opt/nltk_data + python - <<'PY' +import nltk +nltk.download('punkt', download_dir='/opt/nltk_data') +nltk.download('punkt_tab', download_dir='/opt/nltk_data') +PY + %environment - # Ensure uv is present inside the container runtime as well - export PATH=/root/.local/bin:$PATH + export PATH=/usr/local/bin:$PATH + export UV_TOOL_BIN_DIR=/usr/local/bin + export UV_TOOL_DIR=/opt/uv-tools + export NLTK_DATA=/opt/nltk_data + %runscript exec bash "$@" diff --git a/apptainer/leonardo.def b/apptainer/leonardo.def index 570ae2a..f61f282 100644 --- a/apptainer/leonardo.def +++ b/apptainer/leonardo.def @@ -2,24 +2,36 @@ Bootstrap: docker From: nvcr.io/nvidia/pytorch:25.06-py3 %labels - Author multi-cluster-eval - Description Apptainer image for Leonardo cluster (converted from dockerfile) + Author oellm-cli + Description Apptainer image for Leonardo cluster %post - # 1. Install uv package manager - curl -LsSf https://astral.sh/uv/install.sh | sh - echo 'export PATH=$HOME/.local/bin:$PATH' >> /etc/profile + # Install uv into a global bin + curl -LsSf https://astral.sh/uv/install.sh | env UV_INSTALL_DIR=/usr/local/bin sh - # Make uv visible for subsequent commands during build - export PATH=/root/.local/bin:$PATH + # Put uv-installed tool shims in a global bin too + export UV_TOOL_BIN_DIR=/usr/local/bin + uv --version - # 2. Install Python dependencies uv pip install --system --break-system-packages lm-eval \ "transformers<=4.53.0" "datasets<4.0.0" wandb sentencepiece tiktoken accelerate + # Optional: keep tool envs under /opt to avoid $HOME + export UV_TOOL_DIR=/opt/uv-tools + uv tool install --python 3.12 "lighteval[multilingual] @ git+https://github.com/huggingface/lighteval.git@63424f4e795ecc577b90646381b374af3a627978" + uv pip install --system --break-system-packages nltk + mkdir -p /opt/nltk_data + python - <<'PY' +import nltk +nltk.download('punkt', download_dir='/opt/nltk_data') +nltk.download('punkt_tab', download_dir='/opt/nltk_data') +PY + %environment - # Ensure uv is present inside the container runtime as well - export PATH=/root/.local/bin:$PATH + export PATH=/usr/local/bin:$PATH + export UV_TOOL_BIN_DIR=/usr/local/bin + export UV_TOOL_DIR=/opt/uv-tools + export NLTK_DATA=/opt/nltk_data %runscript exec bash "$@" diff --git a/apptainer/lumi.def b/apptainer/lumi.def index 52c042d..c19f85f 100644 --- a/apptainer/lumi.def +++ b/apptainer/lumi.def @@ -2,24 +2,36 @@ Bootstrap: docker From: rocm/pytorch:rocm6.4.1_ubuntu24.04_py3.12_pytorch_release_2.7.1 %labels - Author multi-cluster-eval - Description Apptainer image for LUMI cluster (converted from dockerfile) + Author oellm-cli + Description Apptainer image for LUMI cluster %post - # 1. Install uv package manager - curl -LsSf https://astral.sh/uv/install.sh | sh - echo 'export PATH=$HOME/.local/bin:$PATH' >> /etc/profile + # Install uv into a global bin + curl -LsSf https://astral.sh/uv/install.sh | env UV_INSTALL_DIR=/usr/local/bin sh - # Make uv visible for subsequent commands during build - export PATH=/root/.local/bin:$PATH + # Put uv-installed tool shims in a global bin too + export UV_TOOL_BIN_DIR=/usr/local/bin + uv --version - # 2. Install Python dependencies uv pip install --system --break-system-packages lm-eval \ "transformers<=4.53.0" "datasets<4.0.0" wandb sentencepiece tiktoken accelerate + # Optional: keep tool envs under /opt to avoid $HOME + export UV_TOOL_DIR=/opt/uv-tools + uv tool install --python 3.12 "lighteval[multilingual] @ git+https://github.com/huggingface/lighteval.git@63424f4e795ecc577b90646381b374af3a627978" + uv pip install --system --break-system-packages nltk + mkdir -p /opt/nltk_data + python - <<'PY' +import nltk +nltk.download('punkt', download_dir='/opt/nltk_data') +nltk.download('punkt_tab', download_dir='/opt/nltk_data') +PY + %environment - # Ensure uv is present inside the container runtime as well - export PATH=/root/.local/bin:$PATH + export PATH=/usr/local/bin:$PATH + export UV_TOOL_BIN_DIR=/usr/local/bin + export UV_TOOL_DIR=/opt/uv-tools + export NLTK_DATA=/opt/nltk_data %runscript exec bash "$@" diff --git a/oellm/interactive_csv_builder.py b/oellm/interactive_csv_builder.py index b27b097..61c99f1 100644 --- a/oellm/interactive_csv_builder.py +++ b/oellm/interactive_csv_builder.py @@ -1,5 +1,6 @@ import signal import sys +from importlib.resources import files from pathlib import Path import pandas as pd @@ -12,8 +13,6 @@ from rich.progress import Progress, SpinnerColumn, TextColumn from rich.table import Table -from .task_groups import resolve_task_group - def build_csv_interactive(output_path: str = "eval_config.csv") -> None: """ @@ -60,11 +59,9 @@ def signal_handler(sig, frame): # Step 1: Get models with enhanced input console.print("\n[bold cyan]📦 Step 1: Add Models[/bold cyan]") - models: list[str] = [] + models = [] add_more = True - existing_group_entries: set[tuple[str, tuple[int, ...]]] = set() - while add_more: try: action = questionary.select( @@ -119,19 +116,16 @@ def signal_handler(sig, frame): # Step 2: Configure tasks console.print("\n[bold cyan]📝 Step 2: Configure Tasks[/bold cyan]") - task_configs: list[tuple[str, list[int]]] = [] + task_configs: list[tuple[str, list[int], str]] = [] add_more = True - # Load task groups from YAML file - task_groups_file = Path(__file__).parent / "task-groups.yaml" + # Load task groups from packaged resources task_groups = {} - if task_groups_file.exists(): - try: - with open(task_groups_file) as f: - data = yaml.safe_load(f) - task_groups = data.get("task_groups", {}) - except Exception as e: - console.print(f"[yellow]Warning: Could not load task groups: {e}[/yellow]") + try: + data = yaml.safe_load((files("oellm.resources") / "task-groups.yaml").read_text()) + task_groups = data.get("task_groups", {}) + except Exception as e: + console.print(f"[yellow]Warning: Could not load task groups: {e}[/yellow]") while add_more: choices = [ @@ -190,34 +184,16 @@ def signal_handler(sig, frame): # Add tasks from selected groups for selection in selected_groups: group_name = selection.split(" - ")[0] - try: - group_tasks = resolve_task_group( - group_name, task_groups, console - ) - except ValueError as exc: - console.print(f"[red]{exc}[/red]") - continue - - if not group_tasks: - console.print( - f"[yellow]No tasks found for task group '{group_name}'.[/yellow]" - ) - continue + group_data = task_groups[group_name] console.print(f"\n[cyan]Adding tasks from '{group_name}':[/cyan]") - for task_item in group_tasks: + for task_item in group_data.get("tasks", []): task_name = task_item["task"] n_shots = task_item.get("n_shots", [0]) - entry_key = (task_name, tuple(n_shots)) - if entry_key in existing_group_entries: - console.print( - f" [yellow]• Skipping duplicate: {task_name} with n_shot={n_shots}[/yellow]" - ) - continue - existing_group_entries.add(entry_key) - task_configs.append((task_name, n_shots)) + suite = task_item.get("suite", "lm_eval") + task_configs.append((task_name, n_shots, suite)) console.print( - f" [green]✓ Added: {task_name} with n_shot={n_shots}[/green]" + f" [green]✓ Added: {task_name} (suite={suite}) with n_shot={n_shots}[/green]" ) # After adding task groups, ask if user wants to add more or proceed @@ -282,19 +258,53 @@ def signal_handler(sig, frame): try: n_shots = [int(x.strip()) for x in n_shots_str.split(",")] - entry_key = (task, tuple(n_shots)) - existing_group_entries.add(entry_key) - task_configs.append((task, n_shots)) + suite_choice = questionary.select( + f"Select evaluation suite for '{task}':", + choices=[ + questionary.Choice( + "lm_eval (lm-eval-harness)", value="lm_eval" + ), + questionary.Choice( + "lighteval (Hugging Face LightEval)", + value="lighteval", + ), + "📝 Custom suite", + ], + style=custom_style, + ).ask() + + if suite_choice is None: + console.print("\n[yellow]Cancelled by user.[/yellow]") + return + + if suite_choice == "📝 Custom suite": + suite = questionary.text( + "Enter suite identifier:", + instruction="(e.g., custom-eval-suite)", + style=custom_style, + ).ask() + if suite is None: + console.print("\n[yellow]Cancelled by user.[/yellow]") + return + suite = suite.strip() + if not suite: + suite = "lm_eval" + else: + suite = suite_choice + + task_configs.append((task, n_shots, suite)) console.print( - f"[green]✓ Added: {task} with n_shot={n_shots}[/green]" + f"[green]✓ Added: {task} (suite={suite}) with n_shot={n_shots}[/green]" ) except ValueError: console.print("[red]Invalid n_shot values. Skipping.[/red]") elif action == "📋 View current tasks": console.print("\n[bold]Current tasks:[/bold]") - for i, (task, n_shots) in enumerate(task_configs, 1): - console.print(f" {i}. [green]{task}[/green] → n_shot={n_shots}") + for i, (task, n_shots, suite) in enumerate(task_configs, 1): + console.print( + f" {i}. [green]{task}[/green] → n_shot={n_shots} (suite={suite})" + ) console.print() elif action == "✅ Continue to preview": @@ -310,10 +320,15 @@ def signal_handler(sig, frame): rows = [] for model in models: - for task_name, n_shots in task_configs: + for task_name, n_shots, suite in task_configs: for n_shot in n_shots: rows.append( - {"model_path": model, "task_path": task_name, "n_shot": n_shot} + { + "model_path": model, + "task_path": task_name, + "n_shot": n_shot, + "eval_suite": suite, + } ) df = pd.DataFrame(rows) @@ -327,11 +342,16 @@ def signal_handler(sig, frame): table.add_column("Model", style="cyan", no_wrap=True) table.add_column("Task", style="green") table.add_column("n_shot", justify="right", style="yellow") + table.add_column("Suite", style="magenta") # Show first 10 rows for idx, (_, row) in enumerate(df.head(10).iterrows(), 1): table.add_row( - str(idx), str(row["model_path"]), str(row["task_path"]), str(row["n_shot"]) + str(idx), + str(row["model_path"]), + str(row["task_path"]), + str(row["n_shot"]), + str(row["eval_suite"]), ) if len(df) > 10: diff --git a/oellm/main.py b/oellm/main.py index c718837..c1ffac2 100644 --- a/oellm/main.py +++ b/oellm/main.py @@ -1,409 +1,46 @@ import logging import os import re -import socket import subprocess +from dataclasses import dataclass from datetime import datetime -from itertools import product +from importlib.resources import files from pathlib import Path from string import Template -from typing import Iterable import numpy as np import pandas as pd -import yaml from jsonargparse import auto_cli -from rich import box -from rich.console import Console -from rich.logging import RichHandler - -from .task_groups import flatten_task_groups, load_task_groups - - -def ensure_singularity_image(image_name: str) -> None: - # TODO: switch to OELLM dataset repo once it is created - from huggingface_hub import hf_hub_download - - hf_repo = os.environ.get("HF_SIF_REPO", "openeurollm/evaluation_singularity_images") - eval_base_dir = os.getenv("EVAL_BASE_DIR") - if eval_base_dir is None: - raise ValueError( - "EVAL_BASE_DIR is not set. Please configure it in clusters.yaml or the environment." - ) - - image_path = Path(eval_base_dir) / image_name - - try: - hf_hub_download( - repo_id=hf_repo, - filename=image_name, - repo_type="dataset", - local_dir=eval_base_dir, - ) - logging.info( - "Successfully downloaded latest Singularity image from HuggingFace" - ) - except Exception as e: - logging.warning( - "Failed to fetch latest container image from HuggingFace: %s", str(e) - ) - if image_path.exists(): - logging.info("Using existing Singularity image at %s", image_path) - else: - raise RuntimeError( - f"No container image found at {image_path} and failed to download from HuggingFace. " - f"Cannot proceed with evaluation scheduling." - ) from e - - logging.info( - "Singularity image ready at %s", - Path(eval_base_dir) / image_name, - ) - - -def _setup_logging(verbose: bool = False): - rich_handler = RichHandler( - console=Console(), - show_time=True, - log_time_format="%H:%M:%S", - show_path=False, - markup=True, - rich_tracebacks=True, - ) - - class RichFormatter(logging.Formatter): - def format(self, record): - # Define colors for different log levels - record.msg = f"{record.getMessage()}" - return record.msg - - rich_handler.setFormatter(RichFormatter()) - - root_logger = logging.getLogger() - root_logger.handlers = [] # Remove any default handlers - root_logger.addHandler(rich_handler) - root_logger.setLevel(logging.DEBUG if verbose else logging.INFO) - - -def _load_cluster_env() -> None: - """ - Loads the correct cluster environment variables from `clusters.yaml` based on the hostname. - """ - with open(Path(__file__).parent / "clusters.yaml") as f: - clusters = yaml.safe_load(f) - hostname = socket.gethostname() - - # First load shared environment variables - shared_cfg = clusters.get("shared", {}) - - # match hostname to the regex in the clusters.yaml - for host in set(clusters.keys()) - {"shared"}: - pattern = clusters[host]["hostname_pattern"] - # Convert shell-style wildcards to regex - regex_pattern = pattern.replace(".", r"\.").replace("*", ".*") - if re.match(f"^{regex_pattern}$", hostname): - cluster_cfg = clusters[host] - break - else: - raise ValueError(f"No cluster found for hostname: {hostname}") - - # Combine shared and cluster-specific configs, with cluster-specific taking precedence - # Remove hostname_pattern from the final config - if "hostname_pattern" in cluster_cfg: - del cluster_cfg["hostname_pattern"] - - # Set environment variables, expanding any template variables - for k, v in cluster_cfg.items(): - # Expand template variables using existing environment variables - os.environ[k] = str(v) - - for k, v in shared_cfg.items(): - try: - os.environ[k] = str(v).format(**cluster_cfg) - except KeyError as e: - # when substituting env vars that are not in cluster_cfg but in the environment (e.g., $USER, $SHELL, etc...) - if len(e.args) > 1: - raise ValueError( - f"Env. variable substitution for {k} failed. Missing keys: {', '.join(e.args)}" - ) from e - - missing_key: str = e.args[0] - os.environ[k] = str(v).format( - **cluster_cfg, **{missing_key: os.environ[missing_key]} - ) - - -def _num_jobs_in_queue() -> int: - # TODO avoid running in shell mode which is not secure - result = subprocess.run( - "squeue -u $USER -h -t pending,running -r | wc -l", - shell=True, - capture_output=True, - text=True, - ) - - if result.stdout: - try: - return int(result.stdout.strip()) - except ValueError: - logging.warning(f"Could not parse squeue output: {result.stdout}") - return 0 - - if result.stderr: - logging.warning(f"squeue command produced an error: {result.stderr.strip()}") - - return 0 - - -def _expand_local_model_paths(model: str) -> list[Path]: - """ - Expands a local model path to include all checkpoints if it's a directory. - Recursively searches for models in subdirectories. - - Args: - model: Path to a model or directory containing models - - Returns: - List of paths to model directories containing safetensors files - """ - model_paths: list[Path] = [] - model_path = Path(model) - - if not model_path.exists() or not model_path.is_dir(): - return model_paths - - # First check if current directory contains safetensors files - if any(model_path.glob("*.safetensors")): - model_paths.append(model_path) - # If current dir has safetensors, don't recurse further - return model_paths - - # Check for hf subdirectory pattern (single model with checkpoints) - hf_path = model_path / "hf" - if hf_path.exists() and hf_path.is_dir(): - # This is a single model with checkpoints in hf/iter_* structure - for subdir in hf_path.glob("*"): - if subdir.is_dir() and any(subdir.glob("*.safetensors")): - model_paths.append(subdir) - if model_paths: - return model_paths - - # Check if subdirectories look like model directories - # (e.g., open-sci-ref_model-0.13b_data-c4_...) - subdirs = [d for d in model_path.iterdir() if d.is_dir()] - - # Process each subdirectory as a potential model - for subdir in subdirs: - # Check if this subdirectory directly contains safetensors - if any(subdir.glob("*.safetensors")): - model_paths.append(subdir) - else: - # Check for hf/iter_* pattern in this subdirectory - hf_subpath = subdir / "hf" - if hf_subpath.exists() and hf_subpath.is_dir(): - for checkpoint_dir in hf_subpath.glob("*"): - if checkpoint_dir.is_dir() and any( - checkpoint_dir.glob("*.safetensors") - ): - model_paths.append(checkpoint_dir) - - if len(model_paths) > 1: - logging.info(f"Expanded '{model}' to {len(model_paths)} model checkpoints") - - return model_paths - - -def _process_model_paths(models: Iterable[str]) -> dict[str, list[str]]: - """ - Processes model strings into a dict of model paths. - - Each model string can be a local path or a huggingface model identifier. - This function expands directory paths that contain multiple checkpoints. - """ - from huggingface_hub import snapshot_download - - processed_model_paths: dict[str, list[str]] = {} - for model in models: - model_paths: list[str] = [] - # First try to expand local paths - local_paths = _expand_local_model_paths(model) - if local_paths: - model_paths.extend(str(path) for path in local_paths) - else: - logging.info( - f"Model {model} not found locally, assuming it is a 🤗 hub model" - ) - logging.debug( - f"Downloading model {model} on the login node since the compute nodes may not have access to the internet" - ) - - if "," in model: - model_kwargs = dict( - [kv.split("=") for kv in model.split(",") if "=" in kv] - ) - - # The first element before the comma is the repository ID on the 🤗 Hub - repo_id = model.split(",")[0] - - # snapshot_download kwargs - snapshot_kwargs = {} - if "revision" in model_kwargs: - snapshot_kwargs["revision"] = model_kwargs["revision"] - - try: - # Pre-download (or reuse cache) for the whole repository so that - # compute nodes can load it offline. - hf_home = os.getenv("HF_HOME") - if hf_home is None: - raise ValueError( - "HF_HOME is not set. Please configure it before scheduling evals." - ) - cache_dir = Path(hf_home) / "hub" - - snapshot_download( - repo_id=repo_id, - cache_dir=cache_dir, - **snapshot_kwargs, - ) - model_paths.append(model) - except Exception as e: - logging.debug( - f"Failed to download model {model} from Hugging Face Hub. Continuing..." - ) - logging.debug(e) - else: - # Download the entire model repository to the local cache. The - # original identifier is kept in *model_paths* so downstream - # code can still reference it; at runtime the files will be - # read from cache, allowing offline execution. - hf_home = os.getenv("HF_HOME") - if hf_home is None: - raise ValueError( - "HF_HOME is not set. Please configure it before scheduling evals." - ) - cache_dir = Path(hf_home) / "hub" - - snapshot_download( - repo_id=model, - cache_dir=cache_dir, - ) - model_paths.append(model) - - if not model_paths: - logging.warning( - f"Could not find any valid model for '{model}'. It will be skipped." - ) - processed_model_paths[model] = model_paths - return processed_model_paths - - -def _count_task_subtasks(task_name: str, task_manager) -> int: - from lm_eval.evaluator_utils import get_subtask_list # type: ignore - - task_objects = task_manager.load_task_or_group(task_name) - subtask_dict = get_subtask_list(task_objects) - - total_subtasks = 0 - for _, subtask_list in subtask_dict.items(): - total_subtasks += len(subtask_list) - - return max(1, total_subtasks) # At least 1 subtask - - -def _calculate_task_minutes( - task_name: str, task_manager, base_minutes_per_subtask: int = 5 -) -> int: - """Calculate estimated minutes for a task based on its subtask count.""" - subtask_count = _count_task_subtasks(task_name, task_manager) - - # Special handling for known multi-language tasks that take longer per subtask - known_complex_tasks = { - "belebele": 8, # Multi-language reading comprehension, slower per subtask - "flores": 6, # Translation task, moderately complex - "xnli": 6, # Cross-lingual NLI - "xcopa": 6, # Cross-lingual COPA - "xstory_cloze": 6, # Cross-lingual story cloze - "paws-x": 6, # Cross-lingual paraphrase detection - "hellaswag": 20, # Hellaswag task, needs 20 minutes per subtask - } - - # Use task-specific timing if available, otherwise use default - minutes_per_subtask = known_complex_tasks.get( - task_name.lower(), base_minutes_per_subtask - ) - - # Calculate total time: (subtasks × time_per_subtask) + base_overhead - base_overhead = 3 # Base overhead for task setup/teardown - total_minutes = max(10, (subtask_count * minutes_per_subtask) + base_overhead) - - # Log for complex tasks (>5 subtasks) or any known complex task - if subtask_count > 5 or task_name.lower() in known_complex_tasks: - complexity_note = ( - f" (known complex task, {minutes_per_subtask} min/subtask)" - if task_name.lower() in known_complex_tasks - else "" - ) - logging.info( - f"📊 Task '{task_name}' has {subtask_count} subtasks{complexity_note}, " - f"estimated time: {total_minutes} minutes ({total_minutes / 60:.1f} hours)" - ) - - return total_minutes - - -def _pre_download_task_datasets( - tasks: Iterable[str], trust_remote_code: bool = True -) -> None: - """Ensure that all datasets required by the given `tasks` are present in the local 🤗 cache at $HF_HOME.""" - - from datasets import DownloadMode # type: ignore - from lm_eval.tasks import TaskManager # type: ignore - - processed: set[str] = set() - - tm = TaskManager() - - for task_name in tasks: - if not isinstance(task_name, str) or task_name in processed: - continue - processed.add(task_name) - - logging.info( - f"Preparing dataset for task '{task_name}' (download if not cached)…" - ) - - # Instantiating the task downloads the dataset (or reuses cache) - - task_config = { - "task": task_name, - "dataset_kwargs": {"trust_remote_code": trust_remote_code}, - } - - task_objects = tm.load_config(task_config) - - # Some entries might be nested dictionaries (e.g., groups) - stack = [task_objects] - while stack: - current = stack.pop() - if isinstance(current, dict): - stack.extend(current.values()) - continue - if hasattr(current, "download") and callable(current.download): - try: - current.download(download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) # type: ignore[arg-type] - except TypeError as e: - logging.error( - f"Failed to download dataset for task '{task_name}' with download_mode=REUSE_DATASET_IF_EXISTS: {e}" - ) - current.download() # type: ignore[misc] - - logging.debug(f"Finished dataset preparation for task '{task_name}'.") - +from oellm.task_cache import clear_task_cache +from oellm.task_groups import _expand_task_groups +from oellm.utils import ( + _ensure_singularity_image, + _expand_local_model_paths, + _filter_warnings, + _load_cluster_env, + _num_jobs_in_queue, + _pre_download_lighteval_datasets, + _pre_download_task_datasets, + _process_model_paths, + _setup_logging, + capture_third_party_output_from_kwarg, +) + + +@dataclass +class EvaluationJob: + model_path: Path | str + task_path: str + n_shot: int + eval_suite: str + + +@capture_third_party_output_from_kwarg("verbose") def schedule_evals( models: str | None = None, tasks: str | None = None, - task_group: str | None = None, + task_groups: str | None = None, n_shot: int | list[int] | None = None, eval_csv_path: str | None = None, *, @@ -427,13 +64,13 @@ def schedule_evals( all models in subdirectories will be automatically discovered - For each model directory, if it has an `hf/iter_XXXXX` structure, all checkpoints will be expanded - This allows passing a single directory containing multiple models to evaluate them all - tasks: A string of comma-separated task paths. - task_group: Name(s) of task groups defined in ``task-groups.yaml`` to expand. - Multiple groups can be provided as a comma-separated string. Cannot be used - together with ``tasks`` or ``n_shot``. - n_shot: An integer or list of integers specifying the number of shots for each task. + tasks: A string of comma-separated task names (lm_eval) or paths. + Requires `n_shot` to be provided. Tasks here are assumed to be lm_eval unless otherwise handled via CSV. + task_groups: A string of comma-separated task group names defined in `task-groups.yaml`. + Each group expands into concrete (task, n_shots, suite) entries; `n_shot` is ignored for groups. + n_shot: An integer or list of integers specifying the number of shots applied to `tasks`. eval_csv_path: A path to a CSV file containing evaluation data. - Warning: exclusive argument. Cannot specify `models`, `tasks`, or `n_shot` when `eval_csv_path` is provided. + Warning: exclusive argument. Cannot specify `models`, `tasks`, `task_groups`, or `n_shot` when `eval_csv_path` is provided. max_array_len: The maximum number of jobs to schedule to run concurrently. Warning: this is not the number of jobs in the array job. This is determined by the environment variable `QUEUE_LIMIT`. download_only: If True, only download the datasets and models and exit. @@ -447,20 +84,24 @@ def schedule_evals( _load_cluster_env() if not skip_checks: - image_name = os.environ.get("EVAL_CONTAINER_IMAGE") - if image_name is None: - raise ValueError( - "EVAL_CONTAINER_IMAGE is not set. Please set it in clusters.yaml." - ) - - ensure_singularity_image(image_name) + _ensure_singularity_image(os.environ.get("EVAL_CONTAINER_IMAGE")) # type: ignore else: logging.info("Skipping container image check (--skip-checks enabled)") + if isinstance(models, str) and models is not None: + models = [m.strip() for m in models.split(",") if m.strip()] # type: ignore + + if isinstance(tasks, str) and tasks is not None: + tasks = [t.strip() for t in tasks.split(",") if t.strip()] # type: ignore + + if isinstance(n_shot, int) and n_shot is not None: + n_shot = [n_shot] + + eval_jobs: list[EvaluationJob] = [] if eval_csv_path: - if models or tasks or task_group or n_shot: + if models or tasks or task_groups or n_shot: raise ValueError( - "Cannot specify `models`, `tasks`, `task_group`, or `n_shot` when `eval_csv_path` is provided." + "Cannot specify `models`, `tasks`, `task_groups`, or `n_shot` when `eval_csv_path` is provided." ) df = pd.read_csv(eval_csv_path) required_cols = {"model_path", "task_path", "n_shot"} @@ -469,172 +110,115 @@ def schedule_evals( f"CSV file must contain the columns: {', '.join(required_cols)}" ) - # Always expand local model paths, even with skip_checks - df["model_path"].unique() - expanded_rows = [] - for _, row in df.iterrows(): - original_model_path = row["model_path"] - local_paths = _expand_local_model_paths(original_model_path) - if local_paths: - # Use expanded local paths - for expanded_path in local_paths: - new_row = row.copy() - new_row["model_path"] = expanded_path - expanded_rows.append(new_row) - else: - # Keep original path (might be HF model) - expanded_rows.append(row) - df = pd.DataFrame(expanded_rows) - - # Download HF models only if skip_checks is False - if not skip_checks: - # Process any HF models that need downloading - hf_models = [m for m in df["model_path"].unique() if not Path(m).exists()] - if hf_models: - model_path_map = _process_model_paths(hf_models) - # Update the dataframe with processed HF models - for idx, row in df.iterrows(): - if row["model_path"] in model_path_map: - # This shouldn't expand further, just update the path - df.at[idx, "model_path"] = model_path_map[row["model_path"]][0] + if "eval_suite" not in df.columns: + df["eval_suite"] = "lm_eval" else: - logging.info( - "Skipping model path processing and validation (--skip-checks enabled)" - ) - - elif models and task_group: - if tasks or n_shot is not None: - raise ValueError( - "Cannot combine `task_group` with explicit `tasks` or `n_shot` arguments." - ) - - model_list = [m.strip() for m in models.split(",") if m.strip()] - if not model_list: - raise ValueError("No models specified.") + df["eval_suite"] = df["eval_suite"].fillna("lm_eval") - model_paths: list[str] = [] - - for model in model_list: - local_paths = _expand_local_model_paths(model) - if local_paths: - model_paths.extend(str(path) for path in local_paths) - else: - model_paths.append(model) - - if not skip_checks: - hf_models = [m for m in model_paths if not Path(m).exists()] - if hf_models: - model_path_map = _process_model_paths(hf_models) - model_paths = [ - model_path_map[m][0] if m in model_path_map else m - for m in model_paths - ] - else: - logging.info( - "Skipping model path processing and validation (--skip-checks enabled)" - ) - - group_names = [name.strip() for name in task_group.split(",") if name.strip()] - if not group_names: - raise ValueError("No task groups specified.") - - task_groups = load_task_groups() - - def _capture_duplicate(group_name: str, pair: tuple[str, int]) -> None: - task_name, nshot = pair - logging.info( - "Skipping duplicate task '%s' with n_shot=%s from group '%s'", - task_name, - nshot, - group_name, - ) - - flattened = flatten_task_groups( - group_names, - task_groups, - on_duplicate=_capture_duplicate, - ) - - if not flattened: - raise ValueError( - "Resolved task groups did not yield any tasks. Check task-groups.yaml." - ) - - records = [] - for model_path in model_paths: - for task_name, nshot in flattened: - records.append( - { - "model_path": model_path, - "task_path": task_name, - "n_shot": nshot, - } + # Always expand local model paths, even with skip_checks + df["model_path"].unique() + eval_jobs.extend( + [ + EvaluationJob( + model_path=row["model_path"], + task_path=row["task_path"], + n_shot=row["n_shot"], + eval_suite=row["eval_suite"], ) + for _, row in df.iterrows() + ] + ) - df = pd.DataFrame(records) - - elif models and tasks and n_shot is not None: - model_list = models.split(",") - expanded_model_paths: list[str] = [] - - # Always expand local paths - for model in model_list: - local_paths = _expand_local_model_paths(model) - if local_paths: - expanded_model_paths.extend(str(path) for path in local_paths) - else: - expanded_model_paths.append(model) - - # Download HF models only if skip_checks is False - if not skip_checks: - hf_models = [m for m in expanded_model_paths if not Path(m).exists()] - if hf_models: - model_path_map = _process_model_paths(hf_models) - # Replace HF model identifiers with processed paths - expanded_model_paths = [ - model_path_map[m][0] if m in model_path_map else m - for m in expanded_model_paths + elif models: + if task_groups is None: + eval_jobs.extend( + [ + EvaluationJob( + model_path=model, + task_path=task, + n_shot=shot, + eval_suite="lm_eval", + ) + for model in models + for task in tasks + for shot in n_shot ] + ) else: - logging.info( - "Skipping model path processing and validation (--skip-checks enabled)" + expanded = _expand_task_groups([g.strip() for g in task_groups.split(",")]) + eval_jobs.extend( + [ + EvaluationJob( + model_path=model, + task_path=result.task, + n_shot=result.n_shot, + eval_suite=result.suite, + ) + for model in models + for result in expanded + ] ) - tasks_list = tasks.split(",") + expanded_eval_jobs = [] + for job in eval_jobs: + local_model_paths = _expand_local_model_paths(job.model_path) + if not local_model_paths: + expanded_eval_jobs.append(job) + else: + for path in local_model_paths: + expanded_eval_jobs.append( + EvaluationJob( + model_path=path, + task_path=job.task_path, + n_shot=job.n_shot, + eval_suite=job.eval_suite, + ) + ) - # cross product of model_paths and tasks into a dataframe - df = pd.DataFrame( - product( - expanded_model_paths, - tasks_list, - n_shot if isinstance(n_shot, list) else [n_shot], - ), - columns=["model_path", "task_path", "n_shot"], - ) + if not skip_checks: + hub_models: set[str | Path] = { + job.model_path + for job in expanded_eval_jobs + if not Path(job.model_path).exists() + } + _process_model_paths(hub_models) else: - raise ValueError( - "Either provide `eval_csv_path`, or specify `models` with `tasks` and `n_shot`," - " or `models` with `task_group`." + logging.info( + "Skipping model path processing and validation (--skip-checks enabled)" ) + # create csv + df = pd.DataFrame(expanded_eval_jobs) + if df.empty: logging.warning("No evaluation jobs to schedule.") return None + df["eval_suite"] = df["eval_suite"].str.lower() + # Ensure that all datasets required by the tasks are cached locally to avoid # network access on compute nodes. if not skip_checks: - _pre_download_task_datasets( - df["task_path"].unique(), trust_remote_code=trust_remote_code - ) + lm_eval_tasks = df[df["eval_suite"].isin({"lm-eval-harness"})][ + "task_path" + ].unique() + if len(lm_eval_tasks) > 0: + _pre_download_task_datasets( + lm_eval_tasks, trust_remote_code=trust_remote_code + ) + # Pre-download LightEval datasets (best-effort, incremental support) + light_eval_tasks = df[df["eval_suite"].isin({"light-eval"})]["task_path"].unique() + if len(light_eval_tasks) > 0: + _pre_download_lighteval_datasets(light_eval_tasks) else: logging.info("Skipping dataset pre-download (--skip-checks enabled)") if download_only: return None - queue_limit = int(os.environ.get("QUEUE_LIMIT", 250)) - remaining_queue_capacity = queue_limit - _num_jobs_in_queue() + remaining_queue_capacity = ( + int(os.environ.get("QUEUE_LIMIT", 250)) - _num_jobs_in_queue() + ) if remaining_queue_capacity <= 0: logging.warning("No remaining queue capacity. Not scheduling any jobs.") @@ -663,74 +247,24 @@ def _capture_duplicate(group_name: str, pair: tuple[str, int]) -> None: df.to_csv(csv_path, index=False) - logging.debug(f"Saved evaluation dataframe to temporary CSV: {csv_path}") - - with open(Path(__file__).parent / "template.sbatch") as f: - sbatch_template = f.read() + sbatch_template = (files("oellm.resources") / "template.sbatch").read_text() # Calculate dynamic array size and time limits total_evals = len(df) - - # Calculate time based on actual task complexity (subtask count) - if not skip_checks: - from lm_eval.tasks import TaskManager # type: ignore - - shared_task_manager = TaskManager() - - # Calculate total minutes by considering each unique task's complexity - total_minutes = 0 - task_time_cache = {} # Cache to avoid recalculating for same tasks - - for _, row in df.iterrows(): - task_name = row["task_path"] - if task_name not in task_time_cache: - task_time_cache[task_name] = _calculate_task_minutes( - task_name, task_manager=shared_task_manager - ) - total_minutes += task_time_cache[task_name] - - # Calculate average minutes per eval for logging purposes - minutes_per_eval = total_minutes / total_evals if total_evals > 0 else 10 - - logging.info("📊 Dynamic time calculation:") - for task_name, task_minutes in task_time_cache.items(): - task_count = (df["task_path"] == task_name).sum() - logging.info( - f" Task '{task_name}': {task_minutes} min/eval × {task_count} evals = {task_minutes * task_count} total minutes" - ) - else: - # Fallback to fixed timing when checks are skipped - minutes_per_eval = 10 # Budget 10 minutes per eval - total_minutes = total_evals * minutes_per_eval - logging.info( - "⚠️ Using fixed 10 min/eval (task complexity detection skipped with --skip-checks)" - ) - - # Maximum runtime per job (18 hours with safety margin) + minutes_per_eval = 10 # Budget 10 minutes per eval + total_minutes = total_evals * minutes_per_eval max_minutes_per_job = 18 * 60 # 18 hours min_array_size_for_time = max(1, int(np.ceil(total_minutes / max_minutes_per_job))) desired_array_size = min(128, total_evals) if total_evals >= 128 else total_evals if desired_array_size < min_array_size_for_time: desired_array_size = min_array_size_for_time - - # The actual array size is limited by queue capacity and total evals actual_array_size = min(remaining_queue_capacity, desired_array_size, total_evals) - - # Calculate actual time per job evals_per_job = max(1, int(np.ceil(total_evals / actual_array_size))) minutes_per_job = evals_per_job * minutes_per_eval - - # Add 20% safety margin and round up to nearest hour minutes_with_margin = int(minutes_per_job * 1.2) hours_with_margin = max(1, int(np.ceil(minutes_with_margin / 60))) - - # Apply 3-hour safety minimum for array jobs hours_with_margin = max(hours_with_margin, 3) - - # Cap at 24 hours hours_with_margin = min(hours_with_margin, 23) - - # Format time limit for SLURM (HH:MM:SS) time_limit = f"{hours_with_margin:02d}:59:00" # Log the calculated values @@ -750,8 +284,6 @@ def _capture_duplicate(group_name: str, pair: tuple[str, int]) -> None: ) logging.info(f" Time limit with safety margin: {time_limit}") - # replace the placeholders in the template with the actual values - # First, replace python-style placeholders sbatch_script = sbatch_template.format( csv_path=csv_path, max_array_len=max_array_len, @@ -763,13 +295,10 @@ def _capture_duplicate(group_name: str, pair: tuple[str, int]) -> None: time_limit=time_limit, # Dynamic time limit ) - # substitute any $ENV_VAR occurrences (e.g., $TIME_LIMIT) since env vars are not - # expanded in the #SBATCH directives + # substitute any $ENV_VAR occurrences sbatch_script = Template(sbatch_script).safe_substitute(os.environ) - # Save the sbatch script to the evals directory sbatch_script_path = evals_dir / "submit_evals.sbatch" - logging.debug(f"Saving sbatch script to {sbatch_script_path}") with open(sbatch_script_path, "w") as f: f.write(sbatch_script) @@ -846,7 +375,6 @@ def collect_results( output_csv: str = "eval_results.csv", *, check: bool = False, - reschedule: bool = False, verbose: bool = False, ) -> None: """ @@ -855,16 +383,12 @@ def collect_results( Args: results_dir: Path to the directory containing result JSON files output_csv: Output CSV filename (default: eval_results.csv) - check: Check for crashed or pending evaluations - reschedule: Show overview table and prompt to reschedule failed/pending jobs + check: Check for missing evaluations and create a missing jobs CSV verbose: Enable verbose logging """ import json - from rich.table import Table - _setup_logging(verbose) - console = Console() results_path = Path(results_dir) if not results_path.exists(): @@ -885,13 +409,12 @@ def collect_results( logging.info(f"Found {len(json_files)} result files") - # If check or reschedule mode, also load the jobs.csv to compare - if check or reschedule: + # If check mode, also load the jobs.csv to compare + if check: jobs_csv_path = results_path / "jobs.csv" if not jobs_csv_path.exists(): logging.warning(f"No jobs.csv found in {results_dir}, cannot perform check") check = False - reschedule = False else: jobs_df = pd.read_csv(jobs_csv_path) logging.info(f"Found {len(jobs_df)} scheduled jobs in jobs.csv") @@ -899,72 +422,148 @@ def collect_results( # Collect results rows = [] completed_jobs = set() # Track (model, task, n_shot) tuples - results_with_performance = ( - 0 # Track how many results actually have performance data - ) for json_file in json_files: + with open(json_file) as f: + data = json.load(f) + + # Extract model name/path + model_name = data.get("model_name", "unknown") + + # Extract results for each task + results = data.get("results", {}) + n_shot_data = data.get("n-shot", {}) + + # Infer a global n_shot if exactly one unique value exists in this JSON + global_n_shot = None try: - with open(json_file) as f: - data = json.load(f) - - # Extract model name/path - model_name = data.get("model_name", "unknown") - - # Extract results for each task - results = data.get("results", {}) - n_shot_data = data.get("n-shot", {}) - - for task_name, task_results in results.items(): - # Skip MMLU subtasks - only keep the aggregate score - if task_name.startswith("mmlu_") and task_name != "mmlu": - continue - - # Get n_shot for this task - n_shot = n_shot_data.get(task_name, "unknown") - - # Special handling for MMLU aggregate - get n_shot from any MMLU subtask - if task_name == "mmlu" and n_shot == "unknown": - for key, value in n_shot_data.items(): - if key.startswith("mmlu_"): - n_shot = value - break - - # Get the primary metric (usually acc,none) - performance = task_results.get("acc,none") - if performance is None: - # Try other common metric names - for metric in ["acc", "accuracy", "f1", "exact_match"]: - if metric in task_results: - performance = task_results[metric] - break - - if performance is not None: - results_with_performance += 1 - - # Track completed job for check/reschedule mode (only if we have a result) - if check or reschedule: - completed_jobs.add((model_name, task_name, n_shot)) - - rows.append( - { - "model_name": model_name, - "task": task_name, - "n_shot": n_shot, - "performance": performance, - } - ) - else: - # Debug: log cases where we have a task but no performance metric - if verbose: - logging.debug( - f"No performance metric found for {model_name} | {task_name} | n_shot={n_shot} in {json_file.name}" - ) + candidate_values = [] + for _v in n_shot_data.values(): + if isinstance(_v, (int | float)): + candidate_values.append(int(_v)) + elif isinstance(_v, str) and _v.isdigit(): + candidate_values.append(int(_v)) + unique_values = set(candidate_values) + if len(unique_values) == 1: + global_n_shot = next(iter(unique_values)) + except Exception: + pass + + # Aggregate groups (lm-eval harness) + groups_map = data.get("groups", {}) + group_subtasks_map = data.get("group_subtasks", {}) + group_aggregate_names = set(groups_map.keys()) | set(group_subtasks_map.keys()) + group_subtask_names: set[str] = set() + for _agg, _subs in group_subtasks_map.items(): + for _s in _subs: + group_subtask_names.add(_s) + + # Prefer only the first aggregate metric from groups (simplified) + if groups_map: + group_name, group_results = next(iter(groups_map.items())) + n_shot = n_shot_data.get(group_name, "unknown") + if n_shot == "unknown": + for subtask_name in group_subtasks_map.get(group_name, []): + if subtask_name in n_shot_data: + n_shot = n_shot_data[subtask_name] + break + if n_shot == "unknown" and global_n_shot is not None: + n_shot = global_n_shot + performance = group_results.get("acc,none") + if performance is None: + for metric in ["acc", "accuracy", "f1", "exact_match"]: + if metric in group_results: + performance = group_results[metric] + break + if performance is not None: + if check: + completed_jobs.add((model_name, group_name, n_shot)) + rows.append( + { + "model_name": model_name, + "task": group_name, + "n_shot": n_shot, + "performance": performance, + } + ) + # Skip per-task iteration when groups are present + continue + + for task_name, task_results in results.items(): + # Skip entries already added from groups + if groups_map and task_name in group_aggregate_names: + continue + # Skip any lm-eval group subtasks; keep only aggregates + if task_name in group_subtask_names: + continue + + # Skip MMLU subtasks - only keep the aggregate score + if task_name.startswith("mmlu_") and task_name != "mmlu": + continue + + # Skip Global MMLU subtasks - keep only aggregates like global_mmlu_full_pt + if task_name.startswith("global_mmlu_") and task_name.count("_") >= 4: + continue + + # Get n_shot for this task + n_shot = n_shot_data.get(task_name, "unknown") + + # If this is a group aggregate and n_shot is missing, derive from any subtask + if task_name in group_aggregate_names and n_shot == "unknown": + for subtask_name in group_subtasks_map.get(task_name, []): + if subtask_name in n_shot_data: + n_shot = n_shot_data[subtask_name] + break + if n_shot == "unknown" and global_n_shot is not None: + n_shot = global_n_shot + + # Special handling for MMLU aggregate - get n_shot from any MMLU subtask + if task_name == "mmlu" and n_shot == "unknown": + for key, value in n_shot_data.items(): + if key.startswith("mmlu_"): + n_shot = value + break + if n_shot == "unknown" and global_n_shot is not None: + n_shot = global_n_shot + + # Special handling for Global MMLU aggregates - get n_shot from subtasks + if task_name.startswith("global_mmlu_") and n_shot == "unknown": + prefix = f"{task_name}_" + for key, value in n_shot_data.items(): + if key.startswith(prefix): + n_shot = value + break + if n_shot == "unknown" and global_n_shot is not None: + n_shot = global_n_shot + + # Get the primary metric (usually acc,none) + performance = task_results.get("acc,none") + if performance is None: + # Try other common metric names + for metric in ["acc", "accuracy", "f1", "exact_match"]: + if metric in task_results: + performance = task_results[metric] + break + + if performance is not None: + # Track completed job for check mode + if check: + completed_jobs.add((model_name, task_name, n_shot)) - except Exception as e: - logging.warning(f"Failed to process {json_file}: {e}") - if verbose: - logging.exception(e) + rows.append( + { + "model_name": model_name, + "task": task_name, + "n_shot": n_shot, + "performance": performance, + } + ) + else: + # Debug: log cases where we have a task but no performance metric + if verbose: + logging.debug( + f"No performance metric found for {model_name} | {task_name} | n_shot={n_shot} in {json_file.name}" + ) if not rows and not check: logging.warning("No results extracted from JSON files") @@ -979,7 +578,7 @@ def collect_results( # Print summary statistics if verbose: - logging.info("\nSummary:") + logging.info("Summary:") logging.info(f"Unique models: {df['model_name'].nunique()}") logging.info(f"Unique tasks: {df['task'].nunique()}") logging.info( @@ -987,101 +586,23 @@ def collect_results( ) # Perform check analysis if requested - if check or reschedule: - logging.info("\n=== Evaluation Status Check ===") - - # Parse SLURM logs to get more detailed status - slurm_logs_dir = results_path / "slurm_logs" - attempted_jobs = set() # Jobs that were attempted (started) - failed_jobs = set() # Jobs that crashed/failed - - if slurm_logs_dir.exists(): - # Parse .out files to find attempted jobs - for out_file in slurm_logs_dir.glob("*.out"): - try: - with open(out_file) as f: - content = f.read() - # Look for "Starting evaluation for:" patterns - import re - - pattern = r"Starting evaluation for:\s*\n\s*Model: (.+)\s*\n\s*Task: (.+)\s*\n\s*N-shot: (\d+)" - matches = re.findall(pattern, content) - for model, task, n_shot in matches: - attempted_jobs.add( - (model.strip(), task.strip(), int(n_shot.strip())) - ) - - # Check if job finished successfully - if "Job" in content and "finished." in content: - # This array job completed successfully - pass - else: - # Job might have crashed - check for specific patterns - if ( - "Traceback" in content - or "Error" in content - or "Exception" in content - ): - for model, task, n_shot in matches: - failed_jobs.add( - ( - model.strip(), - task.strip(), - int(n_shot.strip()), - ) - ) - except Exception as e: - logging.debug(f"Error parsing {out_file}: {e}") - - # Parse .err files for errors - for err_file in slurm_logs_dir.glob("*.err"): - try: - file_size = err_file.stat().st_size - if file_size > 0: # Non-empty error file - # Extract array task ID from filename - array_id_match = re.search(r"-(\d+)\.err$", err_file.name) - if array_id_match: - int(array_id_match.group(1)) - # Find corresponding .out file to get job details - out_file = err_file.with_suffix(".out") - if out_file.exists(): - with open(out_file) as f: - content = f.read() - pattern = r"Starting evaluation for:\s*\n\s*Model: (.+)\s*\n\s*Task: (.+)\s*\n\s*N-shot: (\d+)" - matches = re.findall(pattern, content) - for model, task, n_shot in matches: - failed_jobs.add( - ( - model.strip(), - task.strip(), - int(n_shot.strip()), - ) - ) - except Exception as e: - logging.debug(f"Error parsing {err_file}: {e}") - - # Categorize incomplete jobs - still_running_jobs = [] # Jobs that are likely still executing - never_attempted_jobs = [] - crashed_jobs = [] - needs_rerun_jobs = [] # Jobs that definitely need to be rescheduled - - # We know we have exactly len(completed_jobs) completed jobs with actual results - # The rest need to be categorized - len(completed_jobs) + if check: + logging.info("=== Evaluation Status Check ===") + + # Find missing jobs + missing_jobs = [] for _, job in jobs_df.iterrows(): job_tuple = (job["model_path"], job["task_path"], job["n_shot"]) # Check if this job corresponds to one of our completed results - # Use the same matching logic as before but don't over-count is_completed = False - # Try to find a matching completed job + # Try exact matching first if job_tuple in completed_jobs: is_completed = True else: - # Try fuzzy matching + # Try fuzzy matching for model names for completed_job in completed_jobs: completed_model, completed_task, completed_n_shot = completed_job @@ -1096,214 +617,43 @@ def collect_results( is_completed = True break - if is_completed: - continue # Skip completed jobs - - # Job is not completed, categorize it - if job_tuple in failed_jobs: - crashed_jobs.append(job) - needs_rerun_jobs.append(job) - elif job_tuple not in attempted_jobs: - never_attempted_jobs.append(job) - needs_rerun_jobs.append(job) # These likely need rescheduling too - else: - # Job was attempted but not completed and didn't crash - likely still running - still_running_jobs.append(job) - - needs_rerun_df = pd.DataFrame(needs_rerun_jobs) + if not is_completed: + missing_jobs.append(job) - # Calculate completed jobs based on the jobs.csv perspective - actual_completed_from_jobs = ( - len(jobs_df) - - len(still_running_jobs) - - len(crashed_jobs) - - len(never_attempted_jobs) - ) + completed_count = len(jobs_df) - len(missing_jobs) - logging.info(f"\nTotal scheduled jobs: {len(jobs_df)}") - logging.info( - f"Completed jobs (from scheduled jobs): {actual_completed_from_jobs}" - ) - logging.info(f"Still running/pending: {len(still_running_jobs)}") - logging.info(f"Failed/Crashed jobs: {len(crashed_jobs)}") - logging.info(f"Never attempted: {len(never_attempted_jobs)}") - logging.info(f"Jobs needing reschedule: {len(needs_rerun_jobs)}") + logging.info(f"Total scheduled jobs: {len(jobs_df)}") + logging.info(f"Completed jobs: {completed_count}") + logging.info(f"Missing jobs: {len(missing_jobs)}") - if verbose: - logging.info(f"Total CSV rows (results with performance data): {len(rows)}") + if len(missing_jobs) > 0: + missing_df = pd.DataFrame(missing_jobs) + missing_csv = output_csv.replace(".csv", "_missing.csv") + missing_df.to_csv(missing_csv, index=False) + logging.info(f"Missing jobs saved to: {missing_csv}") logging.info( - f"Unique completed jobs found in JSON files: {len(completed_jobs)}" + f"You can run these with: oellm schedule-eval --eval_csv_path {missing_csv}" ) - if len(completed_jobs) != actual_completed_from_jobs: - logging.info( - f"Note: {len(completed_jobs)} results found vs {actual_completed_from_jobs} jobs matched from schedule" - ) - - if len(needs_rerun_jobs) > 0: - if reschedule: - # Show overview table in reschedule mode - console.print("\n[bold cyan]🔄 Jobs Needing Reschedule[/bold cyan]") - - # Create summary table - summary_table = Table( - show_header=True, header_style="bold magenta", box=box.ROUNDED - ) - summary_table.add_column("Status", style="bold") - summary_table.add_column("Count", justify="right", style="cyan") - - summary_table.add_row("✅ Completed", str(actual_completed_from_jobs)) - summary_table.add_row("🏃 Still Running", str(len(still_running_jobs))) - summary_table.add_row("❌ Crashed", str(len(crashed_jobs))) - summary_table.add_row( - "⏭️ Never Attempted", str(len(never_attempted_jobs)) - ) - summary_table.add_row( - "[bold yellow]🔄 Need Reschedule[/bold yellow]", - f"[bold yellow]{len(needs_rerun_jobs)}[/bold yellow]", - ) - - console.print(summary_table) - - # Show detailed table of jobs to reschedule - console.print("\n[bold cyan]📋 Detailed Job List[/bold cyan]") - - detail_table = Table( - show_header=True, header_style="bold magenta", box=box.ROUNDED - ) - detail_table.add_column("#", style="dim", width=4) - detail_table.add_column("Status", style="bold", width=15) - detail_table.add_column( - "Model", style="cyan", no_wrap=True, max_width=40 - ) - detail_table.add_column("Task", style="green", max_width=20) - detail_table.add_column("n_shot", justify="right", style="yellow") - - # Show first 20 rows - for idx, (_, job) in enumerate(needs_rerun_df.head(20).iterrows(), 1): - if ( - job["model_path"], - job["task_path"], - job["n_shot"], - ) in failed_jobs: - status = "[red]❌ CRASHED[/red]" - else: - status = "[yellow]⏭️ NOT ATTEMPTED[/yellow]" - - # Truncate long model paths for display - model_display = str(job["model_path"]) - if len(model_display) > 40: - model_display = "..." + model_display[-37:] - - detail_table.add_row( - str(idx), - status, - model_display, - str(job["task_path"]), - str(job["n_shot"]), - ) - if len(needs_rerun_jobs) > 20: - detail_table.add_row("...", "...", "...", "...", "...") - console.print(detail_table) - console.print( - f"\n[dim]Showing 20 of {len(needs_rerun_jobs)} jobs[/dim]" + # Show some examples if verbose + if verbose and len(missing_jobs) > 0: + logging.info("Example missing jobs:") + for _i, (_, job) in enumerate(missing_df.head(5).iterrows()): + logging.info( + f" - {job['model_path']} | {job['task_path']} | n_shot={job['n_shot']}" ) - else: - console.print(detail_table) - - # Ask for confirmation - console.print( - f"\n[bold]Total jobs to reschedule: {len(needs_rerun_jobs)}[/bold]" - ) - - import questionary - from questionary import Style - - custom_style = Style( - [ - ("qmark", "fg:#673ab7 bold"), - ("question", "bold"), - ("answer", "fg:#f44336 bold"), - ("pointer", "fg:#673ab7 bold"), - ("highlighted", "fg:#673ab7 bold"), - ("selected", "fg:#cc5454"), - ] - ) - - save_and_schedule = questionary.confirm( - "\nSave failed jobs CSV and schedule re-evaluation?", - default=True, - style=custom_style, - ).ask() - - if save_and_schedule: - # Save the CSV - rerun_csv = output_csv.replace(".csv", "_needs_rerun.csv") - needs_rerun_df.to_csv(rerun_csv, index=False) - console.print(f"\n[green]✅ Jobs saved to: {rerun_csv}[/green]") - - # Ask if they want to schedule now - schedule_now = questionary.confirm( - "\nSchedule these jobs now?", - default=True, - style=custom_style, - ).ask() - - if schedule_now: - console.print("\n[yellow]To schedule these jobs, run:[/yellow]") - console.print( - f"[bold cyan]oellm schedule-eval --eval_csv_path {rerun_csv}[/bold cyan]" - ) - - else: - # Original behavior for check mode - # Save jobs that need rescheduling - rerun_csv = output_csv.replace(".csv", "_needs_rerun.csv") - needs_rerun_df.to_csv(rerun_csv, index=False) - logging.info(f"\nJobs needing reschedule saved to: {rerun_csv}") - logging.info( - f"You can re-run these with: [bold cyan]oellm schedule-eval --eval_csv_path {rerun_csv}[/bold cyan]" - ) - - # Save crashed jobs separately if any - if crashed_jobs: - crashed_csv = output_csv.replace(".csv", "_crashed.csv") - pd.DataFrame(crashed_jobs).to_csv(crashed_csv, index=False) - logging.info(f"Crashed jobs specifically saved to: {crashed_csv}") - - # Show some examples if verbose - if verbose and len(needs_rerun_jobs) > 0: - logging.info("\nExample jobs needing reschedule:") - for _i, (_, job) in enumerate(needs_rerun_df.head(5).iterrows()): - if ( - job["model_path"], - job["task_path"], - job["n_shot"], - ) in failed_jobs: - status = "CRASHED" - else: - status = "NEVER ATTEMPTED" - logging.info( - f" - [{status}] {job['model_path']} | {job['task_path']} | n_shot={job['n_shot']}" - ) - if len(needs_rerun_jobs) > 5: - logging.info(f" ... and {len(needs_rerun_jobs) - 5} more") - - if still_running_jobs and verbose: - logging.info( - f"\nNote: {len(still_running_jobs)} jobs appear to still be running/pending." - ) - logging.info( - "These were attempted but haven't completed yet. Check SLURM queue status." - ) + if len(missing_jobs) > 5: + logging.info(f" ... and {len(missing_jobs) - 5} more") def main(): + _filter_warnings() auto_cli( { "schedule-eval": schedule_evals, "build-csv": build_csv, "collect-results": collect_results, + "clean-cache": lambda: clear_task_cache(), }, as_positional=False, description="OELLM: Multi-cluster evaluation tool for language models", diff --git a/oellm/clusters.yaml b/oellm/resources/clusters.yaml similarity index 84% rename from oellm/clusters.yaml rename to oellm/resources/clusters.yaml index e48842b..0fa3f60 100644 --- a/oellm/clusters.yaml +++ b/oellm/resources/clusters.yaml @@ -5,10 +5,12 @@ shared: HF_HOME: "{EVAL_BASE_DIR}/hf_data" # where HuggingFace models and datasets are stored EVAL_OUTPUT_DIR: "{EVAL_BASE_DIR}/{USER}" # where evaluations are written GPUS_PER_NODE: 1 + HF_HUB_DISABLE_PROGRESS_BARS: "1" + HF_DATASETS_DISABLE_PROGRESS_BARS: "1" leonardo: hostname_pattern: "*.leonardo.local" # use this regexp to automatically assign environment variables corresponding to this YAML - EVAL_BASE_DIR: "/leonardo_work/AIFAC_L01_028/shared_evals" + EVAL_BASE_DIR: "/leonardo_work/AIFAC_L01_028/oellm-cli-shared-evals" PARTITION: "boost_usr_prod" # default partition to use ACCOUNT: "AIFAC_L01_028" # default account to use QUEUE_LIMIT: 1000 # maximum number of jobs that can be submitted as job/array, used to send only jobs that respects QOS @@ -26,7 +28,7 @@ jureca: lumi: hostname_pattern: "uan*" - EVAL_BASE_DIR: "/pfs/lustrep4/scratch/project_462000963/shared_evals" + EVAL_BASE_DIR: "/pfs/lustrep4/scratch/project_462000963/oellm-cli-shared-evals" PARTITION: "small-g" ACCOUNT: "project_462000963" QUEUE_LIMIT: 210 diff --git a/oellm/resources/task-groups.yaml b/oellm/resources/task-groups.yaml new file mode 100644 index 0000000..69ca6c8 --- /dev/null +++ b/oellm/resources/task-groups.yaml @@ -0,0 +1,197 @@ +task_groups: + open-sci-0.01: + description: "open-sci-ref 0.01 evals" + suite: lm-eval-harness + tasks: + - task: copa + n_shots: [0] + - task: social_iqa + n_shots: [0] + - task: openbookqa + n_shots: [0] + - task: lambada_openai + n_shots: [0] + - task: winogrande + n_shots: [0] + - task: mmlu + n_shots: [5] + - task: hellaswag + n_shots: [10] + - task: arc_easy + n_shots: [10] + - task: arc_challenge + n_shots: [10] + - task: commonsense_qa + n_shots: [10] + - task: piqa + n_shots: [10] + - task: boolq + n_shots: [10] + belebele-eu-5-shot: + description: "Belebele European language tasks" + suite: lm-eval-harness + n_shots: [5] + tasks: + - task: belebele_bul_Cyrl + - task: belebele_hrv_Latn + - task: belebele_ces_Latn + - task: belebele_dan_Latn + - task: belebele_nld_Latn + - task: belebele_eng_Latn + - task: belebele_est_Latn + - task: belebele_fin_Latn + - task: belebele_fra_Latn + - task: belebele_deu_Latn + - task: belebele_ell_Grek + - task: belebele_hun_Latn + - task: belebele_ita_Latn + - task: belebele_lvs_Latn + - task: belebele_lit_Latn + - task: belebele_mlt_Latn + - task: belebele_pol_Latn + - task: belebele_por_Latn + - task: belebele_ron_Latn + - task: belebele_slk_Latn + - task: belebele_slv_Latn + - task: belebele_spa_Latn + - task: belebele_swe_Latn + flores-200-eu-to-eng: + description: "Flores 200 EU to English translation" + suite: lighteval + n_shots: [0] + tasks: + - task: flores200:bul_Cyrl-eng_Latn + - task: flores200:ces_Latn-eng_Latn + - task: flores200:dan_Latn-eng_Latn + - task: flores200:deu_Latn-eng_Latn + - task: flores200:ell_Grek-eng_Latn + - task: flores200:est_Latn-eng_Latn + - task: flores200:fin_Latn-eng_Latn + - task: flores200:fra_Latn-eng_Latn + - task: flores200:gle_Latn-eng_Latn + - task: flores200:hrv_Latn-eng_Latn + - task: flores200:hun_Latn-eng_Latn + - task: flores200:ita_Latn-eng_Latn + - task: flores200:lit_Latn-eng_Latn + - task: flores200:lvs_Latn-eng_Latn + - task: flores200:mlt_Latn-eng_Latn + - task: flores200:nld_Latn-eng_Latn + - task: flores200:pol_Latn-eng_Latn + - task: flores200:por_Latn-eng_Latn + - task: flores200:ron_Latn-eng_Latn + - task: flores200:slk_Latn-eng_Latn + - task: flores200:slv_Latn-eng_Latn + - task: flores200:spa_Latn-eng_Latn + - task: flores200:swe_Latn-eng_Latn + flores-200-eng-to-eu: + description: "Flores 200 English to EU translation" + suite: lighteval + n_shots: [0] + tasks: + - task: flores200:eng_Latn-bul_Cyrl + - task: flores200:eng_Latn-ces_Latn + - task: flores200:eng_Latn-dan_Latn + - task: flores200:eng_Latn-deu_Latn + - task: flores200:eng_Latn-ell_Grek + - task: flores200:eng_Latn-est_Latn + - task: flores200:eng_Latn-fin_Latn + - task: flores200:eng_Latn-fra_Latn + - task: flores200:eng_Latn-gle_Latn + - task: flores200:eng_Latn-hrv_Latn + - task: flores200:eng_Latn-hun_Latn + - task: flores200:eng_Latn-ita_Latn + - task: flores200:eng_Latn-lit_Latn + - task: flores200:eng_Latn-lvs_Latn + - task: flores200:eng_Latn-mlt_Latn + - task: flores200:eng_Latn-nld_Latn + - task: flores200:eng_Latn-pol_Latn + - task: flores200:eng_Latn-por_Latn + - task: flores200:eng_Latn-ron_Latn + - task: flores200:eng_Latn-slk_Latn + - task: flores200:eng_Latn-slv_Latn + - task: flores200:eng_Latn-spa_Latn + - task: flores200:eng_Latn-swe_Latn + global-mmlu-eu: + description: "Global MMLU EU tasks" + suite: lm-eval-harness + n_shots: [5] + tasks: + - task: global_mmlu_full_cs + - task: global_mmlu_full_de + - task: global_mmlu_full_el + - task: global_mmlu_full_en + - task: global_mmlu_full_es + - task: global_mmlu_full_fr + - task: global_mmlu_full_it + - task: global_mmlu_full_lt + - task: global_mmlu_full_nl + - task: global_mmlu_full_pl + - task: global_mmlu_full_pt + - task: global_mmlu_full_ro + - task: global_mmlu_full_ru + - task: global_mmlu_full_sr + - task: global_mmlu_full_sv + - task: global_mmlu_full_tr + - task: global_mmlu_full_uk + - task: global_mmlu_full_he + mgsm-eu: + description: "EU Language GSM benchmarks in Aya Expanse" + suite: lm-eval-harness + n_shots: [5] + tasks: + - task: mgsm_native_cot_en + - task: mgsm_native_cot_de + - task: mgsm_native_cot_es + - task: mgsm_native_cot_fr + + generic-multilingual: + description: "Generic multilingual benchmarks in Aya Expanse" + suite: lm-eval-harness + n_shots: [0] + tasks: + - task: xwinograd + - task: xcopa + - task: xstorycloze + + include: + description: "INCLUDE benchmarks in Aya Expanse" + suite: lm-eval-harness + n_shots: [0] + tasks: + - task: include_base_44_albanian + - task: include_base_44_armenian + - task: include_base_44_azerbaijani + - task: include_base_44_basque + - task: include_base_44_belarusian + - task: include_base_44_bulgarian + - task: include_base_44_croatian + - task: include_base_44_dutch + - task: include_base_44_estonian + - task: include_base_44_finnish + - task: include_base_44_french + - task: include_base_44_georgian + - task: include_base_44_german + - task: include_base_44_greek + - task: include_base_44_hungarian + - task: include_base_44_italian + - task: include_base_44_lithuanian + - task: include_base_44_north macedonian + - task: include_base_44_polish + - task: include_base_44_portuguese + - task: include_base_44_russian + - task: include_base_44_serbian + - task: include_base_44_spanish + - task: include_base_44_turkish + - task: include_base_44_ukrainian + +super_groups: + oellm-multilingual: + description: "Combined Belebele EU set plus multilingual benchmarks" + task_groups: + - task: flores-200-eu-to-eng + - task: flores-200-eng-to-eu + - task: belebele-eu-5-shot + - task: global-mmlu-eu + - task: mgsm-eu + - task: generic-multilingual + - task: include diff --git a/oellm/template.sbatch b/oellm/resources/template.sbatch similarity index 60% rename from oellm/template.sbatch rename to oellm/resources/template.sbatch index 34c95c3..de1aa69 100644 --- a/oellm/template.sbatch +++ b/oellm/resources/template.sbatch @@ -20,6 +20,7 @@ export HF_XET_CACHE="$HF_HOME/xet" export HF_ASSETS_CACHE="$HF_HOME/assets" export HUGGINGFACE_HUB_CACHE="$HF_HOME/hub" export HUGGINGFACE_ASSETS_CACHE="$HF_HOME/assets" +export HF_DATASETS_CACHE="$HF_HOME/datasets" export HF_HUB_OFFLINE=1 # Path to the shared Singularity image that contains all runtime deps @@ -56,12 +57,13 @@ fi # Use `tail` and `head` to slice the CSV file for the tasks assigned to this job. # The +1 on START_INDEX accounts for the header row. tail -n +$((START_INDEX + 1)) "$CSV_PATH" | head -n $((END_INDEX - START_INDEX + 1)) | \ -while IFS=, read -r model_path task_path n_shot +while IFS=, read -r model_path task_path n_shot eval_suite do # Remove trailing carriage returns if script is edited on Windows model_path=$(echo "$model_path" | tr -d '\r') task_path=$(echo "$task_path" | tr -d '\r') n_shot=$(echo "$n_shot" | tr -d '\r') + eval_suite=$(echo "${{eval_suite:-lm_eval}}" | tr -d '\r') # Skip empty lines if [ -z "$model_path" ]; then @@ -73,6 +75,7 @@ do echo " Model: $model_path" echo " Task: $task_path" echo " N-shot: $n_shot" + echo " Suite: $eval_suite" echo "----------------------------------------------------" # Build bind paths: always mount the shared eval directory, and additionally @@ -91,16 +94,55 @@ do fi fi - - singularity exec $SINGULARITY_ARGS \ - --bind $BIND_PATHS \ - $EVAL_SIF_PATH \ - python -m lm_eval --model hf \ - --model_args pretrained="$model_path",trust_remote_code=True \ - --tasks "$task_path" \ - --num_fewshot "$n_shot" \ - --output_path "{evals_dir}/$(openssl rand -hex 5).json" \ - --trust_remote_code + suite_normalized=$(echo "$eval_suite" | tr '[:upper:]' '[:lower:]') + + case "$suite_normalized" in + lm_eval|lm-eval|lm-eval-harness) + singularity exec $SINGULARITY_ARGS \ + --bind $BIND_PATHS \ + $EVAL_SIF_PATH \ + python -m lm_eval --model hf \ + --model_args pretrained="$model_path",trust_remote_code=True \ + --tasks "$task_path" \ + --num_fewshot "$n_shot" \ + --output_path "{evals_dir}/$(openssl rand -hex 5).json" \ + --trust_remote_code + ;; + lighteval|light-eval) + LIGHT_TASK="$task_path" + + if [[ -f "$LIGHT_TASK" ]]; then + LIGHT_TASK_ARG="$LIGHT_TASK" + else + last_segment="${{LIGHT_TASK##*|}}" + if [[ "$LIGHT_TASK" == *"|"* && "$last_segment" =~ ^[0-9]+$ ]]; then + if [[ -n "$n_shot" && "$last_segment" != "$n_shot" ]]; then + LIGHT_TASK_ARG="${{LIGHT_TASK%|*}}|$n_shot" + else + LIGHT_TASK_ARG="$LIGHT_TASK" + fi + else + LIGHT_TASK_ARG="lighteval|${{LIGHT_TASK}}|$n_shot|0" + fi + fi + + RESULTS_SUBDIR="{evals_dir}/$(openssl rand -hex 5)" + mkdir -p "$RESULTS_SUBDIR" + + singularity exec $SINGULARITY_ARGS \ + --bind $BIND_PATHS \ + --env CUDA_VISIBLE_DEVICES=$SLURM_GPUS_PER_NODE \ + $EVAL_SIF_PATH \ + lighteval accelerate \ + "model_name=$model_path,trust_remote_code=True" \ + "$LIGHT_TASK_ARG" \ + --custom-tasks lighteval.tasks.multilingual.tasks \ + --output-dir "$RESULTS_SUBDIR" + ;; + *) + echo "[warning] Unknown evaluation suite '$eval_suite'. Skipping." + ;; + esac echo "Evaluation finished for model: $model_path" diff --git a/oellm/task-groups.yaml b/oellm/task-groups.yaml deleted file mode 100644 index 465552c..0000000 --- a/oellm/task-groups.yaml +++ /dev/null @@ -1,86 +0,0 @@ -# Default task groups for interactive CSV builder -# Each group contains a list of tasks with their n_shot values -# Format: task_name,n_shot1,n_shot2,... - -task_groups: - open-sci-0.01: - description: "open-sci-ref 0.01 evals" - tasks: - - task: copa - n_shots: [0] - - task: social_iqa - n_shots: [0] - - task: openbookqa - n_shots: [0] - - task: lambada_openai - n_shots: [0] - - task: winogrande - n_shots: [0] - - task: mmlu - n_shots: [5] - - task: hellaswag - n_shots: [10] - - task: arc_easy - n_shots: [10] - - task: arc_challenge - n_shots: [10] - - task: commonsense_qa - n_shots: [10] - - task: piqa - n_shots: [10] - - task: boolq - n_shots: [10] - belebele-eu-5-shot: - description: "Belebele European language tasks" - tasks: - - task: belebele_bul_Cyrl - n_shots: [5] - - task: belebele_hrv_Latn - n_shots: [5] - - task: belebele_ces_Latn - n_shots: [5] - - task: belebele_dan_Latn - n_shots: [5] - - task: belebele_nld_Latn - n_shots: [5] - - task: belebele_eng_Latn - n_shots: [5] - - task: belebele_est_Latn - n_shots: [5] - - task: belebele_fin_Latn - n_shots: [5] - - task: belebele_fra_Latn - n_shots: [5] - - task: belebele_deu_Latn - n_shots: [5] - - task: belebele_ell_Grek - n_shots: [5] - - task: belebele_hun_Latn - n_shots: [5] - - task: belebele_ita_Latn - n_shots: [5] - - task: belebele_lvs_Latn - n_shots: [5] - - task: belebele_lit_Latn - n_shots: [5] - - task: belebele_mlt_Latn - n_shots: [5] - - task: belebele_pol_Latn - n_shots: [5] - - task: belebele_por_Latn - n_shots: [5] - - task: belebele_ron_Latn - n_shots: [5] - - task: belebele_slk_Latn - n_shots: [5] - - task: belebele_slv_Latn - n_shots: [5] - - task: belebele_spa_Latn - n_shots: [5] - - task: belebele_swe_Latn - n_shots: [5] - open-sci-and-belebele: - description: "Combined open-sci-0.01 and Belebele EU task groups" - groups: - - open-sci-0.01 - - belebele-eu-5-shot diff --git a/oellm/task_cache.py b/oellm/task_cache.py new file mode 100644 index 0000000..a320bee --- /dev/null +++ b/oellm/task_cache.py @@ -0,0 +1,330 @@ +import json +from contextlib import contextmanager +from contextvars import ContextVar +from datetime import datetime +from pathlib import Path + +TASK_CACHE_TTL_DAYS = 30 + + +_CURRENT_CAPTURE_BUFFER: ContextVar[list[dict] | None] = ContextVar( + "_CURRENT_CAPTURE_BUFFER", default=None +) + + +def get_task_cache_file() -> Path: + return Path(__file__).resolve().parent / "resources" / "task_map_cache.json" + + +def load_task_cache() -> dict: + cache_file = get_task_cache_file() + if cache_file.exists(): + with open(cache_file) as f: + return json.load(f) or {} + return {} + + +def save_task_cache(cache: dict) -> None: + cache_file = get_task_cache_file() + with open(cache_file, "w") as f: + json.dump(cache, f, indent=2, sort_keys=True) + + +def clear_task_cache() -> None: + cache_file = get_task_cache_file() + with open(cache_file, "w") as f: + json.dump({}, f) + + +def task_cache_key(framework: str, task_id: str) -> str: + return f"{framework}::{task_id}" + + +def task_cache_is_fresh(entry: dict, ttl_days: int = TASK_CACHE_TTL_DAYS) -> bool: + ts = float(entry.get("ts", 0)) + age_days = (datetime.now().timestamp() - ts) / 86400.0 + return age_days >= 0 and age_days < float(ttl_days) + + +def task_cache_lookup( + framework: str, task_id: str, ttl_days: int = TASK_CACHE_TTL_DAYS +) -> bool: + cache = load_task_cache() + key = task_cache_key(framework, task_id) + entry = cache.get(key) + if not isinstance(entry, dict): + return False + return task_cache_is_fresh(entry, ttl_days) + + +def task_cache_mark_resolved(framework: str, task_id: str) -> None: + cache = load_task_cache() + key = task_cache_key(framework, task_id) + entry = cache.get(key) if isinstance(cache.get(key), dict) else {} + entry["ts"] = datetime.now().timestamp() + cache[key] = entry + save_task_cache(cache) + + +def task_cache_get_payload(framework: str, task_id: str) -> dict | None: + cache = load_task_cache() + key = task_cache_key(framework, task_id) + entry = cache.get(key) + if not isinstance(entry, dict): + return None + payload = entry.get("payload") + return payload if isinstance(payload, dict) else None + + +def task_cache_set_payload(framework: str, task_id: str, payload: dict) -> None: + cache = load_task_cache() + key = task_cache_key(framework, task_id) + entry: dict = cache.get(key) if isinstance(cache.get(key), dict) else {} # type: ignore[assignment] + entry["ts"] = datetime.now().timestamp() + entry["payload"] = payload + cache[key] = entry + save_task_cache(cache) + + +def _canonical_key(call: dict) -> tuple: + t = call.get("type") + if t == "load_dataset": + return ( + t, + call.get("path"), + call.get("name"), + call.get("split"), + call.get("revision"), + ) + if t == "snapshot_download": + return ( + t, + call.get("repo_id"), + call.get("repo_type"), + call.get("revision"), + ) + if t == "hf_hub_download": + return ( + t, + call.get("repo_id"), + call.get("filename"), + call.get("repo_type"), + call.get("revision"), + ) + return (str(t),) + + +def dedupe_calls(calls: list[dict]) -> list[dict]: + if not isinstance(calls, list): + return [] + best: dict[tuple, dict] = {} + for c in calls: + if not isinstance(c, dict): + continue + key = _canonical_key(c) + existing = best.get(key) + if existing is None: + best[key] = c + continue + # Prefer trust_remote_code=True for load_dataset + if c.get("type") == "load_dataset": + if bool(c.get("trust_remote_code")) and not bool( + existing.get("trust_remote_code") + ): + best[key] = c + # Optionally drop snapshot_download if matching load_dataset exists + filtered: list[dict] = [] + load_keys = { + ("load_dataset", k[1], k[2], k[3], k[4]) + for k in best.keys() + if k and k[0] == "load_dataset" + } + for k, v in best.items(): + if k and k[0] == "snapshot_download": + # derive comparable key shape: (type, repo_id, None, None, revision) + comparable = ("load_dataset", k[1], None, None, k[3]) + if comparable in load_keys: + continue + filtered.append(v) + return filtered + + +@contextmanager +def capture_hf_dataset_calls(): + captured: list[dict] = [] + _buffer_token = _CURRENT_CAPTURE_BUFFER.set(captured) + + import datasets as _ds # type: ignore + import huggingface_hub as _hfh # type: ignore + + _orig_load_dataset = _ds.load_dataset + _orig_snapshot_download = _hfh.snapshot_download + _orig_hf_hub_download = _hfh.hf_hub_download + + def _load_dataset_proxy(path, *args, **kwargs): # noqa: ANN001 + name = ( + kwargs.get("name") + if "name" in kwargs + else (args[0] if len(args) > 0 else None) + ) + data_files = ( + kwargs.get("data_files") + if "data_files" in kwargs + else (args[1] if len(args) > 1 else None) + ) + split = ( + kwargs.get("split") + if "split" in kwargs + else (args[2] if len(args) > 2 else None) + ) + trust_remote_code = kwargs.get("trust_remote_code") + revision = kwargs.get("revision") + buf = _CURRENT_CAPTURE_BUFFER.get() + if isinstance(buf, list): + buf.append( + { + "type": "load_dataset", + "path": path, + "name": name, + "data_files": data_files, + "split": split, + "revision": revision, + "trust_remote_code": trust_remote_code, + } + ) + return _orig_load_dataset(path, *args, **kwargs) + + def _snapshot_download_proxy(*args, **kwargs): # noqa: ANN001 + repo_id = ( + kwargs.get("repo_id") + if "repo_id" in kwargs + else (args[0] if len(args) > 0 else None) + ) + repo_type = ( + kwargs.get("repo_type") + if "repo_type" in kwargs + else (args[1] if len(args) > 1 else None) + ) + revision = ( + kwargs.get("revision") + if "revision" in kwargs + else (args[2] if len(args) > 2 else None) + ) + buf = _CURRENT_CAPTURE_BUFFER.get() + if isinstance(buf, list): + buf.append( + { + "type": "snapshot_download", + "repo_id": repo_id, + "repo_type": repo_type, + "revision": revision, + } + ) + return _orig_snapshot_download(*args, **kwargs) + + def _hf_hub_download_proxy(*args, **kwargs): # noqa: ANN001 + repo_id = ( + kwargs.get("repo_id") + if "repo_id" in kwargs + else (args[0] if len(args) > 0 else None) + ) + filename = ( + kwargs.get("filename") + if "filename" in kwargs + else (args[1] if len(args) > 1 else None) + ) + repo_type = ( + kwargs.get("repo_type") + if "repo_type" in kwargs + else (args[2] if len(args) > 2 else None) + ) + revision = ( + kwargs.get("revision") + if "revision" in kwargs + else (args[3] if len(args) > 3 else None) + ) + buf = _CURRENT_CAPTURE_BUFFER.get() + if isinstance(buf, list): + buf.append( + { + "type": "hf_hub_download", + "repo_id": repo_id, + "filename": filename, + "repo_type": repo_type, + "revision": revision, + } + ) + return _orig_hf_hub_download(*args, **kwargs) + + _ds.load_dataset = _load_dataset_proxy # type: ignore[assignment] + _hfh.snapshot_download = _snapshot_download_proxy # type: ignore[assignment] + _hfh.hf_hub_download = _hf_hub_download_proxy # type: ignore[assignment] + + try: + yield captured + finally: + _ds.load_dataset = _orig_load_dataset # type: ignore[assignment] + _hfh.snapshot_download = _orig_snapshot_download # type: ignore[assignment] + _hfh.hf_hub_download = _orig_hf_hub_download # type: ignore[assignment] + _CURRENT_CAPTURE_BUFFER.reset(_buffer_token) + + +def prewarm_from_payload(payload: dict | None, *, trust_remote_code: bool = True) -> None: + if not isinstance(payload, dict): + return + calls = payload.get("calls") + if not isinstance(calls, list): + return + + from datasets import load_dataset # type: ignore + from huggingface_hub import hf_hub_download, snapshot_download # type: ignore + + for call in calls: + if not isinstance(call, dict): + continue + # Unified prewarm log message + if call.get("type") == "load_dataset": + path = call.get("path") + name = call.get("name") + else: + repo_id = call.get("repo_id") + filename = call.get("filename") + + if call.get("type") == "snapshot_download": + repo_id = call.get("repo_id") + if isinstance(repo_id, str) and repo_id: + snapshot_download( + repo_id=repo_id, + repo_type=call.get("repo_type") or "dataset", + revision=call.get("revision"), + ) + continue + if call.get("type") == "hf_hub_download": + repo_id = call.get("repo_id") + filename = call.get("filename") + if isinstance(repo_id, str) and isinstance(filename, str): + hf_hub_download( + repo_id=repo_id, + filename=filename, + repo_type=call.get("repo_type"), + revision=call.get("revision"), + ) + continue + path = call.get("path") + name = call.get("name") + data_files = call.get("data_files") + split = call.get("split") + revision = call.get("revision") + trc = call.get("trust_remote_code", trust_remote_code) + kwargs: dict = {} + if name is not None: + kwargs["name"] = name + if data_files is not None: + kwargs["data_files"] = data_files + if revision is not None: + kwargs["revision"] = revision + kwargs["trust_remote_code"] = bool(trc) + if split is not None: + load_dataset(path, split=split, **kwargs) + else: + load_dataset(path, **kwargs) diff --git a/oellm/task_groups.py b/oellm/task_groups.py index 8794521..73c7d35 100644 --- a/oellm/task_groups.py +++ b/oellm/task_groups.py @@ -1,122 +1,142 @@ -"""Utilities for loading and resolving task group definitions.""" - -from __future__ import annotations - -import logging -from pathlib import Path -from typing import Iterable, Sequence +from collections.abc import Iterable +from dataclasses import dataclass +from importlib.resources import files import yaml -DEFAULT_TASK_GROUPS_PATH = Path(__file__).parent / "task-groups.yaml" - - -def load_task_groups(path: str | Path | None = None) -> dict[str, dict]: - """Load task group definitions from a YAML file.""" - groups_path = Path(path) if path is not None else DEFAULT_TASK_GROUPS_PATH - if not groups_path.exists(): - raise FileNotFoundError( - f"Task groups file not found at {groups_path}. Please create it or " - "provide a custom path." - ) - with groups_path.open("r", encoding="utf-8") as fh: - data = yaml.safe_load(fh) or {} +@dataclass +class _Task: + name: str + n_shots: list[int] | None = None - task_groups = data.get("task_groups", {}) - if not isinstance(task_groups, dict): - raise ValueError( - "task_groups.yaml is malformed. Expected 'task_groups' to be a mapping." - ) - return task_groups +@dataclass +class TaskGroup: + name: str + tasks: list[_Task] + suite: str + description: str + n_shots: list[int] | None = None + def __post_init__(self): + for task in self.tasks: + if task.n_shots is None and self.n_shots is not None: + task.n_shots = self.n_shots + elif task.n_shots is None and self.n_shots is None: + raise ValueError( + f"N_shots is not set for task {task.name} and no default n_shots is set for the task group: {self.name}" + ) -def resolve_task_group( - group_name: str, - task_groups: dict[str, dict], - console=None, - _chain: Sequence[str] | None = None, -) -> list[dict]: - """Resolve a task group into its concrete task definitions.""" + @classmethod + def from_dict(cls, name: str, data: dict) -> "TaskGroup": + tasks = [] + for task_data in data["tasks"]: + task_name = task_data["task"] + task_n_shots = task_data.get("n_shots") + tasks.append(_Task(name=task_name, n_shots=task_n_shots)) + + return cls( + name=name, + tasks=tasks, + suite=data["suite"], + description=data["description"], + n_shots=data.get("n_shots"), + ) - if _chain is None: - _chain = [] - if group_name not in task_groups: - raise ValueError( - f"Task group '{group_name}' is not defined in task-groups.yaml." - ) +@dataclass +class TaskSuperGroup: + name: str + task_groups: list[TaskGroup] + description: str - if group_name in _chain: - cycle = " -> ".join(list(_chain) + [group_name]) - raise ValueError(f"Circular task group reference detected: {cycle}") - - group_data = task_groups.get(group_name) or {} - chain = list(_chain) + [group_name] - resolved_tasks: list[dict] = [] - - subgroups = group_data.get("groups", []) - if subgroups: - if not isinstance(subgroups, list): - raise ValueError( - f"Task group '{group_name}' has an invalid 'groups' section; expected a list of group names." - ) - for subgroup in subgroups: - if not isinstance(subgroup, str): + def __post_init__(self): + resolved_groups = [] + for group in self.task_groups: + if isinstance(group, str): raise ValueError( - f"Task group '{group_name}' references an invalid subgroup entry: {subgroup!r}" + f"Task group '{group}' not found in available task groups" ) - resolved_tasks.extend( - resolve_task_group(subgroup, task_groups, console=console, _chain=chain) - ) - - for task_item in group_data.get("tasks", []) or []: - if "task" not in task_item: - message = ( - f"Skipping malformed task entry in group '{group_name}': {task_item}" - ) - if console is not None: - console.print(f"[yellow]{message}[/yellow]") - else: - logging.warning(message) - continue - resolved_tasks.append( - { - "task": task_item["task"], - "n_shots": list(task_item.get("n_shots", [0])), - } - ) - - return resolved_tasks + resolved_groups.append(group) + self.task_groups = resolved_groups + + @classmethod + def from_dict( + cls, name: str, data: dict, available_task_groups: dict[str, TaskGroup] + ) -> "TaskSuperGroup": + task_groups = [] + for task_group_data in data["task_groups"]: + group_name = task_group_data["task"] + if group_name not in available_task_groups: + raise ValueError( + f"Task group '{group_name}' not found in available task groups" + ) + task_groups.append(available_task_groups[group_name]) + return cls( + name=name, + task_groups=task_groups, + description=data["description"], + ) -def flatten_task_groups( - group_names: Iterable[str], - task_groups: dict[str, dict], - *, - console=None, - on_duplicate=None, -) -> list[tuple[str, int]]: - """Flatten multiple task groups into (task, n_shot) pairs without duplicates.""" - seen: set[tuple[str, int]] = set() - flattened: list[tuple[str, int]] = [] +def _parse_task_groups( + requested_groups: list[str], +) -> dict[str, TaskSuperGroup | TaskGroup]: + data = ( + yaml.safe_load((files("oellm.resources") / "task-groups.yaml").read_text()) or {} + ) - for group_name in group_names: - resolved = resolve_task_group(group_name, task_groups, console=console) - for entry in resolved: - task_name = entry["task"] - for n_shot in entry.get("n_shots", [0]): - pair = (task_name, int(n_shot)) - if pair in seen: - if on_duplicate is not None: - on_duplicate(group_name, pair) - continue - seen.add(pair) - flattened.append(pair) + task_groups: dict[str, TaskGroup] = {} - return flattened + for task_group_name, task_data in data["task_groups"].items(): + task_groups[task_group_name] = TaskGroup.from_dict(task_group_name, task_data) + super_groups: dict[str, TaskSuperGroup] = {} + for super_group_name, super_group_data in data.get("super_groups", {}).items(): + super_groups[super_group_name] = TaskSuperGroup.from_dict( + super_group_name, super_group_data, task_groups + ) -__all__ = ["load_task_groups", "resolve_task_group", "flatten_task_groups"] + result = {**task_groups, **super_groups} + return { + group_name: group + for group_name, group in result.items() + if group_name in requested_groups + } + + +@dataclass +class TaskGroupResult: + task: str + n_shot: int + suite: str + + +def _expand_task_groups(group_names: Iterable[str]) -> list[TaskGroupResult]: + parsed = _parse_task_groups([str(n).strip() for n in group_names if str(n).strip()]) + missing = {str(n).strip() for n in group_names if str(n).strip()} - set(parsed.keys()) + if missing: + raise ValueError(f"Unknown task group(s): {', '.join(sorted(missing))}") + + results: list[TaskGroupResult] = [] + + for _, group in parsed.items(): + if isinstance(group, TaskGroup): + suite = group.suite + for t in group.tasks: + shots = [int(s) for s in (t.n_shots or [])] + for shot in shots: + results.append(TaskGroupResult(task=t.name, n_shot=shot, suite=suite)) + else: + for g in group.task_groups: + suite = g.suite + for t in g.tasks: + shots = [int(s) for s in (t.n_shots or [])] + for shot in shots: + results.append( + TaskGroupResult(task=t.name, n_shot=shot, suite=suite) + ) + + return results diff --git a/oellm/utils.py b/oellm/utils.py new file mode 100644 index 0000000..ef2ed5f --- /dev/null +++ b/oellm/utils.py @@ -0,0 +1,529 @@ +import builtins +import fnmatch +import logging +import os +import socket +import subprocess +import sys +from collections.abc import Iterable +from contextlib import contextmanager +from functools import wraps +from importlib.resources import files +from pathlib import Path + +import yaml +from rich.console import Console +from rich.logging import RichHandler + +from oellm.task_cache import ( + capture_hf_dataset_calls, + dedupe_calls, + prewarm_from_payload, + task_cache_get_payload, + task_cache_lookup, + task_cache_mark_resolved, + task_cache_set_payload, +) + +_RICH_CONSOLE: Console | None = None + + +def get_console() -> Console: + global _RICH_CONSOLE + if _RICH_CONSOLE is None: + _RICH_CONSOLE = Console() + return _RICH_CONSOLE + + +def _ensure_singularity_image(image_name: str) -> None: + from huggingface_hub import hf_hub_download + + image_path = Path(os.getenv("EVAL_BASE_DIR")) / image_name + + try: + console = get_console() + with console.status( + "Downloading latest Singularity image from HuggingFace", spinner="dots" + ): + hf_hub_download( + repo_id="openeurollm/evaluation_singularity_images", + filename=image_name, + repo_type="dataset", + local_dir=os.getenv("EVAL_BASE_DIR"), + ) + except Exception as e: + logging.warning( + "Failed to fetch latest container image from HuggingFace: %s", str(e) + ) + if image_path.exists(): + logging.info("Using existing Singularity image at %s", image_path) + else: + raise RuntimeError( + f"No container image found at {image_path} and failed to download from HuggingFace. " + f"Cannot proceed with evaluation scheduling." + ) from e + + +def _setup_logging(verbose: bool = False): + rich_handler = RichHandler( + console=get_console(), + show_time=True, + log_time_format="%H:%M:%S", + show_path=False, + markup=True, + rich_tracebacks=True, + ) + + class RichFormatter(logging.Formatter): + def format(self, record): + record.msg = f"{record.getMessage()}" + return record.msg + + rich_handler.setFormatter(RichFormatter()) + + root_logger = logging.getLogger() + root_logger.handlers = [] + root_logger.addHandler(rich_handler) + root_logger.setLevel(logging.DEBUG if verbose else logging.INFO) + + +def _load_cluster_env() -> None: + """ + Loads the correct cluster environment variables from `clusters.yaml` based on the hostname. + """ + clusters = yaml.safe_load((files("oellm.resources") / "clusters.yaml").read_text()) + hostname = socket.gethostname() + + shared_cfg = clusters.get("shared", {}) or {} + + cluster_cfg_raw: dict | None = None + for name, cfg in clusters.items(): + if name == "shared": + continue + pattern = cfg.get("hostname_pattern") + if isinstance(pattern, str) and fnmatch.fnmatch(hostname, pattern): + cluster_cfg_raw = dict(cfg) + break + if cluster_cfg_raw is None: + raise ValueError(f"No cluster found for hostname: {hostname}") + + cluster_cfg_raw.pop("hostname_pattern", None) + + class _Default(dict): + def __missing__(self, key): + return "{" + key + "}" + + base_ctx = _Default({**os.environ, **{k: str(v) for k, v in cluster_cfg_raw.items()}}) + + resolved_shared = {k: str(v).format_map(base_ctx) for k, v in shared_cfg.items()} + + ctx = _Default({**base_ctx, **resolved_shared}) + + resolved_cluster = {k: str(v).format_map(ctx) for k, v in cluster_cfg_raw.items()} + + final_env = {**resolved_shared, **resolved_cluster} + for k, v in final_env.items(): + os.environ[k] = v + + +def _num_jobs_in_queue() -> int: + user = os.environ.get("USER") + cmd: list[str] = ["squeue"] + if user: + cmd += ["-u", user] + cmd += ["-h", "-t", "pending,running", "-r", "-o", "%i"] + + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + if result.stderr: + logging.warning(f"squeue error: {result.stderr.strip()}") + return 0 + + output = result.stdout.strip() + if not output: + return 0 + return sum(1 for line in output.splitlines() if line.strip()) + + +def _expand_local_model_paths(model: str | Path) -> list[Path]: + """ + Expands a local model path to include all checkpoints if it's a directory. + Recursively searches for models in subdirectories. + + Args: + model: Path to a model or directory containing models + + Returns: + List of paths to model directories containing safetensors files + """ + model_paths = [] + model_path = Path(model) + + if not model_path.exists() or not model_path.is_dir(): + return model_paths + + if any(model_path.glob("*.safetensors")): + model_paths.append(model_path) + return model_paths + + hf_path = model_path / "hf" + if hf_path.exists() and hf_path.is_dir(): + for subdir in hf_path.glob("*"): + if subdir.is_dir() and any(subdir.glob("*.safetensors")): + model_paths.append(subdir) + if model_paths: + return model_paths + + subdirs = [d for d in model_path.iterdir() if d.is_dir()] + + for subdir in subdirs: + if any(subdir.glob("*.safetensors")): + model_paths.append(subdir) + else: + hf_subpath = subdir / "hf" + if hf_subpath.exists() and hf_subpath.is_dir(): + for checkpoint_dir in hf_subpath.glob("*"): + if checkpoint_dir.is_dir() and any( + checkpoint_dir.glob("*.safetensors") + ): + model_paths.append(checkpoint_dir) + + if len(model_paths) > 1: + logging.info(f"Expanded '{model}' to {len(model_paths)} model checkpoints") + + return model_paths + + +def _process_model_paths(models: Iterable[str]): + """ + Processes model strings into a dict of model paths. + + Each model string can be a local path or a huggingface model identifier. + This function expands directory paths that contain multiple checkpoints. + """ + from huggingface_hub import snapshot_download + + console = get_console() + models_list = list(models) + + with console.status( + f"Processing models… 0/{len(models_list)}", spinner="dots" + ) as status: + for idx, model in enumerate(models_list, 1): + status.update(f"Checking model '{model}' ({idx}/{len(models_list)})") + per_model_paths: list[Path | str] = [] + + local_paths = _expand_local_model_paths(model) + if local_paths: + per_model_paths.extend(local_paths) + status.update(f"Using local model '{model}' ({idx}/{len(models_list)})") + else: + logging.info( + f"Model {model} not found locally, assuming it is a 🤗 hub model" + ) + logging.debug( + f"Downloading model {model} on the login node since the compute nodes may not have access to the internet" + ) + + if "," in model: + model_kwargs = dict( + [kv.split("=") for kv in model.split(",") if "=" in kv] + ) + + repo_id = model.split(",")[0] + + snapshot_kwargs = {} + if "revision" in model_kwargs: + snapshot_kwargs["revision"] = model_kwargs["revision"] + + status.update(f"Downloading '{repo_id}' ({idx}/{len(models_list)})") + try: + snapshot_download( + repo_id=repo_id, + cache_dir=Path(os.getenv("HF_HOME")) / "hub", + **snapshot_kwargs, + ) + per_model_paths.append(model) + except Exception as e: + logging.debug( + f"Failed to download model {model} from Hugging Face Hub. Continuing..." + ) + logging.debug(e) + else: + status.update(f"Downloading '{model}' ({idx}/{len(models_list)})") + snapshot_download( + repo_id=model, + cache_dir=Path(os.getenv("HF_HOME")) / "hub", + ) + per_model_paths.append(model) + + if not per_model_paths: + logging.warning( + f"Could not find any valid model for '{model}'. It will be skipped." + ) + + +def _pre_download_task_datasets( + tasks: Iterable[str], trust_remote_code: bool = True +) -> None: + processed: set[str] = set() + + misses: list[str] = [] + console = get_console() + with console.status("Checking lm-eval datasets…", spinner="dots") as status: + cache_hits = 0 + for task_name in tasks: + if not isinstance(task_name, str) or task_name in processed: + continue + processed.add(task_name) + if task_cache_lookup("lm-eval", task_name): + cache_hits += 1 + status.update( + f"Checking lm-eval datasets… {cache_hits} cached, {len(misses)} to prepare" + ) + continue + misses.append(task_name) + status.update( + f"Checking lm-eval datasets… {cache_hits} cached, {len(misses)} to prepare" + ) + + if not misses: + with console.status( + f"Using cached lm-eval datasets for {len(processed)} tasks…", + spinner="dots", + ) as status: + for task_name in processed: + if task_cache_lookup("lm-eval", task_name): + status.update(f"Loading cached dataset for '{task_name}'…") + prewarm_from_payload( + task_cache_get_payload("lm-eval", task_name), + trust_remote_code=trust_remote_code, + ) + return + + from datasets import DownloadMode # type: ignore + from lm_eval.tasks import TaskManager # type: ignore + + tm = TaskManager() + + with console.status( + f"Preparing lm-eval datasets… {len(misses)} remaining", + spinner="dots", + ) as status: + for idx, task_name in enumerate(misses, 1): + status.update(f"Preparing dataset for '{task_name}' ({idx}/{len(misses)})") + + task_config = { + "task": task_name, + "dataset_kwargs": {"trust_remote_code": trust_remote_code}, + } + + with capture_hf_dataset_calls() as captured_calls: + task_objects = tm.load_config(task_config) + + stack = [task_objects] + while stack: + current = stack.pop() + if isinstance(current, dict): + stack.extend(current.values()) + continue + if hasattr(current, "download") and callable(current.download): + try: + current.download( + download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS + ) # type: ignore[arg-type] + except TypeError as e: + logging.error( + f"Failed to download dataset for task '{task_name}' with download_mode=REUSE_DATASET_IF_EXISTS: {e}" + ) + current.download() # type: ignore[misc] + + if captured_calls: + payload = {"calls": dedupe_calls(captured_calls)} + task_cache_set_payload("lm-eval", task_name, payload) + task_cache_mark_resolved("lm-eval", task_name) + logging.debug(f"Finished dataset preparation for task '{task_name}'.") + + +def _pre_download_lighteval_datasets(tasks: Iterable[str]) -> None: + seen: set[str] = set() + misses: list[str] = [] + tasks = [str(task).strip() for task in tasks] + console = get_console() + with console.status("Checking lighteval datasets…", spinner="dots") as status: + cache_hits = 0 + for task in tasks: + if not task or task in seen: + continue + seen.add(task) + if task_cache_lookup("lighteval", task): + cache_hits += 1 + status.update( + f"Checking lighteval datasets… {cache_hits} cached, {len(misses)} to prepare" + ) + continue + misses.append(task) + status.update( + f"Checking lighteval datasets… {cache_hits} cached, {len(misses)} to prepare" + ) + + if not misses: + with console.status( + f"Using cached lighteval datasets for {len(seen)} tasks…", + spinner="dots", + ): + for task in seen: + if task_cache_lookup("lighteval", task): + prewarm_from_payload( + task_cache_get_payload("lighteval", task), + trust_remote_code=True, + ) + return + + with console.status( + f"Preparing lighteval datasets… {len(misses)} remaining", + spinner="dots", + ) as status: + for idx, task in enumerate(misses, 1): + status.update(f"Preparing dataset for '{task}' ({idx}/{len(misses)})") + with capture_hf_dataset_calls() as captured_calls: + from lighteval.tasks.lighteval_task import LightevalTask + from lighteval.tasks.registry import ( + TRUNCATE_FEW_SHOTS_DEFAULTS, + Registry, + ) + + reg = Registry(custom_tasks="lighteval.tasks.multilingual.tasks") + truncate_default = int(TRUNCATE_FEW_SHOTS_DEFAULTS) + + spec = task + if "|" not in spec: + spec = f"lighteval|{spec}|0|{truncate_default}" + elif spec.count("|") == 1: + spec = f"{spec}|0|{truncate_default}" + elif spec.count("|") == 2: + spec = f"{spec}|{truncate_default}" + + configs = reg.get_tasks_configs(spec) + task_dict = reg.get_tasks_from_configs(configs) + LightevalTask.load_datasets(task_dict) + + payload = ( + {"calls": dedupe_calls(captured_calls)} + if captured_calls + else {"calls": []} + ) + task_cache_set_payload("lighteval", task, payload) + task_cache_mark_resolved("lighteval", task) + + +@contextmanager +def capture_third_party_output(verbose: bool = False): + """ + Suppresses print/logging.info/logging.debug originating from non-project modules + unless verbose=True. + + A call is considered "third-party" if its immediate caller's file path is not + under the repository root (parent of the `oellm` package directory). + """ + if verbose: + yield + return + + package_root = Path(__file__).resolve().parent + + def is_internal_stack(skip: int = 2, max_depth: int = 20) -> bool: + f = sys._getframe(skip) + depth = 0 + while f and depth < max_depth: + code = f.f_code + filename = code.co_filename if code else "" + if filename: + p = Path(filename).resolve() + name = code.co_name if code else "" + # Skip logging internals and our filtering wrappers to find the real caller + if "/logging/__init__.py" in filename or name.startswith("filtered_"): + f = f.f_back + depth += 1 + continue + return p.is_relative_to(package_root) + f = f.f_back + depth += 1 + return False + + orig_print = builtins.print + orig_logger_info = logging.Logger.info + orig_logger_debug = logging.Logger.debug + orig_module_info = logging.info + orig_module_debug = logging.debug + + def filtered_print(*args, **kwargs): + if is_internal_stack(): + return orig_print(*args, **kwargs) + # third-party: drop + return None + + def filtered_logger_info(self, msg, *args, **kwargs): + if is_internal_stack(): + return orig_logger_info(self, msg, *args, **kwargs) + return None + + def filtered_logger_debug(self, msg, *args, **kwargs): + if is_internal_stack(): + return orig_logger_debug(self, msg, *args, **kwargs) + return None + + def filtered_module_info(msg, *args, **kwargs): + if is_internal_stack(): + return orig_module_info(msg, *args, **kwargs) + return None + + def filtered_module_debug(msg, *args, **kwargs): + if is_internal_stack(): + return orig_module_debug(msg, *args, **kwargs) + return None + + builtins.print = filtered_print + logging.Logger.info = filtered_logger_info # type: ignore[assignment] + logging.Logger.debug = filtered_logger_debug # type: ignore[assignment] + logging.info = filtered_module_info # type: ignore[assignment] + logging.debug = filtered_module_debug # type: ignore[assignment] + + try: + yield + finally: + builtins.print = orig_print + logging.Logger.info = orig_logger_info # type: ignore[assignment] + logging.Logger.debug = orig_logger_debug # type: ignore[assignment] + logging.info = orig_module_info # type: ignore[assignment] + logging.debug = orig_module_debug # type: ignore[assignment] + + +def capture_third_party_output_from_kwarg( + verbose_kwarg: str = "verbose", default: bool = False +): + """ + Decorator factory that wraps the function execution inside + capture_third_party_output(verbose=kwargs.get(verbose_kwarg, default)). + """ + + def _decorator(func): + @wraps(func) + def _wrapper(*args, **kwargs): + verbose_value = bool(kwargs.get(verbose_kwarg, default)) + with capture_third_party_output(verbose=verbose_value): + return func(*args, **kwargs) + + return _wrapper + + return _decorator + + +def _filter_warnings(): + """ + Filters warnings from the lm_eval and lighteval libraries. + """ + import warnings + + warnings.filterwarnings("ignore", module="lm_eval") + warnings.filterwarnings("ignore", module="lighteval") diff --git a/pyproject.toml b/pyproject.toml index 7c8e470..d699cba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,19 +3,27 @@ name = "oellm" version = "0.1.0" description = "OpenEuroLLM CLI" readme = "README.md" -requires-python = ">=3.12,<3.14" +requires-python = ">=3.12" dependencies = [ "pandas", - "jsonargparse[all]", - "datasets<4.0.0", + "jsonargparse", + "datasets", "rich", "torch", "lm-eval", + "lighteval[extended_tasks,multilingual] @ git+https://github.com/huggingface/lighteval.git@63424f4e795ecc577b90646381b374af3a627978", + "pydantic<2.12", "huggingface_hub", "pyyaml", "questionary", ] +[project.optional-dependencies] +dev = [ + "pytest>=8.4.1", + "pre-commit", +] + [project.scripts] oellm = "oellm.main:main" @@ -26,7 +34,7 @@ build-backend = "uv_build" [tool.uv.build-backend] module-name = "oellm" module-root = "" -include = ["oellm/clusters.yaml", "oellm/task-groups.yaml"] +include = ["oellm/resources/*"] [tool.uv.sources] torch = [ @@ -42,8 +50,8 @@ url = "https://download.pytorch.org/whl/cpu" explicit = true [tool.ruff] -line-length = 88 -target-version = "py38" +line-length = 90 +target-version = "py312" [tool.ruff.lint] select = [ @@ -70,8 +78,3 @@ quote-style = "double" indent-style = "space" skip-magic-trailing-comma = false line-ending = "auto" - -[dependency-groups] -dev = [ - "pytest>=8.4.1", -]