diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b45a188..00db357 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,21 +7,27 @@ on: - hobj/** # Changes to workflows - .github/workflows/ci.yml - # Changes to pyproject.toml + # Changes to project/dependency metadata - 'pyproject.toml' + - 'uv.lock' jobs: unit_tests: runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.12"] steps: - name: Checkout code uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v5 with: - python-version: '3.12' - cache: 'pip' - cache-dependency-path: setup.py # See https://github.com/actions/setup-python/blob/main/docs/advanced-usage.md#caching-packages - - name: Install hobj - run: pip3 install -e . + python-version: ${{ matrix.python-version }} + - name: Set up uv + uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + - name: Install dependencies + run: uv sync --locked --dev - name: Run pytests - run: pytest -s + run: uv run pytest -s diff --git a/.gitignore b/.gitignore index f511459..ac631e5 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ __pycache__/ *.egg-info/ dist/ +/data/ diff --git a/Makefile b/Makefile index c596605..bf3fd22 100644 --- a/Makefile +++ b/Makefile @@ -6,3 +6,6 @@ check: uv run ty check && \ uv run ruff check && \ uv run ruff format --check + +test: + uv run pytest tests diff --git a/examples/dev.py b/examples/dev.py deleted file mode 100644 index 29661e1..0000000 --- a/examples/dev.py +++ /dev/null @@ -1 +0,0 @@ -import hobj diff --git a/hobj/benchmarks/binary_classification/benchmark.py b/hobj/benchmarks/binary_classification/benchmark.py index 8f5b3f6..90820a6 100644 --- a/hobj/benchmarks/binary_classification/benchmark.py +++ b/hobj/benchmarks/binary_classification/benchmark.py @@ -1,9 +1,10 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + import numpy as np import pydantic import xarray as xr -from dataclasses import dataclass from tqdm import tqdm -from typing import List, Dict, Union, Tuple, Optional from hobj.benchmarks.binary_classification.estimator import LearningCurveStatistics from hobj.benchmarks.binary_classification.simulation import BinaryClassificationSubtask, BinaryClassificationSubtaskResult diff --git a/hobj/benchmarks/binary_classification/estimator.py b/hobj/benchmarks/binary_classification/estimator.py index 370f74c..0b80cb7 100644 --- a/hobj/benchmarks/binary_classification/estimator.py +++ b/hobj/benchmarks/binary_classification/estimator.py @@ -1,7 +1,8 @@ +from dataclasses import dataclass +from typing import Dict, List, Tuple + import numpy as np import xarray as xr -from dataclasses import dataclass -from typing import List, Dict, Tuple from hobj.benchmarks.binary_classification.simulation import BinaryClassificationSubtaskResult from hobj.stats import binomial as binomial_funcs @@ -193,4 +194,4 @@ def _get_bootstrap_resamples_by_session( return LearningCurveStatistics.BootstrapSamples( boot_k=boot_k, boot_n=boot_n, - ) \ No newline at end of file + ) diff --git a/hobj/benchmarks/binary_classification/simulation.py b/hobj/benchmarks/binary_classification/simulation.py index 6e364f4..ad35f89 100644 --- a/hobj/benchmarks/binary_classification/simulation.py +++ b/hobj/benchmarks/binary_classification/simulation.py @@ -1,9 +1,9 @@ +from typing import List, Optional, Union + import numpy as np import pydantic -from typing import List, Union, Optional - +from hobj.types import ImageId from hobj.learning_models import BinaryLearningModel -from mref import ImageRef # %% @@ -26,14 +26,14 @@ class BinaryClassificationSubtask(pydantic.BaseModel): frozen=True ) - classA: List[ImageRef] - classB: List[ImageRef] + classA: List[ImageId] + classB: List[ImageId] ntrials: int = pydantic.Field(description='The number of trials in the subtask.', gt=0) replace: bool = pydantic.Field(description='Whether to show stimulus images with replacement or not.') @pydantic.field_validator('classA', 'classB', mode='after') @classmethod - def sort_image_refs(cls, value: List[ImageRef]) -> List[ImageRef]: + def sort_image_refs(cls, value: List[ImageId]) -> List[ImageId]: return sorted(value) @pydantic.model_validator(mode='after') diff --git a/hobj/benchmarks/generalization/benchmark.py b/hobj/benchmarks/generalization/benchmark.py index 3c9e856..69266ac 100644 --- a/hobj/benchmarks/generalization/benchmark.py +++ b/hobj/benchmarks/generalization/benchmark.py @@ -1,12 +1,13 @@ +from dataclasses import dataclass +from typing import List, Tuple + import numpy as np import pydantic import xarray as xr -from dataclasses import dataclass from tqdm import tqdm -from typing import List, Tuple from hobj.benchmarks.generalization.estimator import GeneralizationStatistics -from hobj.benchmarks.generalization.simulator import GeneralizationSubtask, GeneralizationSessionResult +from hobj.benchmarks.generalization.simulator import GeneralizationSessionResult, GeneralizationSubtask from hobj.learning_models import BinaryLearningModel from hobj.stats.ci import estimate_basic_bootstrap_CI diff --git a/hobj/benchmarks/generalization/estimator.py b/hobj/benchmarks/generalization/estimator.py index cc77697..6dbd337 100644 --- a/hobj/benchmarks/generalization/estimator.py +++ b/hobj/benchmarks/generalization/estimator.py @@ -1,8 +1,9 @@ -import numpy as np import warnings -import xarray as xr from typing import List +import numpy as np +import xarray as xr + from hobj.benchmarks.generalization.simulator import GeneralizationSessionResult from hobj.stats import binomial as binomial_funcs diff --git a/hobj/benchmarks/generalization/simulator.py b/hobj/benchmarks/generalization/simulator.py index dc690ad..9bddd95 100644 --- a/hobj/benchmarks/generalization/simulator.py +++ b/hobj/benchmarks/generalization/simulator.py @@ -1,12 +1,11 @@ -from typing import List, Union, Dict, Optional +import collections +from typing import Dict, List, Optional, Union import numpy as np import pydantic -from mref import ImageRef from hobj.learning_models import BinaryLearningModel -import collections - +from hobj.types import ImageId # %% class GeneralizationSessionResult(pydantic.BaseModel): @@ -37,15 +36,15 @@ class GeneralizationSubtask(pydantic.BaseModel): frozen=True ) - support_imageA: ImageRef - support_imageB: ImageRef - test_imagesA: List[ImageRef] - test_imagesB: List[ImageRef] - image_ref_to_transformation: Dict[ImageRef, str] + support_imageA: ImageId + support_imageB: ImageId + test_imagesA: List[ImageId] + test_imagesB: List[ImageId] + image_ref_to_transformation: Dict[ImageId, str] @pydantic.field_validator('test_imagesA', 'test_imagesB', mode='after') @classmethod - def sort_image_refs(cls, value: List[ImageRef]) -> List[ImageRef]: + def sort_image_refs(cls, value: List[ImageId]) -> List[ImageId]: return sorted(value) @pydantic.model_validator(mode='after') diff --git a/hobj/benchmarks/make_model.py b/hobj/benchmarks/make_model.py index 2df5f58..0593021 100644 --- a/hobj/benchmarks/make_model.py +++ b/hobj/benchmarks/make_model.py @@ -1,18 +1,30 @@ """ This module provides an alternative interface for instantiating a linear learning model. """ -from hobj.learning_models.linear import LinearLearner, RepresentationalModel -import hobj.learning_models.linear.update_rules as update_rules -from typing import Literal, Dict -import mref +from functools import lru_cache +from typing import Literal + import numpy as np -from typing import List + +import hobj.learning_models.update_rules as update_rules +from hobj.learning_models import LinearLearner, RepresentationalModel +from hobj.types import ImageId + + +# %% +@lru_cache(maxsize=1) +def _get_calibration_image_ids() -> list[ImageId]: + """ + Returns the ImageIds of the warmup images that are used for calibrating the features of the linear learner. + Caches the result to avoid redundant computation. + """ + raise NotImplementedError # %% def make_linear_learner_from_features( - ref_to_features: Dict[mref.ImageRef, np.ndarray], - calibration_images: List[mref.ImageRef], + features: np.ndarray, + image_ids: list[ImageId], update_rule_name: Literal[ 'Prototype', 'Square', @@ -28,13 +40,13 @@ def make_linear_learner_from_features( """ Instantiates a linear learning model from precomputed features. :param ref_to_features: Dict[mref.ImageRef, np.ndarray], the features to use. - :param calibration_images: List[mref.ImageRef], the images that will be used to calibrate the features (i.e. for mean centering and ensuring they fit within a unit ball). :param update_rule_name: str, the name of the update rule to use. :param alpha: float, the learning rate. :return: LinearLearner """ - f_calibration = np.array([ref_to_features[ref] for ref in calibration_images]) + ref_to_features = {ref: features[i] for i, ref in enumerate(image_ids)} + f_calibration = np.array([ref_to_features[ref] for ref in _get_calibration_image_ids()]) mu_calibration = np.mean(f_calibration, axis=0) norms_calibration = np.linalg.norm(f_calibration - mu_calibration, axis=1) norm_cutoff = np.quantile(norms_calibration, 0.999) # Will clip the rest @@ -55,4 +67,4 @@ def make_linear_learner_from_features( image_ref_to_features=ref_to_calibrated_features ), update_rule=update_rule_name(alpha=alpha) - ) \ No newline at end of file + ) diff --git a/hobj/benchmarks/mut_highvar_benchmark.py b/hobj/benchmarks/mut_highvar_benchmark.py index 2d76f7e..5068538 100644 --- a/hobj/benchmarks/mut_highvar_benchmark.py +++ b/hobj/benchmarks/mut_highvar_benchmark.py @@ -1,10 +1,11 @@ +from typing import Dict, List + import numpy as np -from typing import List, Dict from hobj.benchmarks.binary_classification.benchmark import LearningCurveBenchmark, LearningCurveBenchmarkConfig, TargetSubtaskData from hobj.benchmarks.binary_classification.simulation import BinaryClassificationSubtask, BinaryClassificationSubtaskResult -from hobj.data.behavior import load_highvar_behavior -from hobj.data.images import MutatorHighVarImageset +from hobj.data_loaders.behavior import load_highvar_behavior +from hobj.data_loaders.images import MutatorHighVarImageset # %% @@ -21,7 +22,7 @@ def __init__(self): # Normalize data for benchmark: sha256_to_category = { - ref.sha256: imageset.get_annotation(image_ref=ref).category for ref in imageset.image_refs + ref.sha256: imageset.get_annotation(image_id=ref).category for ref in imageset.image_ids } subtask_name_to_results = {} @@ -46,8 +47,8 @@ def __init__(self): # Instantiate the subtask if it does not exist: if subtask_name not in subtask_name_to_subtask: subtask = BinaryClassificationSubtask( - classA=imageset.category_to_image_refs[cat0], - classB=imageset.category_to_image_refs[cat1], + classA=imageset.category_to_image_ids[cat0], + classB=imageset.category_to_image_ids[cat1], ntrials=100, replace=False, ) @@ -87,4 +88,4 @@ def __init__(self): if __name__ == '__main__': experiment = MutatorHighVarBenchmark() - print(sorted(experiment.config.subtask_name_to_data.keys())) \ No newline at end of file + print(sorted(experiment.config.subtask_name_to_data.keys())) diff --git a/hobj/benchmarks/mut_oneshot_benchmark.py b/hobj/benchmarks/mut_oneshot_benchmark.py index 226b1de..71851d3 100644 --- a/hobj/benchmarks/mut_oneshot_benchmark.py +++ b/hobj/benchmarks/mut_oneshot_benchmark.py @@ -2,13 +2,13 @@ # Coercing human data from typing import Dict, List -from hobj.benchmarks.generalization.benchmark import GeneralizationBenchmarkConfig, GeneralizationBenchmark, GeneralizationSessionResult +from hobj.benchmarks.generalization.benchmark import GeneralizationBenchmark, GeneralizationBenchmarkConfig, GeneralizationSessionResult from hobj.benchmarks.generalization.estimator import GeneralizationStatistics from hobj.benchmarks.generalization.simulator import GeneralizationSubtask -from hobj.data.behavior import load_oneshot_behavior -from hobj.data.images import MutatorOneShotImageset -from mref import ImageRef +from hobj.data_loaders.behavior import load_oneshot_behavior +from hobj.data_loaders.images import MutatorOneShotImageset +from hobj.types import ImageId # %% @@ -96,24 +96,24 @@ def __init__(self): # Map image refs to transformation ids image_ref_to_transformation_id = {} - cat_to_support_image: Dict[str, ImageRef] = {} - cat_to_test_images: Dict[str, List[ImageRef]] = {} + cat_to_support_image: Dict[str, ImageId] = {} + cat_to_test_images: Dict[str, List[ImageId]] = {} - for ref in imageset.image_refs: - annotation = imageset.get_annotation(image_ref=ref) + for image_id in imageset.image_ids: + annotation = imageset.get_annotation(image_id=image_id) transformation_id = f"{annotation.transformation} | {annotation.transformation_level}" - image_ref_to_transformation_id[ref] = transformation_id + image_ref_to_transformation_id[image_id] = transformation_id if annotation.transformation == 'original': if annotation.category not in cat_to_support_image: - cat_to_support_image[annotation.category] = ref + cat_to_support_image[annotation.category] = image_id else: raise ValueError(f"Multiple support images for category {annotation.category}") else: if annotation.category not in cat_to_test_images: cat_to_test_images[annotation.category] = [] - cat_to_test_images[annotation.category].append(ref) + cat_to_test_images[annotation.category].append(image_id) # Assemble subtask simulators subtasks = [] @@ -148,8 +148,8 @@ def __init__(self): observed_categories = set() for i_trial, sha in enumerate(session.stimulus_sha256_seq): - ref = ImageRef(sha256=sha) - annotation = imageset.get_annotation(image_ref=ref) + image_id = sha + annotation = imageset.get_annotation(image_id=image_id) # Add stimulus category to observed categories observed_categories.add(annotation.category) @@ -166,7 +166,7 @@ def __init__(self): ncatch += 1 else: assert annotation.transformation != 'original' - transformation_id = image_ref_to_transformation_id[ref] + transformation_id = image_ref_to_transformation_id[image_id] # Keep only benchmarked transformations if transformation_id in self.transformation_ids: diff --git a/hobj/data/images/__init__.py b/hobj/data/images/__init__.py deleted file mode 100644 index 0fc51d1..0000000 --- a/hobj/data/images/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from hobj.data.images.imagesets.highvar import MutatorHighVarImageset -from hobj.data.images.imagesets.oneshot import MutatorOneShotImageset -from hobj.data.images.imagesets.warmup import MutatorWarmupImageset -from hobj.data.images.imagesets.probe import ProbeImageset \ No newline at end of file diff --git a/hobj/data/images/imagesets/__init__.py b/hobj/data/images/imagesets/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/hobj/data/images/imagesets/highvar.py b/hobj/data/images/imagesets/highvar.py deleted file mode 100644 index f39778c..0000000 --- a/hobj/data/images/imagesets/highvar.py +++ /dev/null @@ -1,46 +0,0 @@ -import mref.media_references -from hobj.data.images.template import Imageset -import pydantic - -from typing import Dict, List - - - -class MutatorHighVarAnnotation(pydantic.BaseModel): - category: str = pydantic.Field( - examples=[ - 'MutatorB2000_4872', - 'MutatorB2000_419', - ], - pattern=r'^MutatorB2000_\d+$' - ) - - -class MutatorHighVarImageset(Imageset[MutatorHighVarAnnotation]): - manifest_url = 'https://hlbdatasets.s3.us-east-1.amazonaws.com/imagesets/mutator-highvar/mutator-highvar-manifest.json' - zipped_images_url = 'https://hlbdatasets.s3.us-east-1.amazonaws.com/imagesets/mutator-highvar/MutatorB2000_Subset128_FullVar_Train.zip' - annotation_schema = MutatorHighVarAnnotation - - def __init__(self): - super().__init__() - - self._category_to_image_refs: Dict[str, List[mref.media_references.ImageRef]] = {} - - for ref in self.image_refs: - annotation = self.get_annotation(image_ref=ref) - category = annotation.category - if category not in self._category_to_image_refs: - self._category_to_image_refs[category] = [] - self._category_to_image_refs[category].append(ref) - - for category in self._category_to_image_refs: - self._category_to_image_refs[category] = sorted(self._category_to_image_refs[category]) - - @property - def category_to_image_refs(self) -> Dict[str, List[mref.media_references.ImageRef]]: - return self._category_to_image_refs - - -if __name__ == '__main__': - - imageset = MutatorHighVarImageset() diff --git a/hobj/data/images/template.py b/hobj/data/images/template.py deleted file mode 100644 index f8f93d7..0000000 --- a/hobj/data/images/template.py +++ /dev/null @@ -1,152 +0,0 @@ -import tempfile -from abc import ABC -from pathlib import Path -from typing import Any, Dict, TypeVar, Generic, List - -import PIL.Image -import pydantic -from tqdm import tqdm - -from hobj.utils.file_io import unzip_file -from mref import ImageRef -from hobj.data.store import default_data_store -from mref import FileSystemStorage -import warnings - - -class ImageManifestEntry(pydantic.BaseModel, ABC): - sha256: str = pydantic.Field(pattern=r'^[a-f0-9]{64}$') - relpath: Path = pydantic.Field( - description='The relative path to the image file inside of the unzipped imageset directory.' - ) - annotation: Any = pydantic.Field( - description='Arbitrary annotation data for the image, in a JSON-valid format.' - ) - - -class ImageManifest(pydantic.BaseModel): - entries: Dict[str, ImageManifestEntry] = pydantic.Field( - description='A mapping from a unique image ID to image manifest entries.' - ) - - -# This is a type variable for ImageManifestEntry.annotation: -IA = TypeVar('IA') - - -class Imageset(Generic[IA], ABC): - """ - Imagesets are a combination of: - - Images - - Annotations on those images - """ - - manifest_url: str - zipped_images_url: str - annotation_schema: IA - - def __init__( - self, - data_store: FileSystemStorage = None, - redownload=False, - ): - """ - Unwrap the image manifest and save the images to the cache. - """ - - if not data_store: - self.data_store: FileSystemStorage = default_data_store - - # Load the manifest if it is already cached - manifest_data = self.data_store.download_json_from_url(url=self.manifest_url, register=True) # Todo - image_manifest = ImageManifest(**manifest_data) - - self._register_image_urls(manifest=image_manifest, redownload=redownload) - self._manifest = image_manifest - - self._image_id_to_annotation: Dict[str, IA] = {} - self._image_id_to_sha256: Dict[str, str] = {} - self._sha256_to_image_ids: Dict[str, List[str]] = {} - self._image_refs: List[ImageRef] = [] - - for image_id, entry in image_manifest.entries.items(): - image_ref = ImageRef(sha256=entry.sha256) - self._image_refs.append(image_ref) - self._image_id_to_sha256[image_id] = image_ref.sha256 - self._image_id_to_annotation[image_id] = self.annotation_schema(**entry.annotation) - if image_ref.sha256 not in self._sha256_to_image_ids: - self._sha256_to_image_ids[image_ref.sha256] = [] - self._sha256_to_image_ids[image_ref.sha256].append(image_id) - - def _register_image_urls(self, manifest: ImageManifest, redownload: bool = False): - """ - Ensures the entries of the manifest are registered in the data store. - """ - - num_undownloaded_images = 0 - for image_id, manifest_entry in manifest.entries.items(): - ref = ImageRef(sha256=manifest_entry.sha256) - if not self.data_store.check_data_exists(ref=ref): - num_undownloaded_images += 1 - - if num_undownloaded_images == 0: - return - - print(f'Missing {num_undownloaded_images}/{len(manifest.entries)} images for this imageset.') - - # Download the images - zipped_images_path = self.data_store.download_zip_path_from_url(url=self.zipped_images_url, register=True) # todo type correctly - - # Make a tempdir to unzip the images - with tempfile.TemporaryDirectory() as tempdir: - tempdir = Path(tempdir) - unzip_file(zip_path=zipped_images_path, output_dir=tempdir) - - # Register the images - pbar = tqdm(total=len(manifest.entries)) - for image_id, manifest_entry in manifest.entries.items(): - reported_sha256 = manifest_entry.sha256 - relpath = manifest_entry.relpath - image_path = tempdir / relpath - image_data = PIL.Image.open(image_path) - - image_ref = ImageRef.from_image(image=image_data) - - if not image_ref.sha256 == reported_sha256: - raise ValueError(f"SHA256 mismatch for image {manifest_entry}: {image_ref.sha256} != {reported_sha256}") - - # Store image - self.data_store.register_image(image=image_data) - pbar.update(1) - - @property - def image_refs(self) -> List[ImageRef]: - """ - List of image refs in this imageset. - :return: - """ - return self._image_refs - - def get_annotation(self, *, image_ref: ImageRef = None, sha256: str = None, ) -> IA: - """ - Get the annotation for a given image. If an image has multiple annotations, this will throw an error. - """ - - if sha256 is None: - sha256 = image_ref.sha256 - - image_ids = self._sha256_to_image_ids[sha256] - if len(image_ids) > 1: - warnings.warn(f"Image {sha256} has multiple annotations: {image_ids}. Returning the first one.") - image_id = image_ids[0] - entry = self._image_id_to_annotation[image_id] - return entry - - def __len__(self) -> int: - return len(self.image_refs) - - def __repr__(self): - return f"{self.__class__.__name__}({len(self)})" - - def __str__(self): - return f"Imageset(n={len(self)})" diff --git a/hobj/data/store.py b/hobj/data/store.py deleted file mode 100644 index c10786a..0000000 --- a/hobj/data/store.py +++ /dev/null @@ -1,13 +0,0 @@ - -import hobj.config - - -import mref -__all__ = [ - 'default_data_store' -] - - -# %% -# Default data store -default_data_store = mref.FileSystemStorage(cachedir=hobj.config.cachedir) diff --git a/hobj/data/__init__.py b/hobj/data_loaders/__init__.py similarity index 100% rename from hobj/data/__init__.py rename to hobj/data_loaders/__init__.py diff --git a/hobj/data/behavior.py b/hobj/data_loaders/behavior.py similarity index 77% rename from hobj/data/behavior.py rename to hobj/data_loaders/behavior.py index d44ade8..3452ce7 100644 --- a/hobj/data/behavior.py +++ b/hobj/data_loaders/behavior.py @@ -1,9 +1,11 @@ -from typing import List, Literal import datetime +import json +from pathlib import Path +from typing import List, Literal import pydantic -from hobj.data.store import default_data_store +from hobj.utils.file_io import download_json __all__ = ['load_highvar_behavior', 'load_oneshot_behavior'] @@ -53,16 +55,22 @@ def validate_lengths(self) -> 'HumanLearningSession': # %% def _load_learning_sessions( dataset_url: str, - redownload: bool + cache_filename: str, + cachedir: Path | None = None, + redownload: bool = False, ) -> List[HumanLearningSession]: - data_store = default_data_store - # Download the data: - json_data = data_store.download_json_from_url( - url=dataset_url, - register=True, - ) + repo_root = Path(__file__).resolve().parents[2] + cache_root = (cachedir if cachedir is not None else repo_root / 'data').resolve() + behavior_dir = cache_root / 'behavior' + behavior_dir.mkdir(parents=True, exist_ok=True) + dataset_path = behavior_dir / cache_filename + + if redownload or not dataset_path.exists(): + json_data = download_json(dataset_url) + dataset_path.write_text(json.dumps(json_data, indent=2)) + else: + json_data = json.loads(dataset_path.read_text()) - # class LearningSessionDataset(pydantic.BaseModel): sessions: List[HumanLearningSession] @@ -72,8 +80,9 @@ class LearningSessionDataset(pydantic.BaseModel): # %% Data loaders def load_highvar_behavior( + remove_probe_trials: bool = True, + cachedir: Path | None = None, redownload: bool = False, - remove_probe_trials: bool = True ) -> List[HumanLearningSession]: """ Load the "raw" human learning data from Experiment 1 of Lee and DiCarlo 2023. @@ -82,7 +91,9 @@ def load_highvar_behavior( sessions = _load_learning_sessions( dataset_url='https://hlbdatasets.s3.us-east-1.amazonaws.com/behavior/mutator-highvar-human-learning-data.json', - redownload=redownload + cache_filename='mutator-highvar-human-learning-data.json', + cachedir=cachedir, + redownload=redownload, ) if not remove_probe_trials: @@ -109,7 +120,10 @@ def filter(vals: list): return filtered_sessions -def load_oneshot_behavior(redownload: bool = False) -> List[HumanLearningSession]: +def load_oneshot_behavior( + cachedir: Path | None = None, + redownload: bool = False, +) -> List[HumanLearningSession]: """ Load the "raw" human learning data from Experiment 2 of Lee and DiCarlo 2023. :return: @@ -117,6 +131,8 @@ def load_oneshot_behavior(redownload: bool = False) -> List[HumanLearningSession sessions = _load_learning_sessions( dataset_url='https://hlbdatasets.s3.us-east-1.amazonaws.com/behavior/mutator-oneshot-human-learning-data.json', + cache_filename='mutator-oneshot-human-learning-data.json', + cachedir=cachedir, redownload=redownload ) diff --git a/hobj/data_loaders/images/__init__.py b/hobj/data_loaders/images/__init__.py new file mode 100644 index 0000000..1e11a53 --- /dev/null +++ b/hobj/data_loaders/images/__init__.py @@ -0,0 +1,5 @@ +from hobj.data_loaders.images.highvar import MutatorHighVarImageset +from hobj.data_loaders.images.oneshot import MutatorOneShotImageset +from hobj.data_loaders.images.probe import ProbeImageset +from hobj.data_loaders.images.warmup import MutatorWarmupImageset + diff --git a/hobj/data_loaders/images/highvar.py b/hobj/data_loaders/images/highvar.py new file mode 100644 index 0000000..ea98517 --- /dev/null +++ b/hobj/data_loaders/images/highvar.py @@ -0,0 +1,47 @@ +from pathlib import Path +import pydantic + +from hobj.data_loaders.images.template import Imageset + +from hobj.types import ImageId + +# %% + +class MutatorHighVarAnnotation(pydantic.BaseModel): + category: str = pydantic.Field( + examples=[ + 'MutatorB2000_4872', + 'MutatorB2000_419', + ], + pattern=r'^MutatorB2000_\d+$' + ) + + +class MutatorHighVarImageset(Imageset[MutatorHighVarAnnotation]): + manifest_url = 'https://hlbdatasets.s3.us-east-1.amazonaws.com/imagesets/mutator-highvar/mutator-highvar-manifest.json' + zipped_images_url = 'https://hlbdatasets.s3.us-east-1.amazonaws.com/imagesets/mutator-highvar/MutatorB2000_Subset128_FullVar_Train.zip' + annotation_schema = MutatorHighVarAnnotation + + def __init__(self, cachedir: Path | None = None, redownload: bool = False): + super().__init__(cachedir=cachedir, redownload=redownload) + + self._category_to_image_ids: dict[str, list[ImageId]] = {} + + for ref in self.image_ids: + annotation = self.get_annotation(image_id=ref) + category = annotation.category + if category not in self._category_to_image_ids: + self._category_to_image_ids[category] = [] + self._category_to_image_ids[category].append(ref) + + for category in self._category_to_image_ids: + self._category_to_image_ids[category] = sorted(self._category_to_image_ids[category]) + + @property + def category_to_image_ids(self) -> dict[str, list[ImageId]]: + return self._category_to_image_ids + + +if __name__ == '__main__': + + imageset = MutatorHighVarImageset() diff --git a/hobj/data/images/imagesets/oneshot.py b/hobj/data_loaders/images/oneshot.py similarity index 86% rename from hobj/data/images/imagesets/oneshot.py rename to hobj/data_loaders/images/oneshot.py index 80d1646..15c58fd 100644 --- a/hobj/data/images/imagesets/oneshot.py +++ b/hobj/data_loaders/images/oneshot.py @@ -1,9 +1,9 @@ - +from pathlib import Path from typing import Literal import pydantic -from hobj.data.images.template import Imageset +from hobj.data_loaders.images.template import Imageset class MutatorOneShotAnnotation(pydantic.BaseModel): @@ -48,6 +48,9 @@ class MutatorOneShotImageset(Imageset[MutatorOneShotAnnotation]): zipped_images_url = 'https://hlbdatasets.s3.us-east-1.amazonaws.com/imagesets/mutator-oneshot/MutatorB2000_Oneshot64.zip' annotation_schema = MutatorOneShotAnnotation + def __init__(self, cachedir: Path | None = None, redownload: bool = False): + super().__init__(cachedir=cachedir, redownload=redownload) + if __name__ == '__main__': - imageset = MutatorOneShotImageset() \ No newline at end of file + imageset = MutatorOneShotImageset() diff --git a/hobj/data/images/imagesets/probe.py b/hobj/data_loaders/images/probe.py similarity index 91% rename from hobj/data/images/imagesets/probe.py rename to hobj/data_loaders/images/probe.py index d9a7a0f..473ce35 100644 --- a/hobj/data/images/imagesets/probe.py +++ b/hobj/data_loaders/images/probe.py @@ -1,8 +1,10 @@ -from hobj.data.images.template import Imageset -import pydantic from typing import Literal +import pydantic + +from hobj.data_loaders.images.template import Imageset + class ProbeAnnotation(pydantic.BaseModel): color: Literal['blue', 'orange'] diff --git a/hobj/data_loaders/images/template.py b/hobj/data_loaders/images/template.py new file mode 100644 index 0000000..51e527e --- /dev/null +++ b/hobj/data_loaders/images/template.py @@ -0,0 +1,165 @@ +import json +from abc import ABC +from pathlib import Path +from typing import Any, Dict, Generic, List, TypeVar + +import PIL.Image +import pydantic + +from hobj.types import ImageId +from hobj.utils.file_io import download_file, download_json, unzip_file +from hobj.utils.hash import hash_image + + +# %% +class ImageManifestEntry(pydantic.BaseModel, ABC): + sha256: str = pydantic.Field(pattern=r'^[a-f0-9]{64}$') + relpath: Path = pydantic.Field( + description='The relative path to the image file inside of the unzipped imageset directory.' + ) + annotation: Any = pydantic.Field( + description='Arbitrary annotation data for the image, in a JSON-valid format.' + ) + + +class ImageManifest(pydantic.BaseModel): + entries: Dict[ImageId, ImageManifestEntry] = pydantic.Field( + description='A mapping from a unique image ID to image manifest entries.' + ) + + +# This is a type variable for ImageManifestEntry.annotation: +IA = TypeVar('IA') + + +class Imageset(Generic[IA], ABC): + """ + Imagesets are a combination of: + - Images + - Annotations on those images + """ + + manifest_url: str + zipped_images_url: str + annotation_schema: IA + + def __init__( + self, + cachedir: Path | None = None, + redownload=False, + ): + """ + Download and materialize the imageset into a local cache directory. + """ + + repo_root = Path(__file__).resolve().parents[3] + self.cachedir = (cachedir if cachedir is not None else repo_root / 'data').resolve() + self.cachedir.mkdir(parents=True, exist_ok=True) + + self._dataset_dir = self.cachedir / self.__class__.__name__ + self._dataset_dir.mkdir(parents=True, exist_ok=True) + self._images_dir = self._dataset_dir / 'images' + + manifest_data = self._load_manifest_json(redownload=redownload) + image_manifest = ImageManifest(**manifest_data) + self._manifest = image_manifest + self._ensure_images_present(manifest=image_manifest, redownload=redownload) + + self._image_id_to_annotation: Dict[ImageId, IA] = {} + self._image_id_to_sha256: Dict[ImageId, str] = {} + self._image_id_to_relpath: Dict[ImageId, Path] = {} + self._image_ids: List[ImageId] = [] + + for image_id, entry in image_manifest.entries.items(): + self._image_ids.append(image_id) + self._image_id_to_sha256[image_id] = entry.sha256 + self._image_id_to_relpath[image_id] = entry.relpath + self._image_id_to_annotation[image_id] = self.annotation_schema(**entry.annotation) + + def _load_manifest_json(self, redownload: bool) -> dict[str, Any]: + if redownload or not self.manifest_path.exists(): + manifest_data = download_json(self.manifest_url) + self.manifest_path.write_text(json.dumps(manifest_data, indent=2)) + return json.loads(self.manifest_path.read_text()) + + def _ensure_images_present(self, manifest: ImageManifest, redownload: bool = False) -> None: + """ + Ensure the images for this imageset exist locally. + """ + if redownload or not self._all_images_present(manifest): + self._download_and_extract_images(manifest) + + def _all_images_present(self, manifest: ImageManifest) -> bool: + for entry in manifest.entries.values(): + if not (self._images_dir / entry.relpath).exists(): + return False + return True + + def _download_and_extract_images(self, manifest: ImageManifest) -> None: + self._dataset_dir.mkdir(parents=True, exist_ok=True) + if self._images_dir.exists(): + for path in sorted(self._images_dir.rglob('*'), reverse=True): + if path.is_file(): + path.unlink() + elif path.is_dir(): + path.rmdir() + self._images_dir.rmdir() + + download_file(self.zipped_images_url, self.archive_path) + unzip_file(zip_path=self.archive_path, output_dir=self._images_dir) + self._verify_images(manifest) + + def _verify_images(self, manifest: ImageManifest) -> None: + for image_id, entry in manifest.entries.items(): + image_path = self._images_dir / entry.relpath + if not image_path.exists(): + raise FileNotFoundError(f"Missing image file for {image_id}: {image_path}") + + with PIL.Image.open(image_path) as image_data: + observed_sha256 = hash_image(image_data) + + if observed_sha256 != entry.sha256: + raise ValueError( + f"SHA256 mismatch for image {image_id}: {observed_sha256} != {entry.sha256}" + ) + + @property + def manifest_path(self) -> Path: + return self._dataset_dir / 'manifest.json' + + @property + def archive_path(self) -> Path: + archive_name = Path(self.zipped_images_url).name + return self._dataset_dir / archive_name + + @property + def images_dir(self) -> Path: + return self._images_dir + + @property + def image_ids(self) -> list[ImageId]: + """ + List of image refs in this imageset. + :return: + """ + return self._image_ids + + def get_annotation(self, *, image_id: ImageId) -> IA: + """ + Get the annotation for a given image. If an image has multiple annotations, this will throw an error. + """ + + entry = self._image_id_to_annotation[image_id] + return entry + + def get_image_path(self, *, image_id: ImageId) -> Path: + return self.images_dir / self._image_id_to_relpath[image_id] + + def __len__(self) -> int: + return len(self.image_ids) + + def __repr__(self): + return f"{self.__class__.__name__}({len(self)})" + + def __str__(self): + return f"Imageset(n={len(self)})" diff --git a/hobj/data/images/imagesets/warmup.py b/hobj/data_loaders/images/warmup.py similarity index 92% rename from hobj/data/images/imagesets/warmup.py rename to hobj/data_loaders/images/warmup.py index 1a53dc8..3fdd9c9 100644 --- a/hobj/data/images/imagesets/warmup.py +++ b/hobj/data_loaders/images/warmup.py @@ -1,7 +1,6 @@ import pydantic -from hobj.data.images.template import Imageset - +from hobj.data_loaders.images.template import Imageset class MutatorWarmupAnnotation(pydantic.BaseModel): diff --git a/hobj/data_loaders/store.py b/hobj/data_loaders/store.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/hobj/data_loaders/store.py @@ -0,0 +1 @@ + diff --git a/hobj/learning_models/__init__.py b/hobj/learning_models/__init__.py index 97e597c..ad3411b 100644 --- a/hobj/learning_models/__init__.py +++ b/hobj/learning_models/__init__.py @@ -1,15 +1,15 @@ import typing from abc import ABC, abstractmethod -import PIL.Image import numpy as np -import mref.media_references +from hobj.learning_models.representation import RepresentationalModel +from hobj.learning_models.update_rules import UpdateRule +from hobj.types import ImageId # %% class BinaryLearningModel(ABC): - @abstractmethod def reset_state(self, seed: typing.Union[int, None]) -> None: """ @@ -22,7 +22,7 @@ def reset_state(self, seed: typing.Union[int, None]) -> None: @abstractmethod def get_response( self, - image: typing.Union[mref.media_references.ImageRef, PIL.Image], + image: ImageId, ) -> typing.Literal[0, 1]: """ This function takes the current stimulus image (given either as a PIL.Image or a ImageRef) and returns one of two possible actions (parameterized by an integer). @@ -54,7 +54,7 @@ def reset_state(self, seed: typing.Union[int, None]) -> None: def get_response( self, - image: typing.Union[mref.media_references.ImageRef, PIL.Image], + image: ImageId, ) -> typing.Literal[0, 1]: action = self.random_generator.integers(2) action = int(action) @@ -62,3 +62,71 @@ def get_response( def give_feedback(self, reward: float) -> None: return + + +# %% +class LinearLearner(BinaryLearningModel): + def __init__( + self, + representational_model: RepresentationalModel, + update_rule: UpdateRule, + ): + + self.representational_model = representational_model + self.update_rule = update_rule + + # State variables + self.w = None + self.b = None + self._f_last = None + self._logits_last = None + self._action_last = None + self._generator: np.random.Generator = np.random.default_rng() + + # Initialize state + self.reset_state(seed=0) + return + + def reset_state(self, seed: int) -> None: + """ + :param seed: + :return: + """ + self.update_rule.reset() + self.w = np.zeros((self.representational_model.d, 2)) + self.b = np.zeros((2,)) + self._f_last = None + self._logits_last = None + self._action_last = None + self._generator = np.random.default_rng(seed=seed) + + return + + def get_response( + self, + image: ImageId, + ) -> typing.Literal[0, 1]: + + f = self.representational_model.get_features(image=image) + logits = f @ self.w + self.b + action = self._random_tiebreaking_argmax(logits[0], logits[1]) + + # Update internal state with traces + self._f_last = f + self._logits_last = logits + self._action_last = action + return action + + def give_feedback(self, reward: float) -> None: + delta_w, delta_b = self.update_rule.get_update(x=self._f_last, w=self.w, b=self.b, logits=self._logits_last, action=self._action_last, reward=reward) # [action] + self.w += delta_w + self.b += delta_b + return + + def _random_tiebreaking_argmax(self, logit0, logit1) -> typing.Literal[0, 1]: + if logit0 > logit1: + return 0 + elif logit0 < logit1: + return 1 + else: + return 0 if self._generator.random() < 0.5 else 1 diff --git a/hobj/learning_models/linear/__init__.py b/hobj/learning_models/linear/__init__.py deleted file mode 100644 index 77b934d..0000000 --- a/hobj/learning_models/linear/__init__.py +++ /dev/null @@ -1,75 +0,0 @@ -import PIL.Image -import numpy as np -import typing - -import mref.media_references -from hobj.learning_models import BinaryLearningModel -from hobj.learning_models.linear.representation import RepresentationalModel -from hobj.learning_models.linear.update_rules import UpdateRule - - -class LinearLearner(BinaryLearningModel): - def __init__( - self, - representational_model: RepresentationalModel, - update_rule: UpdateRule, - ): - - self.representational_model = representational_model - self.update_rule = update_rule - - # State variables - self.w = None - self.b = None - self._f_last = None - self._logits_last = None - self._action_last = None - self._generator: np.random.Generator = np.random.default_rng() - - # Initialize state - self.reset_state(seed=0) - return - - def reset_state(self, seed: int) -> None: - """ - :param seed: - :return: - """ - self.update_rule.reset() - self.w = np.zeros((self.representational_model.d, 2)) - self.b = np.zeros((2,)) - self._f_last = None - self._logits_last = None - self._action_last = None - self._generator = np.random.default_rng(seed=seed) - - return - - def get_response( - self, - image: typing.Union[mref.media_references.ImageRef, PIL.Image] - ) -> typing.Literal[0, 1]: - - f = self.representational_model.get_features(image=image) - logits = f @ self.w + self.b - action = self._random_tiebreaking_argmax(logits[0], logits[1]) - - # Update internal state with traces - self._f_last = f - self._logits_last = logits - self._action_last = action - return action - - def give_feedback(self, reward: float) -> None: - delta_w, delta_b = self.update_rule.get_update(x=self._f_last, w=self.w, b=self.b, logits=self._logits_last, action=self._action_last, reward=reward) # [action] - self.w += delta_w - self.b += delta_b - return - - def _random_tiebreaking_argmax(self, logit0, logit1) -> typing.Literal[0, 1]: - if logit0 > logit1: - return 0 - elif logit0 < logit1: - return 1 - else: - return 0 if self._generator.random() < 0.5 else 1 diff --git a/hobj/learning_models/linear/representation.py b/hobj/learning_models/representation.py similarity index 79% rename from hobj/learning_models/linear/representation.py rename to hobj/learning_models/representation.py index 846870d..bd3d1f4 100644 --- a/hobj/learning_models/linear/representation.py +++ b/hobj/learning_models/representation.py @@ -1,11 +1,11 @@ -import numpy as np +from typing import Callable, Dict -from typing import Union, Dict, Callable -import PIL.Image +import numpy as np -from mref import ImageRef +from hobj.types import ImageId +# %% class RepresentationalModel: """ Class which maps images to feature vectors of shape (d,). @@ -14,7 +14,7 @@ class RepresentationalModel: def __init__( self, d: int, - image_to_features_func: Callable[[Union[ImageRef, PIL.Image]], np.ndarray], + image_to_features_func: Callable[[ImageId], np.ndarray], ): if not isinstance(d, int): @@ -32,7 +32,7 @@ def d(self) -> int: def get_features( self, - image: Union[ImageRef, PIL.Image] + image: ImageId ) -> np.ndarray: """ Returns a feature vector for the image_url. @@ -52,7 +52,7 @@ def get_features( @classmethod def from_precomputed_features( cls, - image_ref_to_features: Dict[ImageRef, np.ndarray] + image_ref_to_features: Dict[ImageId, np.ndarray] ) -> 'RepresentationalModel': """ @@ -61,12 +61,8 @@ def from_precomputed_features( If get_features is called with an ImageRef (or PIL.Image with an ImageRef) not in image_ref_to_features, a KeyError will be raised. """ - def image_to_features_func(image: Union[ImageRef, PIL.Image]) -> np.ndarray: - if isinstance(image, PIL.Image.Image): - ref = ImageRef.from_image(image) - else: - ref = image - return image_ref_to_features[ref] + def image_to_features_func(image: ImageId) -> np.ndarray: + return image_ref_to_features[image] # Ensure all feature vectors are the same shape d = None @@ -83,4 +79,4 @@ def image_to_features_func(image: Union[ImageRef, PIL.Image]) -> np.ndarray: if not f.shape[0] == d: raise ValueError(f"Expected feature vector to be of shape ({d},), but got {f.shape}") - return cls(d=d, image_to_features_func=image_to_features_func) \ No newline at end of file + return cls(d=d, image_to_features_func=image_to_features_func) diff --git a/hobj/learning_models/tests/test_dummy_learner.py b/hobj/learning_models/tests/test_dummy_learner.py deleted file mode 100644 index 9c8656c..0000000 --- a/hobj/learning_models/tests/test_dummy_learner.py +++ /dev/null @@ -1,24 +0,0 @@ -import pytest - -import mref.media_references -from hobj.learning_models import RandomGuesser - -@pytest.fixture -def dummy_learner() -> RandomGuesser: - return RandomGuesser(seed=0) - -@pytest.fixture -def test_image() -> mref.media_references.ImageRef: - return mref.media_references.ImageRef(sha256='0' * 64) - - -def test_dummy_learner_deterministic(dummy_learner, test_image): - - ntests = 10 - actions = [] - expected_actions = [1, 1, 1, 0, 0, 0, 0, 0, 0, 1] - for i in range(ntests): - a = dummy_learner.get_response(image=test_image) - actions.append(a) - - assert actions == expected_actions \ No newline at end of file diff --git a/hobj/learning_models/linear/update_rules.py b/hobj/learning_models/update_rules.py similarity index 99% rename from hobj/learning_models/linear/update_rules.py rename to hobj/learning_models/update_rules.py index 6911f11..3297e26 100644 --- a/hobj/learning_models/linear/update_rules.py +++ b/hobj/learning_models/update_rules.py @@ -1,8 +1,9 @@ -import numpy as np -import scipy.special from abc import ABC, abstractmethod from typing import Tuple, Union +import numpy as np +import scipy.special + class UpdateRule(ABC): def __init__(self, alpha: float): @@ -52,6 +53,7 @@ def get_update( return delta_w, delta_b +# %% class Prototype(UpdateRule): """ Simulates the decision boundary implemented by a prototype learner. diff --git a/hobj/stats/ci.py b/hobj/stats/ci.py index 9a713d4..13979a5 100644 --- a/hobj/stats/ci.py +++ b/hobj/stats/ci.py @@ -1,4 +1,5 @@ from typing import Tuple, Union + import numpy as np @@ -11,7 +12,7 @@ def estimate_basic_bootstrap_CI( Estimates the basic confidence interval for a given point estimate(s) using the bootstrap method. :param alpha: Sets the width of the confidence interval to be 1 - alpha. Must be in the range (0, 1). :param point_estimate: The point estimate(s) for which the confidence interval is to be estimated. - :param bootstrapped_point_estimate: Bootstrap resamples of the point estimate in question. + :param bootstrapped_point_estimates: Bootstrap resamples of the point estimate in question. :return: A tuple containing the lower and upper bounds of the confidence interval. """ diff --git a/hobj/types.py b/hobj/types.py new file mode 100644 index 0000000..0f0952f --- /dev/null +++ b/hobj/types.py @@ -0,0 +1 @@ +type ImageId = str diff --git a/hobj/utils/file_io.py b/hobj/utils/file_io.py index 86be9cb..6778353 100644 --- a/hobj/utils/file_io.py +++ b/hobj/utils/file_io.py @@ -37,16 +37,13 @@ def download_file(url: str, output_path: Path) -> None: if not output_path.parent.exists(): output_path.parent.mkdir(parents=True) - size, unit = get_bytes_size(num_bytes=total_size_in_bytes) - - with tqdm(total=size, unit=unit, unit_scale=True, disable=False, desc='Download progress') as progress_bar: + with tqdm(total=total_size_in_bytes, unit='B', unit_scale=True, disable=False, desc='Download progress') as progress_bar: with open(output_path.as_posix(), 'wb') as file: # Iterate over the response data in chunks and write to file for chunk in response.iter_content(chunk_size=1024): if chunk: file.write(chunk) - chunk_size, _ = get_bytes_size(num_bytes=len(chunk), output_units=unit) - progress_bar.update(chunk_size) + progress_bar.update(len(chunk)) file.flush() @@ -93,4 +90,4 @@ def download_json(url: str) -> Any: data = response.read().decode('utf-8') json_data = json.loads(data) - return json_data \ No newline at end of file + return json_data diff --git a/hobj/utils/hash.py b/hobj/utils/hash.py index 7431058..f321fbe 100644 --- a/hobj/utils/hash.py +++ b/hobj/utils/hash.py @@ -1,9 +1,10 @@ -import numpy as np import hashlib + import PIL.Image +import numpy as np -def hash_image(image: PIL.Image) -> str: +def hash_image(image: PIL.Image.Image) -> str: """ Hash an image based on its np.uint8 representation. :param image: diff --git a/pyproject.toml b/pyproject.toml index 4104969..526a8b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "tqdm>=4.67", "pydantic>=2.10", "xarray>=2025.1", + "scipy>=1.17.1", ] [dependency-groups] diff --git a/site/readme_images/run_make_images.py b/site/readme_images/run_make_images.py index b55fb68..c7ec90c 100644 --- a/site/readme_images/run_make_images.py +++ b/site/readme_images/run_make_images.py @@ -2,11 +2,9 @@ # %% import matplotlib.pyplot as plt - from hobj.data.images import MutatorHighVarImageset from hobj.benchmarks import MutatorHighVarBenchmark import numpy as np - imageset = MutatorHighVarImageset() benchmark = MutatorHighVarBenchmark() target_stats = benchmark.target_statistics diff --git a/tests/test_dummy_learner.py b/tests/test_dummy_learner.py new file mode 100644 index 0000000..f164712 --- /dev/null +++ b/tests/test_dummy_learner.py @@ -0,0 +1,19 @@ +import pytest + +from hobj.learning_models import RandomGuesser + +@pytest.fixture +def dummy_learner() -> RandomGuesser: + return RandomGuesser(seed=0) + + +def test_dummy_learner_deterministic(dummy_learner): + + ntests = 10 + actions = [] + expected_actions = [1, 1, 1, 0, 0, 0, 0, 0, 0, 1] + for i in range(ntests): + a = dummy_learner.get_response(image='hi') + actions.append(a) + + assert actions == expected_actions diff --git a/hobj/data/tests/test_load_behavior.py b/tests/test_load_behavior.py similarity index 70% rename from hobj/data/tests/test_load_behavior.py rename to tests/test_load_behavior.py index c76d948..83b6c8d 100644 --- a/hobj/data/tests/test_load_behavior.py +++ b/tests/test_load_behavior.py @@ -1,4 +1,4 @@ -from hobj.data.behavior import load_highvar_behavior, load_oneshot_behavior +from hobj.data_loaders.behavior import load_highvar_behavior, load_oneshot_behavior def test_load_highvar(): diff --git a/hobj/benchmarks/binary_classification/tests/test_simulate_subtask.py b/tests/test_simulate_subtask.py similarity index 74% rename from hobj/benchmarks/binary_classification/tests/test_simulate_subtask.py rename to tests/test_simulate_subtask.py index 3503df8..3772a21 100644 --- a/hobj/benchmarks/binary_classification/tests/test_simulate_subtask.py +++ b/tests/test_simulate_subtask.py @@ -1,18 +1,18 @@ -from hobj.benchmarks.binary_classification.simulation import BinaryClassificationSubtask -from hobj.learning_models import RandomGuesser -from mref import ImageRef from typing import List + import PIL.Image import numpy as np +from hobj.benchmarks.binary_classification.simulation import BinaryClassificationSubtask +from hobj.learning_models import RandomGuesser +from hobj.types import ImageId + -def create_image_refs(nimages_per_class: int, seed: int) -> List[ImageRef]: +def create_image_refs(nimages_per_class: int, seed: int) -> List[ImageId]: images = [] np.random.seed(seed) - for _ in range(nimages_per_class): - image = PIL.Image.fromarray(np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)) - image_ref = ImageRef.from_image(image=image) - images.append(image_ref) + for i in range(nimages_per_class): + images.append(f'seed{seed}_image{i}') return images @@ -36,6 +36,3 @@ def test_simulate_subtask(): ) assert len(result.perf_seq) == ntrials - - -# Todo: test deterministic \ No newline at end of file diff --git a/uv.lock b/uv.lock index f65af68..8d370d2 100644 --- a/uv.lock +++ b/uv.lock @@ -157,6 +157,7 @@ dependencies = [ { name = "pandas" }, { name = "pydantic" }, { name = "requests" }, + { name = "scipy" }, { name = "tqdm" }, { name = "xarray" }, ] @@ -173,6 +174,7 @@ requires-dist = [ { name = "pandas", specifier = ">=2.2" }, { name = "pydantic", specifier = ">=2.10" }, { name = "requests", specifier = ">=2.32" }, + { name = "scipy", specifier = ">=1.17.1" }, { name = "tqdm", specifier = ">=4.67" }, { name = "xarray", specifier = ">=2025.1" }, ] @@ -528,6 +530,37 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, ] +[[package]] +name = "scipy" +version = "1.17.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7a/97/5a3609c4f8d58b039179648e62dd220f89864f56f7357f5d4f45c29eb2cc/scipy-1.17.1.tar.gz", hash = "sha256:95d8e012d8cb8816c226aef832200b1d45109ed4464303e997c5b13122b297c0", size = 30573822, upload-time = "2026-02-23T00:26:24.851Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/df/75/b4ce781849931fef6fd529afa6b63711d5a733065722d0c3e2724af9e40a/scipy-1.17.1-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:1f95b894f13729334fb990162e911c9e5dc1ab390c58aa6cbecb389c5b5e28ec", size = 31613675, upload-time = "2026-02-23T00:16:00.13Z" }, + { url = "https://files.pythonhosted.org/packages/f7/58/bccc2861b305abdd1b8663d6130c0b3d7cc22e8d86663edbc8401bfd40d4/scipy-1.17.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:e18f12c6b0bc5a592ed23d3f7b891f68fd7f8241d69b7883769eb5d5dfb52696", size = 28162057, upload-time = "2026-02-23T00:16:09.456Z" }, + { url = "https://files.pythonhosted.org/packages/6d/ee/18146b7757ed4976276b9c9819108adbc73c5aad636e5353e20746b73069/scipy-1.17.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:a3472cfbca0a54177d0faa68f697d8ba4c80bbdc19908c3465556d9f7efce9ee", size = 20334032, upload-time = "2026-02-23T00:16:17.358Z" }, + { url = "https://files.pythonhosted.org/packages/ec/e6/cef1cf3557f0c54954198554a10016b6a03b2ec9e22a4e1df734936bd99c/scipy-1.17.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:766e0dc5a616d026a3a1cffa379af959671729083882f50307e18175797b3dfd", size = 22709533, upload-time = "2026-02-23T00:16:25.791Z" }, + { url = "https://files.pythonhosted.org/packages/4d/60/8804678875fc59362b0fb759ab3ecce1f09c10a735680318ac30da8cd76b/scipy-1.17.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:744b2bf3640d907b79f3fd7874efe432d1cf171ee721243e350f55234b4cec4c", size = 33062057, upload-time = "2026-02-23T00:16:36.931Z" }, + { url = "https://files.pythonhosted.org/packages/09/7d/af933f0f6e0767995b4e2d705a0665e454d1c19402aa7e895de3951ebb04/scipy-1.17.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:43af8d1f3bea642559019edfe64e9b11192a8978efbd1539d7bc2aaa23d92de4", size = 35349300, upload-time = "2026-02-23T00:16:49.108Z" }, + { url = "https://files.pythonhosted.org/packages/b4/3d/7ccbbdcbb54c8fdc20d3b6930137c782a163fa626f0aef920349873421ba/scipy-1.17.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:cd96a1898c0a47be4520327e01f874acfd61fb48a9420f8aa9f6483412ffa444", size = 35127333, upload-time = "2026-02-23T00:17:01.293Z" }, + { url = "https://files.pythonhosted.org/packages/e8/19/f926cb11c42b15ba08e3a71e376d816ac08614f769b4f47e06c3580c836a/scipy-1.17.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4eb6c25dd62ee8d5edf68a8e1c171dd71c292fdae95d8aeb3dd7d7de4c364082", size = 37741314, upload-time = "2026-02-23T00:17:12.576Z" }, + { url = "https://files.pythonhosted.org/packages/95/da/0d1df507cf574b3f224ccc3d45244c9a1d732c81dcb26b1e8a766ae271a8/scipy-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:d30e57c72013c2a4fe441c2fcb8e77b14e152ad48b5464858e07e2ad9fbfceff", size = 36607512, upload-time = "2026-02-23T00:17:23.424Z" }, + { url = "https://files.pythonhosted.org/packages/68/7f/bdd79ceaad24b671543ffe0ef61ed8e659440eb683b66f033454dcee90eb/scipy-1.17.1-cp311-cp311-win_arm64.whl", hash = "sha256:9ecb4efb1cd6e8c4afea0daa91a87fbddbce1b99d2895d151596716c0b2e859d", size = 24599248, upload-time = "2026-02-23T00:17:34.561Z" }, + { url = "https://files.pythonhosted.org/packages/35/48/b992b488d6f299dbe3f11a20b24d3dda3d46f1a635ede1c46b5b17a7b163/scipy-1.17.1-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:35c3a56d2ef83efc372eaec584314bd0ef2e2f0d2adb21c55e6ad5b344c0dcb8", size = 31610954, upload-time = "2026-02-23T00:17:49.855Z" }, + { url = "https://files.pythonhosted.org/packages/b2/02/cf107b01494c19dc100f1d0b7ac3cc08666e96ba2d64db7626066cee895e/scipy-1.17.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:fcb310ddb270a06114bb64bbe53c94926b943f5b7f0842194d585c65eb4edd76", size = 28172662, upload-time = "2026-02-23T00:18:01.64Z" }, + { url = "https://files.pythonhosted.org/packages/cf/a9/599c28631bad314d219cf9ffd40e985b24d603fc8a2f4ccc5ae8419a535b/scipy-1.17.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:cc90d2e9c7e5c7f1a482c9875007c095c3194b1cfedca3c2f3291cdc2bc7c086", size = 20344366, upload-time = "2026-02-23T00:18:12.015Z" }, + { url = "https://files.pythonhosted.org/packages/35/f5/906eda513271c8deb5af284e5ef0206d17a96239af79f9fa0aebfe0e36b4/scipy-1.17.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:c80be5ede8f3f8eded4eff73cc99a25c388ce98e555b17d31da05287015ffa5b", size = 22704017, upload-time = "2026-02-23T00:18:21.502Z" }, + { url = "https://files.pythonhosted.org/packages/da/34/16f10e3042d2f1d6b66e0428308ab52224b6a23049cb2f5c1756f713815f/scipy-1.17.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e19ebea31758fac5893a2ac360fedd00116cbb7628e650842a6691ba7ca28a21", size = 32927842, upload-time = "2026-02-23T00:18:35.367Z" }, + { url = "https://files.pythonhosted.org/packages/01/8e/1e35281b8ab6d5d72ebe9911edcdffa3f36b04ed9d51dec6dd140396e220/scipy-1.17.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:02ae3b274fde71c5e92ac4d54bc06c42d80e399fec704383dcd99b301df37458", size = 35235890, upload-time = "2026-02-23T00:18:49.188Z" }, + { url = "https://files.pythonhosted.org/packages/c5/5c/9d7f4c88bea6e0d5a4f1bc0506a53a00e9fcb198de372bfe4d3652cef482/scipy-1.17.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8a604bae87c6195d8b1045eddece0514d041604b14f2727bbc2b3020172045eb", size = 35003557, upload-time = "2026-02-23T00:18:54.74Z" }, + { url = "https://files.pythonhosted.org/packages/65/94/7698add8f276dbab7a9de9fb6b0e02fc13ee61d51c7c3f85ac28b65e1239/scipy-1.17.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:f590cd684941912d10becc07325a3eeb77886fe981415660d9265c4c418d0bea", size = 37625856, upload-time = "2026-02-23T00:19:00.307Z" }, + { url = "https://files.pythonhosted.org/packages/a2/84/dc08d77fbf3d87d3ee27f6a0c6dcce1de5829a64f2eae85a0ecc1f0daa73/scipy-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:41b71f4a3a4cab9d366cd9065b288efc4d4f3c0b37a91a8e0947fb5bd7f31d87", size = 36549682, upload-time = "2026-02-23T00:19:07.67Z" }, + { url = "https://files.pythonhosted.org/packages/bc/98/fe9ae9ffb3b54b62559f52dedaebe204b408db8109a8c66fdd04869e6424/scipy-1.17.1-cp312-cp312-win_arm64.whl", hash = "sha256:f4115102802df98b2b0db3cce5cb9b92572633a1197c77b7553e5203f284a5b3", size = 24547340, upload-time = "2026-02-23T00:19:12.024Z" }, +] + [[package]] name = "six" version = "1.17.0"