diff --git a/hobj/benchmarks/make_model.py b/hobj/benchmarks/make_model.py new file mode 100644 index 0000000..2df5f58 --- /dev/null +++ b/hobj/benchmarks/make_model.py @@ -0,0 +1,58 @@ +""" +This module provides an alternative interface for instantiating a linear learning model. +""" +from hobj.learning_models.linear import LinearLearner, RepresentationalModel +import hobj.learning_models.linear.update_rules as update_rules +from typing import Literal, Dict +import mref +import numpy as np +from typing import List + + +# %% +def make_linear_learner_from_features( + ref_to_features: Dict[mref.ImageRef, np.ndarray], + calibration_images: List[mref.ImageRef], + update_rule_name: Literal[ + 'Prototype', + 'Square', + 'Perceptron', + 'Hinge', + 'MAE', + 'Exponential', + 'CE', + 'REINFORCE' + ] = 'Square', + alpha: float = 1, +) -> LinearLearner: + """ + Instantiates a linear learning model from precomputed features. + :param ref_to_features: Dict[mref.ImageRef, np.ndarray], the features to use. + :param calibration_images: List[mref.ImageRef], the images that will be used to calibrate the features (i.e. for mean centering and ensuring they fit within a unit ball). + :param update_rule_name: str, the name of the update rule to use. + :param alpha: float, the learning rate. + :return: LinearLearner + """ + + f_calibration = np.array([ref_to_features[ref] for ref in calibration_images]) + mu_calibration = np.mean(f_calibration, axis=0) + norms_calibration = np.linalg.norm(f_calibration - mu_calibration, axis=1) + norm_cutoff = np.quantile(norms_calibration, 0.999) # Will clip the rest + + ref_to_calibrated_features = {} + for ref in ref_to_features: + f = ref_to_features[ref] + fc = f - mu_calibration + fcn = fc / norm_cutoff + norm_cur = np.linalg.norm(fcn) + if norm_cur > 1: + fcn = fcn / norm_cur + ref_to_calibrated_features[ref] = np.array(fcn) + + update_rule_name = getattr(update_rules, update_rule_name) + return LinearLearner( + representational_model=RepresentationalModel.from_precomputed_features( + image_ref_to_features=ref_to_calibrated_features + ), + update_rule=update_rule_name(alpha=alpha) + ) \ No newline at end of file diff --git a/hobj/benchmarks/mut_highvar_benchmark.py b/hobj/benchmarks/mut_highvar_benchmark.py index 27ad1ad..2d76f7e 100644 --- a/hobj/benchmarks/mut_highvar_benchmark.py +++ b/hobj/benchmarks/mut_highvar_benchmark.py @@ -84,3 +84,7 @@ def __init__(self): super().__init__( config=config ) + +if __name__ == '__main__': + experiment = MutatorHighVarBenchmark() + print(sorted(experiment.config.subtask_name_to_data.keys())) \ No newline at end of file diff --git a/hobj/config.py b/hobj/config.py index fe623da..07bc6e0 100644 --- a/hobj/config.py +++ b/hobj/config.py @@ -1,3 +1,3 @@ from pathlib import Path -cachedir: Path = Path.home() / 'hobj_cache2' +cachedir: Path = Path.home() / 'hobj_cache' diff --git a/pyproject.toml b/pyproject.toml index 919d6fd..e6b4da5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ dependencies = [ "pydantic>=2.10", "xarray>=2025.1", "pytest>=8.3", - #"mref @ git+https://github.com/himjl/mref.git", + "mref @ git+https://github.com/himjl/mref.git", ] [project.urls]