From 8d1b08e5ba99117db47416daca7f66f71a1cdf4d Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Tue, 10 Feb 2026 13:23:59 -0800 Subject: [PATCH 1/3] Upstream MultiTypeExperiment features into Experiment (#4873) Summary: These changes will enable us to deprecate multitypeexperment, simplifying the Ax data model ahead of storage changes. 1. In Experiment make the default_trial_type a new Key.DEFAULT_TRIAL_TYPE value instead of None 2. Move over logic for bookkeeping metric -> trial_type and runner -> trial_type mappings 3. Treat LONG_ and SHORT_RUN trial types as special cases which map to DEFAULT_TRIAL_TYPE (i.e. if a Trial has trial_type=LONG_RUN then use whichever metrics and runners are mapped to DEFAULT_TRIAL_TYPE 4. Fix tests which expect the default_trial_type of an Experiment to be None This diff allows us to remove all isinstance(foo, MultiTypeExperiment) checks in Ax in the next diff, then to deprecate MultiTypeExperiment entirely. Differential Revision: D91618283 --- ax/core/base_trial.py | 9 +- ax/core/experiment.py | 237 +++++++++++++++++++++++++----- ax/core/tests/test_batch_trial.py | 3 +- ax/core/tests/test_experiment.py | 7 +- ax/core/tests/test_observation.py | 10 ++ ax/core/tests/test_trial.py | 4 +- ax/orchestration/orchestrator.py | 28 ++-- ax/storage/json_store/decoder.py | 14 +- ax/storage/json_store/decoders.py | 9 +- ax/storage/sqa_store/decoder.py | 31 +++- ax/storage/sqa_store/encoder.py | 7 +- ax/utils/common/constants.py | 1 + 12 files changed, 292 insertions(+), 68 deletions(-) diff --git a/ax/core/base_trial.py b/ax/core/base_trial.py index ee201603067..77b8160cfae 100644 --- a/ax/core/base_trial.py +++ b/ax/core/base_trial.py @@ -98,16 +98,15 @@ def __init__( self._ttl_seconds: int | None = ttl_seconds self._index: int = self._experiment._attach_trial(self, index=index) - trial_type = ( + self._trial_type: str = ( trial_type if trial_type is not None else self._experiment.default_trial_type ) - if not self._experiment.supports_trial_type(trial_type): + if not self._experiment.supports_trial_type(self._trial_type): raise ValueError( - f"Trial type {trial_type} is not supported by the experiment." + f"Trial type {self._trial_type} is not supported by the experiment." ) - self._trial_type = trial_type self.__status: TrialStatus | None = None # Uses `_status` setter, which updates trial statuses to trial indices @@ -285,7 +284,7 @@ def stop_metadata(self) -> dict[str, Any]: return self._stop_metadata @property - def trial_type(self) -> str | None: + def trial_type(self) -> str: """The type of the trial. Relevant for experiments containing different kinds of trials diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 071fe83f802..39b466faf02 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -100,7 +100,7 @@ def __init__( default_data_type: Any = None, auxiliary_experiments_by_purpose: None | (dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]]) = None, - default_trial_type: str | None = None, + default_trial_type: str = Keys.DEFAULT_TRIAL_TYPE.value, ) -> None: """Inits Experiment. @@ -123,6 +123,8 @@ def __init__( default_data_type: Deprecated and ignored. auxiliary_experiments_by_purpose: Dictionary of auxiliary experiments for different purposes (e.g., transfer learning). + default_trial_type: Default trial type for trials on this experiment. + Defaults to Keys.DEFAULT_TRIAL_TYPE. """ if default_data_type is not None: warnings.warn( @@ -150,10 +152,16 @@ def __init__( self._properties: dict[str, Any] = properties or {} # Initialize trial type to runner mapping - self._default_trial_type = default_trial_type - self._trial_type_to_runner: dict[str | None, Runner | None] = { - default_trial_type: runner + self._default_trial_type: str = ( + default_trial_type or Keys.DEFAULT_TRIAL_TYPE.value + ) + self._trial_type_to_runner: dict[str, Runner | None] = { + self._default_trial_type: runner } + + # Maps metric names to their trial types. Every metric must have an entry. + self._metric_to_trial_type: dict[str, str] = {} + # Used to keep track of whether any trials on the experiment # specify a TTL. Since trials need to be checked for their TTL's # expiration often, having this attribute helps avoid unnecessary @@ -413,16 +421,46 @@ def runner(self) -> Runner | None: def runner(self, runner: Runner | None) -> None: """Set the default runner and update trial type mapping.""" self._runner = runner - if runner is not None: - self._trial_type_to_runner[self._default_trial_type] = runner - else: - self._trial_type_to_runner = {None: None} + self._trial_type_to_runner[self._default_trial_type] = runner @runner.deleter def runner(self) -> None: """Delete the runner.""" self._runner = None - self._trial_type_to_runner = {None: None} + self._trial_type_to_runner[self._default_trial_type] = None + + def add_trial_type(self, trial_type: str, runner: Runner) -> "Experiment": + """Add a new trial_type to be supported by this experiment. + + Args: + trial_type: The new trial_type to be added. + runner: The default runner for trials of this type. + + Returns: + The experiment with the new trial type added. + """ + if self.supports_trial_type(trial_type): + raise ValueError(f"Experiment already contains trial_type `{trial_type}`") + + self._trial_type_to_runner[trial_type] = runner + return self + + def update_runner(self, trial_type: str, runner: Runner) -> "Experiment": + """Update the default runner for an existing trial_type. + + Args: + trial_type: The trial_type to update. + runner: The new runner for trials of this type. + + Returns: + The experiment with the updated runner. + """ + if not self.supports_trial_type(trial_type): + raise ValueError(f"Experiment does not contain trial_type `{trial_type}`") + + self._trial_type_to_runner[trial_type] = runner + self._runner = runner + return self @property def parameters(self) -> dict[str, Parameter]: @@ -489,13 +527,25 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None: f"`{Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF.value}` " "property that is set to `True` on this experiment." ) + + # Remove old OC metrics from trial type mapping + prev_optimization_config = self._optimization_config + if prev_optimization_config is not None: + for metric_name in prev_optimization_config.metrics.keys(): + self._metric_to_trial_type.pop(metric_name, None) + for metric_name in optimization_config.metrics.keys(): if metric_name in self._tracking_metrics: self.remove_tracking_metric(metric_name) + # add metrics from the previous optimization config that are not in the new # optimization config as tracking metrics - prev_optimization_config = self._optimization_config self._optimization_config = optimization_config + + # Map new OC metrics to default trial type + for metric_name in optimization_config.metrics.keys(): + self._metric_to_trial_type[metric_name] = self._default_trial_type + if prev_optimization_config is not None: metrics_to_track = ( set(prev_optimization_config.metrics.keys()) @@ -505,6 +555,16 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None: for metric_name in metrics_to_track: self.add_tracking_metric(prev_optimization_config.metrics[metric_name]) + # Clean up any stale entries in _metric_to_trial_type that don't correspond + # to actual metrics (can happen when same optimization_config object is + # mutated and reassigned). + current_metric_names = set(self.metrics.keys()) + stale_metric_names = ( + set(self._metric_to_trial_type.keys()) - current_metric_names + ) + for metric_name in stale_metric_names: + self._metric_to_trial_type.pop(metric_name, None) + @property def is_moo_problem(self) -> bool: """Whether the experiment's optimization config contains multiple objectives.""" @@ -553,12 +613,25 @@ def immutable_search_space_and_opt_config(self) -> bool: def tracking_metrics(self) -> list[Metric]: return list(self._tracking_metrics.values()) - def add_tracking_metric(self, metric: Metric) -> Self: + def add_tracking_metric( + self, + metric: Metric, + trial_type: str | None = None, + ) -> Self: """Add a new metric to the experiment. Args: metric: Metric to be added. + trial_type: The trial type for which this metric is used. If not + provided, defaults to the experiment's default trial type. """ + effective_trial_type = ( + trial_type if trial_type is not None else self._default_trial_type + ) + + if not self.supports_trial_type(effective_trial_type): + raise ValueError(f"`{effective_trial_type}` is not a supported trial type.") + if metric.name in self._tracking_metrics: raise ValueError( f"Metric `{metric.name}` already defined on experiment. " @@ -574,9 +647,14 @@ def add_tracking_metric(self, metric: Metric) -> Self: ) self._tracking_metrics[metric.name] = metric + self._metric_to_trial_type[metric.name] = effective_trial_type return self - def add_tracking_metrics(self, metrics: list[Metric]) -> Self: + def add_tracking_metrics( + self, + metrics: list[Metric], + metrics_to_trial_types: dict[str, str] | None = None, + ) -> Self: """Add a list of new metrics to the experiment. If any of the metrics are already defined on the experiment, @@ -584,23 +662,58 @@ def add_tracking_metrics(self, metrics: list[Metric]) -> Self: Args: metrics: Metrics to be added. + metrics_to_trial_types: Optional mapping from metric names to + corresponding trial types. If not provided for a metric, + the experiment's default trial type is used. """ - # Before setting any metrics, we validate none are already on - # the experiment + metrics_to_trial_types = metrics_to_trial_types or {} for metric in metrics: - self.add_tracking_metric(metric) + self.add_tracking_metric( + metric=metric, + trial_type=metrics_to_trial_types.get(metric.name), + ) return self - def update_tracking_metric(self, metric: Metric) -> Self: + def update_tracking_metric( + self, + metric: Metric, + trial_type: str | None = None, + ) -> Self: """Redefine a metric that already exists on the experiment. Args: metric: New metric definition. + trial_type: The trial type for which this metric is used. If not + provided, keeps the existing trial type mapping. """ if metric.name not in self._tracking_metrics: raise ValueError(f"Metric `{metric.name}` doesn't exist on experiment.") + # Validate trial type if provided + effective_trial_type = ( + trial_type + if trial_type is not None + else self._metric_to_trial_type.get(metric.name, self._default_trial_type) + ) + + # Check that optimization config metrics stay on default trial type + oc = self.optimization_config + oc_metrics = oc.metrics if oc else {} + if ( + metric.name in oc_metrics + and effective_trial_type != self._default_trial_type + ): + raise ValueError( + f"Metric `{metric.name}` must remain a " + f"`{self._default_trial_type}` metric because it is part of the " + "optimization_config." + ) + + if not self.supports_trial_type(effective_trial_type): + raise ValueError(f"`{effective_trial_type}` is not a supported trial type.") + self._tracking_metrics[metric.name] = metric + self._metric_to_trial_type[metric.name] = effective_trial_type return self def remove_tracking_metric(self, metric_name: str) -> Self: @@ -613,6 +726,7 @@ def remove_tracking_metric(self, metric_name: str) -> Self: raise ValueError(f"Metric `{metric_name}` doesn't exist on experiment.") del self._tracking_metrics[metric_name] + self._metric_to_trial_type.pop(metric_name, None) return self @property @@ -852,8 +966,21 @@ def _fetch_trial_data( ) -> dict[str, MetricFetchResult]: trial = self.trials[trial_index] + # If metrics are not provided, fetch all metrics on the experiment for the + # relevant trial type, or the default trial type as a fallback. Otherwise, + # fetch provided metrics. + if metrics is None: + resolved_metrics = [ + metric + for metric in list(self.metrics.values()) + if self._metric_to_trial_type.get(metric.name, self._default_trial_type) + == trial.trial_type + ] + else: + resolved_metrics = metrics + trial_data = self._lookup_or_fetch_trials_results( - trials=[trial], metrics=metrics, **kwargs + trials=[trial], metrics=resolved_metrics, **kwargs ) if trial_index in trial_data: @@ -1548,39 +1675,79 @@ def __repr__(self) -> str: # overridden in the MultiTypeExperiment class. @property - def default_trial_type(self) -> str | None: - """Default trial type assigned to trials in this experiment. - - In the base experiment class this is always None. For experiments - with multiple trial types, use the MultiTypeExperiment class. - """ + def default_trial_type(self) -> str: + """Default trial type assigned to trials in this experiment.""" return self._default_trial_type - def runner_for_trial_type(self, trial_type: str | None) -> Runner | None: + def runner_for_trial_type(self, trial_type: str) -> Runner | None: """The default runner to use for a given trial type. Looks up the appropriate runner for this trial type in the trial_type_to_runner. """ + # Special case for LONG_ and SHORT_RUN trial types, which we treat as "default" + # trial types for deployment. + if ( + trial_type == Keys.SHORT_RUN or trial_type == Keys.LONG_RUN + ) and self.supports_trial_type(trial_type=Keys.DEFAULT_TRIAL_TYPE): + return self._trial_type_to_runner[Keys.DEFAULT_TRIAL_TYPE] + if not self.supports_trial_type(trial_type): raise ValueError(f"Trial type `{trial_type}` is not supported.") if (runner := self._trial_type_to_runner.get(trial_type)) is None: return self.runner # return the default runner return runner - def supports_trial_type(self, trial_type: str | None) -> bool: + def supports_trial_type(self, trial_type: str) -> bool: """Whether this experiment allows trials of the given type. - The base experiment class only supports None. For experiments - with multiple trial types, use the MultiTypeExperiment class. + Checks if the trial type is registered in the trial_type_to_runner mapping. """ - return ( - trial_type is None - # We temporarily allow "short run" and "long run" trial - # types in single-type experiments during development of - # a new ``GenerationStrategy`` that needs them. - or trial_type == Keys.SHORT_RUN - or trial_type == Keys.LONG_RUN - ) + # Special case for LONG_ and SHORT_RUN trial types, which we treat as "default" + # trial types for deployment. + if ( + trial_type == Keys.SHORT_RUN or trial_type == Keys.LONG_RUN + ) and self.supports_trial_type(trial_type=Keys.DEFAULT_TRIAL_TYPE): + return True + + return trial_type in self._trial_type_to_runner + + @property + def is_multi_type(self) -> bool: + """Returns True if this experiment has multiple trial types registered.""" + return len(self._trial_type_to_runner) > 1 + + @property + def metric_to_trial_type(self) -> dict[str, str]: + """Read-only mapping of metric names to trial types.""" + return self._metric_to_trial_type.copy() + + def metrics_for_trial_type(self, trial_type: str) -> list[Metric]: + """Returns metrics associated with a specific trial type. + + Args: + trial_type: The trial type to get metrics for. + + Returns: + List of metrics associated with the given trial type. + """ + # Special case for LONG_ and SHORT_RUN trial types, which we treat as "default" + # trial types for deployment. + if ( + trial_type == Keys.SHORT_RUN or trial_type == Keys.LONG_RUN + ) and self.supports_trial_type(trial_type=Keys.DEFAULT_TRIAL_TYPE): + return [ + self.metrics[metric_name] + for metric_name, metric_trial_type in self._metric_to_trial_type.items() + if metric_trial_type == Keys.DEFAULT_TRIAL_TYPE + ] + + if not self.supports_trial_type(trial_type): + raise ValueError(f"Trial type `{trial_type}` is not supported.") + return [ + self.metrics[metric_name] + for metric_name, metric_trial_type in self._metric_to_trial_type.items() + if metric_trial_type == trial_type + ] def attach_trial( self, diff --git a/ax/core/tests/test_batch_trial.py b/ax/core/tests/test_batch_trial.py index 8a2c201f4b9..97901dd3be9 100644 --- a/ax/core/tests/test_batch_trial.py +++ b/ax/core/tests/test_batch_trial.py @@ -410,8 +410,9 @@ def test_clone_to(self, _) -> None: cloned_batch._time_created = batch._time_created self.assertEqual(cloned_batch, batch) # test cloning with clear_trial_type=True + # When clear_trial_type=True, uses experiment's default_trial_type cloned_batch = batch.clone_to(clear_trial_type=True) - self.assertIsNone(cloned_batch.trial_type) + self.assertEqual(cloned_batch.trial_type, self.experiment.default_trial_type) self.assertEqual( cloned_batch.generation_method_str, f"{MANUAL_GENERATION_METHOD_STR}, {STATUS_QUO_GENERATION_METHOD_STR}", diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index 9903fbeb578..c87ec407bb5 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -1417,7 +1417,7 @@ def test_clone_with(self) -> None: cloned_experiment._time_created = experiment._time_created self.assertEqual(cloned_experiment, experiment) - # test clear_trial_type + # test clear_trial_type - uses experiment's default_trial_type experiment = get_branin_experiment( with_batch=True, num_batch_trial=1, @@ -1427,7 +1427,10 @@ def test_clone_with(self) -> None: with self.assertRaisesRegex(ValueError, ".* foo is not supported by the exp"): experiment.clone_with() cloned_experiment = experiment.clone_with(clear_trial_type=True) - self.assertIsNone(cloned_experiment.trials[0].trial_type) + self.assertEqual( + cloned_experiment.trials[0].trial_type, + cloned_experiment.default_trial_type, + ) # Test cloning with specific properties to keep experiment_w_props = get_branin_experiment() diff --git a/ax/core/tests/test_observation.py b/ax/core/tests/test_observation.py index c18ccf25deb..0817cac286a 100644 --- a/ax/core/tests/test_observation.py +++ b/ax/core/tests/test_observation.py @@ -268,6 +268,8 @@ def test_ObservationsFromData(self) -> None: } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} + experiment.default_trial_type = "default" + experiment.supports_trial_type = Mock(return_value=True) trials = { obs["trial_index"]: Trial( experiment, GeneratorRun(arms=[arms[obs["arm_name"]]]) @@ -525,6 +527,8 @@ def test_ObservationsFromDataAbandoned(self) -> None: } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} + experiment.default_trial_type = "default" + experiment.supports_trial_type = Mock(return_value=True) trials = { obs["trial_index"]: ( Trial(experiment, GeneratorRun(arms=[arms[obs["arm_name"]]])) @@ -637,6 +641,8 @@ def test_ObservationsFromDataWithSomeMissingTimes(self) -> None: } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} + experiment.default_trial_type = "default" + experiment.supports_trial_type = Mock(return_value=True) trials = { obs["trial_index"]: Trial( experiment, GeneratorRun(arms=[arms[obs["arm_name"]]]) @@ -744,6 +750,8 @@ def test_ObservationsFromDataWithDifferentTimesSingleTrial( } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} + experiment.default_trial_type = "default" + experiment.supports_trial_type = Mock(return_value=True) trials = { 0: BatchTrial(experiment, GeneratorRun(arms=list(arms_by_name.values()))) } @@ -885,6 +893,8 @@ def test_ObservationsWithCandidateMetadata(self) -> None: } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} + experiment.default_trial_type = "default" + experiment.supports_trial_type = Mock(return_value=True) trials = { obs["trial_index"]: Trial( experiment, diff --git a/ax/core/tests/test_trial.py b/ax/core/tests/test_trial.py index 33eef439201..0fa7d3aecfe 100644 --- a/ax/core/tests/test_trial.py +++ b/ax/core/tests/test_trial.py @@ -460,9 +460,9 @@ def test_clone_to(self) -> None: # check that trial_type is cloned correctly self.assertEqual(new_trial.trial_type, "foo") - # test clear_trial_type + # test clear_trial_type - uses experiment's default_trial_type new_trial = self.trial.clone_to(clear_trial_type=True) - self.assertIsNone(new_trial.trial_type) + self.assertEqual(new_trial.trial_type, new_experiment.default_trial_type) def test_update_trial_status_on_clone(self) -> None: for status in [ diff --git a/ax/orchestration/orchestrator.py b/ax/orchestration/orchestrator.py index 8d4804893b7..a54e5d56496 100644 --- a/ax/orchestration/orchestrator.py +++ b/ax/orchestration/orchestrator.py @@ -58,7 +58,7 @@ set_ax_logger_levels, ) from ax.utils.common.timeutils import current_timestamp_in_millis -from pyre_extensions import assert_is_instance, none_throws +from pyre_extensions import none_throws NOT_IMPLEMENTED_IN_BASE_CLASS_MSG = """ \ @@ -364,7 +364,7 @@ def options(self, options: OrchestratorOptions) -> None: self._validate_runner_and_implemented_metrics(experiment=self.experiment) @property - def trial_type(self) -> str | None: + def trial_type(self) -> str: """Trial type for the experiment this Orchestrator is running. This returns None if the experiment is not a MultitypeExperiment @@ -374,8 +374,10 @@ def trial_type(self) -> str | None: experiment is a MultiTypeExperiment and None otherwise. """ if isinstance(self.experiment, MultiTypeExperiment): - return self.options.mt_experiment_trial_type - return None + return ( + self.options.mt_experiment_trial_type or Keys.DEFAULT_TRIAL_TYPE.value + ) + return Keys.DEFAULT_TRIAL_TYPE.value @property def running_trials(self) -> list[BaseTrial]: @@ -487,12 +489,8 @@ def runner(self) -> Runner: """``Runner`` specified on the experiment associated with this ``Orchestrator`` instance. """ - if self.trial_type is not None: - runner = assert_is_instance( - self.experiment, MultiTypeExperiment - ).runner_for_trial_type(trial_type=none_throws(self.trial_type)) - else: - runner = self.experiment.runner + runner = self.experiment.runner_for_trial_type(trial_type=self.trial_type) + if runner is None: raise UnsupportedError( "`Orchestrator` requires that experiment specifies a `Runner`." @@ -2034,11 +2032,11 @@ def _fetch_and_process_trials_data_results( try: kwargs = deepcopy(self.options.fetch_kwargs) - if self.trial_type is not None: - metrics = assert_is_instance( - self.experiment, MultiTypeExperiment - ).metrics_for_trial_type(trial_type=none_throws(self.trial_type)) - kwargs["metrics"] = metrics + metrics = self.experiment.metrics_for_trial_type( + trial_type=none_throws(self.trial_type) + ) + kwargs["metrics"] = metrics + results = self.experiment.fetch_trials_data_results( trial_indices=trial_indices, **kwargs, diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 0beccd47c09..213994b6f06 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -59,6 +59,7 @@ CORE_DECODER_REGISTRY, ) from ax.storage.utils import data_by_trial_to_data +from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger from ax.utils.common.serialization import ( extract_init_args, @@ -727,8 +728,19 @@ def experiment_from_json( } ) experiment._arms_by_name = {} - if _trial_type_to_runner is not None: + + # Handle backwards compatibility issue where some Experiments support None + # trial types. + if ( + _trial_type_to_runner is not None + and len(_trial_type_to_runner) > 0 + and ({*_trial_type_to_runner.keys()} != {None}) + ): experiment._trial_type_to_runner = _trial_type_to_runner + else: + experiment._trial_type_to_runner = { + Keys.DEFAULT_TRIAL_TYPE.value: experiment.runner + } _load_experiment_info( exp=experiment, diff --git a/ax/storage/json_store/decoders.py b/ax/storage/json_store/decoders.py index 5d65a745c97..26e30e10074 100644 --- a/ax/storage/json_store/decoders.py +++ b/ax/storage/json_store/decoders.py @@ -48,6 +48,7 @@ REMOVED_TRANSFORMS, REVERSE_TRANSFORM_REGISTRY, ) +from ax.utils.common.constants import Keys from ax.utils.common.kwargs import warn_on_kwargs from ax.utils.common.logger import get_logger from ax.utils.common.typeutils_torch import torch_type_from_str @@ -158,7 +159,9 @@ def batch_trial_from_json( # the SQ at the end of this function. ) batch._index = index - batch._trial_type = trial_type + batch._trial_type = ( + trial_type if trial_type is not None else Keys.DEFAULT_TRIAL_TYPE.value + ) batch._time_created = time_created batch._time_completed = time_completed batch._time_staged = time_staged @@ -219,7 +222,9 @@ def trial_from_json( experiment=experiment, generator_run=generator_run, ttl_seconds=ttl_seconds ) trial._index = index - trial._trial_type = trial_type + trial._trial_type = ( + trial_type if trial_type is not None else Keys.DEFAULT_TRIAL_TYPE.value + ) # Swap `DISPATCHED` for `RUNNING`, since `DISPATCHED` is deprecated and nearly # equivalent to `RUNNING`. trial._status = status if status != TrialStatus.DISPATCHED else TrialStatus.RUNNING diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 92628b093f3..40faa2f6485 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -386,11 +386,15 @@ def experiment_from_sqa( experiment.data = data_by_trial_to_data(data_by_trial=data_by_trial) trial_type_to_runner = { - sqa_runner.trial_type: self.runner_from_sqa(sqa_runner) + ( + sqa_runner.trial_type + if sqa_runner.trial_type is not None + else Keys.DEFAULT_TRIAL_TYPE.value + ): self.runner_from_sqa(sqa_runner) for sqa_runner in experiment_sqa.runners } if len(trial_type_to_runner) == 0: - trial_type_to_runner = {None: None} + trial_type_to_runner = {Keys.DEFAULT_TRIAL_TYPE.value: None} experiment._trials = {trial.index: trial for trial in trials} experiment._arms_by_name = {} @@ -415,9 +419,24 @@ def experiment_from_sqa( # `_trial_type_to_runner` is set in _init_mt_experiment_from_sqa if subclass != "MultiTypeExperiment": experiment._trial_type_to_runner = cast( - dict[str | None, Runner | None], trial_type_to_runner + dict[str, Runner | None], trial_type_to_runner ) experiment.db_id = experiment_sqa.id + + # For non-MultiTypeExperiment, populate _metric_to_trial_type + # This is needed because the metrics were added directly to the experiment + # without going through the setters that populate this field. + if subclass != "MultiTypeExperiment": + default_trial_type = Keys.DEFAULT_TRIAL_TYPE.value + # Add OC metrics + oc = experiment.optimization_config + if oc is not None: + for metric_name in oc.metrics.keys(): + experiment._metric_to_trial_type[metric_name] = default_trial_type + # Add tracking metrics + for metric_name in experiment._tracking_metrics.keys(): + experiment._metric_to_trial_type[metric_name] = default_trial_type + return experiment def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter: @@ -996,7 +1015,11 @@ def trial_from_sqa( reduced_state=reduced_state, immutable_search_space_and_opt_config=immutable_ss_and_oc, ) - trial._trial_type = trial_sqa.trial_type + trial._trial_type = ( + trial_sqa.trial_type + if trial_sqa.trial_type is not None + else Keys.DEFAULT_TRIAL_TYPE.value + ) # Swap `DISPATCHED` for `RUNNING`, since `DISPATCHED` is deprecated and nearly # equivalent to `RUNNING`. trial._status = ( diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index 00a9568caa7..cba1364ee7d 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -226,7 +226,12 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment: metric.name ] elif experiment.runner: - runners.append(self.runner_to_sqa(none_throws(experiment.runner))) + runners.append( + self.runner_to_sqa( + none_throws(experiment.runner), + trial_type=experiment.default_trial_type, + ) + ) properties = experiment._properties.copy() if ( oc := experiment.optimization_config diff --git a/ax/utils/common/constants.py b/ax/utils/common/constants.py index 2f6a6d5b4d5..a1f9fdf531b 100644 --- a/ax/utils/common/constants.py +++ b/ax/utils/common/constants.py @@ -53,6 +53,7 @@ class Keys(StrEnum): COST_INTERCEPT = "cost_intercept" CURRENT_VALUE = "current_value" DEFAULT_OBJECTIVE_NAME = "objective" + DEFAULT_TRIAL_TYPE = "default" EXPAND = "expand" EXPECTED_ACQF_VAL = "expected_acquisition_value" EXPERIMENT_TOTAL_CONCURRENT_ARMS = "total_concurrent_arms" From 67ff859564606f4d18b3c557609add3cbf880be6 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Tue, 10 Feb 2026 13:23:59 -0800 Subject: [PATCH 2/3] Remove bifrucation around MultiTypeExperiment (#4874) Summary: With recent changes to experiment we no longer need this bifructation. Next diff will remove places where we construct MultiTypeExperiment, and the one after will deprecate the class entirely Differential Revision: D91920991 --- ax/orchestration/orchestrator.py | 24 ++++----------------- ax/orchestration/tests/test_orchestrator.py | 16 ++++++-------- ax/service/ax_client.py | 15 +++++-------- ax/service/tests/test_report_utils.py | 4 ++-- ax/service/utils/report_utils.py | 8 ++++--- 5 files changed, 23 insertions(+), 44 deletions(-) diff --git a/ax/orchestration/orchestrator.py b/ax/orchestration/orchestrator.py index a54e5d56496..a7a3401c215 100644 --- a/ax/orchestration/orchestrator.py +++ b/ax/orchestration/orchestrator.py @@ -27,7 +27,6 @@ from ax.core.multi_type_experiment import ( filter_trials_by_type, get_trial_indices_for_statuses, - MultiTypeExperiment, ) from ax.core.runner import Runner from ax.core.trial import Trial @@ -367,17 +366,11 @@ def options(self, options: OrchestratorOptions) -> None: def trial_type(self) -> str: """Trial type for the experiment this Orchestrator is running. - This returns None if the experiment is not a MultitypeExperiment - Returns: - Trial type for the experiment this Orchestrator is running if the - experiment is a MultiTypeExperiment and None otherwise. + Trial type for the experiment this Orchestrator is running. + Defaults to Keys.DEFAULT_TRIAL_TYPE if not specified. """ - if isinstance(self.experiment, MultiTypeExperiment): - return ( - self.options.mt_experiment_trial_type or Keys.DEFAULT_TRIAL_TYPE.value - ) - return Keys.DEFAULT_TRIAL_TYPE.value + return self.options.mt_experiment_trial_type or Keys.DEFAULT_TRIAL_TYPE.value @property def running_trials(self) -> list[BaseTrial]: @@ -1619,11 +1612,7 @@ def _validate_options(self, options: OrchestratorOptions) -> None: "will be unable to fetch intermediate results with which to " "evaluate early stopping criteria." ) - if isinstance(self.experiment, MultiTypeExperiment): - if options.mt_experiment_trial_type is None: - raise UserInputError( - "Must specify `mt_experiment_trial_type` for MultiTypeExperiment." - ) + if options.mt_experiment_trial_type is not None: if not self.experiment.supports_trial_type( options.mt_experiment_trial_type ): @@ -1631,11 +1620,6 @@ def _validate_options(self, options: OrchestratorOptions) -> None: "Experiment does not support trial type " f"{options.mt_experiment_trial_type}." ) - elif options.mt_experiment_trial_type is not None: - raise UserInputError( - "`mt_experiment_trial_type` must be None unless the experiment is a " - "MultiTypeExperiment." - ) def _get_max_pending_trials(self) -> int: """Returns the maximum number of pending trials specified in the options, or diff --git a/ax/orchestration/tests/test_orchestrator.py b/ax/orchestration/tests/test_orchestrator.py index 67f982ca1a1..8417fd031d7 100644 --- a/ax/orchestration/tests/test_orchestrator.py +++ b/ax/orchestration/tests/test_orchestrator.py @@ -2736,12 +2736,9 @@ def test_validate_options_not_none_mt_trial_type( self, msg: str | None = None ) -> None: # test that error is raised if `mt_experiment_trial_type` is not - # compatible with the type of experiment (single or multi-type) + # a supported trial type for this experiment if msg is None: - msg = ( - "`mt_experiment_trial_type` must be None unless the experiment is a " - "MultiTypeExperiment." - ) + msg = "Experiment does not support trial type type1." options = OrchestratorOptions( init_seconds_between_polls=0, # No wait bw polls so test is fast. batch_size=10, @@ -2752,7 +2749,7 @@ def test_validate_options_not_none_mt_trial_type( ), ) gs = self.two_sobol_steps_GS - with self.assertRaisesRegex(UserInputError, msg): + with self.assertRaisesRegex(ValueError, msg): Orchestrator( experiment=self.branin_experiment, generation_strategy=gs, @@ -3010,10 +3007,11 @@ def test_fetch_and_process_trials_data_results_failed_non_objective( def test_validate_options_not_none_mt_trial_type( self, msg: str | None = None ) -> None: - # test if a MultiTypeExperiment with `mt_experiment_trial_type=None` - self.orchestrator_options_kwargs["mt_experiment_trial_type"] = None + # test that error is raised if `mt_experiment_trial_type` is not + # a supported trial type for this experiment (using an invalid type) + self.orchestrator_options_kwargs["mt_experiment_trial_type"] = "invalid_type" super().test_validate_options_not_none_mt_trial_type( - msg="Must specify `mt_experiment_trial_type` for MultiTypeExperiment." + msg="Experiment does not support trial type invalid_type." ) def test_run_n_trials_single_step_existing_experiment( diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 592c316f037..fd1d4794886 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -24,7 +24,6 @@ from ax.core.evaluations_to_data import raw_evaluations_to_data from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun -from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, Objective from ax.core.observation import ObservationFeatures from ax.core.runner import Runner @@ -458,15 +457,11 @@ def add_tracking_metrics( for metric_name in metric_names ] - if isinstance(self.experiment, MultiTypeExperiment): - experiment = assert_is_instance(self.experiment, MultiTypeExperiment) - experiment.add_tracking_metrics( - metrics=metric_objects, - metrics_to_trial_types=metrics_to_trial_types, - canonical_names=canonical_names, - ) - else: - self.experiment.add_tracking_metrics(metrics=metric_objects) + self.experiment.add_tracking_metrics( + metrics=metric_objects, + metrics_to_trial_types=metrics_to_trial_types, + **({"canonical_names": canonical_names} if canonical_names else {}), + ) @copy_doc(Experiment.remove_tracking_metric) def remove_tracking_metric(self, metric_name: str) -> None: diff --git a/ax/service/tests/test_report_utils.py b/ax/service/tests/test_report_utils.py index e857215be50..044c9fca78c 100644 --- a/ax/service/tests/test_report_utils.py +++ b/ax/service/tests/test_report_utils.py @@ -199,9 +199,9 @@ def test_exp_to_df_with_failure(self) -> None: self.assertEqual(f"{fail_reason}...", df["reason"].iloc[0]) def test_exp_to_df(self) -> None: - # MultiTypeExperiment should fail + # Experiments with multiple trial types should fail exp = get_multi_type_experiment() - with self.assertRaisesRegex(ValueError, "MultiTypeExperiment"): + with self.assertRaisesRegex(ValueError, "multiple trial types"): exp_to_df(exp=exp) # exp with no trials should return empty results diff --git a/ax/service/utils/report_utils.py b/ax/service/utils/report_utils.py index 4ddabe4c211..3f434504041 100644 --- a/ax/service/utils/report_utils.py +++ b/ax/service/utils/report_utils.py @@ -32,7 +32,6 @@ from ax.core.generator_run import GeneratorRunType from ax.core.map_metric import MapMetric from ax.core.metric import Metric -from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, ScalarizedObjective from ax.core.optimization_config import ( MultiObjectiveOptimizationConfig, @@ -787,8 +786,11 @@ def exp_to_df( ) # Accept Experiment and SimpleExperiment - if isinstance(exp, MultiTypeExperiment): - raise ValueError("Cannot transform MultiTypeExperiments to DataFrames.") + # Reject experiments with multiple trial types as they need special handling + if len(exp._trial_type_to_runner) > 1: + raise ValueError( + "Cannot transform experiments with multiple trial types to DataFrames." + ) key_components = ["trial_index", "arm_name"] From c56f900495a35cebbfdb12a15320fa97889093ec Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Tue, 10 Feb 2026 13:23:59 -0800 Subject: [PATCH 3/3] Remove MultiTypeExperiment (#4875) Summary: Guts MultiTypeExperiment class down to a deprecation warning, replaces all places which initialize a MultiTypeExperiment with a base Experiment, and updated type annotations. Updated storage accordingly; previously stored MultiTypeExperiments will be correctly decoded as Experiments. As previously discussed, also deprecated metric_to_cannonical_name mapping as we intend to reexamine this design in the context of the metric_signature field in Data we added in H2 2025 Differential Revision: D92089176 --- ax/core/experiment.py | 9 + ax/core/multi_type_experiment.py | 329 +------------------ ax/core/tests/test_multi_type_experiment.py | 41 +-- ax/orchestration/tests/test_orchestrator.py | 12 +- ax/service/tests/test_ax_client.py | 6 +- ax/service/tests/test_instantiation_utils.py | 2 +- ax/service/utils/instantiation.py | 28 +- ax/storage/json_store/decoder.py | 96 +++--- ax/storage/json_store/encoders.py | 30 +- ax/storage/sqa_store/decoder.py | 160 +++++---- ax/storage/sqa_store/encoder.py | 7 +- ax/storage/sqa_store/tests/test_sqa_store.py | 4 - ax/utils/testing/core_stubs.py | 10 +- 13 files changed, 201 insertions(+), 533 deletions(-) diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 39b466faf02..8edcbba241d 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -1225,6 +1225,15 @@ def trial_indices_with_data( return trials_with_data + @property + def default_trials(self) -> set[int]: + """Return the indicies for trials of the default type.""" + return { + idx + for idx, trial in self.trials.items() + if trial.trial_type == self.default_trial_type + } + def new_trial( self, generator_run: GeneratorRun | None = None, diff --git a/ax/core/multi_type_experiment.py b/ax/core/multi_type_experiment.py index b6c552d2bcb..754a0d626ab 100644 --- a/ax/core/multi_type_experiment.py +++ b/ax/core/multi_type_experiment.py @@ -6,334 +6,27 @@ # pyre-strict -from collections.abc import Iterable, Sequence -from typing import Any, Self +from collections.abc import Sequence -from ax.core.arm import Arm -from ax.core.base_trial import BaseTrial, TrialStatus -from ax.core.data import Data +from ax.core.base_trial import BaseTrial from ax.core.experiment import Experiment -from ax.core.metric import Metric, MetricFetchResult -from ax.core.optimization_config import OptimizationConfig -from ax.core.runner import Runner -from ax.core.search_space import SearchSpace -from ax.utils.common.docutils import copy_doc -from pyre_extensions import none_throws +from ax.core.trial_status import TrialStatus class MultiTypeExperiment(Experiment): """Class for experiment with multiple trial types. - A canonical use case for this is tuning a large production system - with limited evaluation budget and a simulator which approximates - evaluations on the main system. Trial deployment and data fetching - is separate for the two systems, but the final data is combined and - fed into multi-task models. + .. deprecated:: + The `MultiTypeExperiment` class is deprecated. Use `Experiment` with + `default_trial_type` parameter instead. All multi-type experiment + functionality has been moved to the base `Experiment` class. - See the Multi-Task Modeling tutorial for more details. - - Attributes: - name: Name of the experiment. - description: Description of the experiment. """ - def __init__( - self, - name: str, - search_space: SearchSpace, - default_trial_type: str, - default_runner: Runner | None, - optimization_config: OptimizationConfig | None = None, - tracking_metrics: list[Metric] | None = None, - status_quo: Arm | None = None, - description: str | None = None, - is_test: bool = False, - experiment_type: str | None = None, - properties: dict[str, Any] | None = None, - default_data_type: Any = None, - ) -> None: - """Inits Experiment. - - Args: - name: Name of the experiment. - search_space: Search space of the experiment. - default_trial_type: Default type for trials on this experiment. - default_runner: Default runner for trials of the default type. - optimization_config: Optimization config of the experiment. - tracking_metrics: Additional tracking metrics not used for optimization. - These are associated with the default trial type. - runner: Default runner used for trials on this experiment. - status_quo: Arm representing existing "control" arm. - description: Description of the experiment. - is_test: Convenience metadata tracker for the user to mark test experiments. - experiment_type: The class of experiments this one belongs to. - properties: Dictionary of this experiment's properties. - default_data_type: Deprecated and ignored. - """ - - # Specifies which trial type each metric belongs to - self._metric_to_trial_type: dict[str, str] = {} - - # Maps certain metric names to a canonical name. Useful for ancillary trial - # types' metrics, to specify which primary metrics they correspond to - # (e.g. 'comment_prediction' => 'comment') - self._metric_to_canonical_name: dict[str, str] = {} - - # call super.__init__() after defining fields above, because we need - # them to be populated before optimization config is set - super().__init__( - name=name, - search_space=search_space, - optimization_config=optimization_config, - status_quo=status_quo, - description=description, - is_test=is_test, - experiment_type=experiment_type, - properties=properties, - tracking_metrics=tracking_metrics, - runner=default_runner, - default_trial_type=default_trial_type, - default_data_type=default_data_type, - ) - - def add_trial_type(self, trial_type: str, runner: Runner) -> Self: - """Add a new trial_type to be supported by this experiment. - - Args: - trial_type: The new trial_type to be added. - runner: The default runner for trials of this type. - """ - if self.supports_trial_type(trial_type): - raise ValueError(f"Experiment already contains trial_type `{trial_type}`") - - self._trial_type_to_runner[trial_type] = runner - return self - - # pyre-fixme [56]: Pyre was not able to infer the type of the decorator - # `Experiment.optimization_config.setter`. - @Experiment.optimization_config.setter - def optimization_config(self, optimization_config: OptimizationConfig) -> None: - # pyre-fixme [16]: `Optional` has no attribute `fset`. - Experiment.optimization_config.fset(self, optimization_config) - for metric_name in optimization_config.metrics.keys(): - # Optimization config metrics are required to be the default trial type - # currently. TODO: remove that restriction (T202797235) - self._metric_to_trial_type[metric_name] = none_throws( - self.default_trial_type - ) - - def update_runner(self, trial_type: str, runner: Runner) -> Self: - """Update the default runner for an existing trial_type. - - Args: - trial_type: The new trial_type to be added. - runner: The new runner for trials of this type. - """ - if not self.supports_trial_type(trial_type): - raise ValueError(f"Experiment does not contain trial_type `{trial_type}`") - - self._trial_type_to_runner[trial_type] = runner - self._runner = runner - return self - - def add_tracking_metric( - self, - metric: Metric, - trial_type: str | None = None, - canonical_name: str | None = None, - ) -> Self: - """Add a new metric to the experiment. - - Args: - metric: The metric to add. - trial_type: The trial type for which this metric is used. - canonical_name: The default metric for which this metric is a proxy. - """ - if trial_type is None: - trial_type = self._default_trial_type - if not self.supports_trial_type(trial_type): - raise ValueError(f"`{trial_type}` is not a supported trial type.") - - super().add_tracking_metric(metric) - self._metric_to_trial_type[metric.name] = none_throws(trial_type) - if canonical_name is not None: - self._metric_to_canonical_name[metric.name] = canonical_name - return self - - def add_tracking_metrics( - self, - metrics: list[Metric], - metrics_to_trial_types: dict[str, str] | None = None, - canonical_names: dict[str, str] | None = None, - ) -> Experiment: - """Add a list of new metrics to the experiment. - - If any of the metrics are already defined on the experiment, - we raise an error and don't add any of them to the experiment - - Args: - metrics: Metrics to be added. - metrics_to_trial_types: The mapping from metric names to corresponding - trial types for each metric. If provided, the metrics will be - added to their trial types. If not provided, then the default - trial type will be used. - canonical_names: A mapping of metric names to their - canonical names(The default metrics for which the metrics are - proxies.) - - Returns: - The experiment with the added metrics. - """ - metrics_to_trial_types = metrics_to_trial_types or {} - canonical_name = None - for metric in metrics: - if canonical_names is not None: - canonical_name = none_throws(canonical_names).get(metric.name, None) - - self.add_tracking_metric( - metric=metric, - trial_type=metrics_to_trial_types.get( - metric.name, self._default_trial_type - ), - canonical_name=canonical_name, - ) - return self - - def update_tracking_metric( - self, - metric: Metric, - trial_type: str | None = None, - canonical_name: str | None = None, - ) -> Self: - """Update an existing metric on the experiment. - - Args: - metric: The metric to add. - trial_type: The trial type for which this metric is used. Defaults to - the current trial type of the metric (if set), or the default trial - type otherwise. - canonical_name: The default metric for which this metric is a proxy. - """ - # Default to the existing trial type if not specified - if trial_type is None: - trial_type = self._metric_to_trial_type.get( - metric.name, self._default_trial_type - ) - oc = self.optimization_config - oc_metrics = oc.metrics if oc else [] - if metric.name in oc_metrics and trial_type != self._default_trial_type: - raise ValueError( - f"Metric `{metric.name}` must remain a " - f"`{self._default_trial_type}` metric because it is part of the " - "optimization_config." - ) - elif not self.supports_trial_type(trial_type): - raise ValueError(f"`{trial_type}` is not a supported trial type.") - - super().update_tracking_metric(metric) - self._metric_to_trial_type[metric.name] = none_throws(trial_type) - if canonical_name is not None: - self._metric_to_canonical_name[metric.name] = canonical_name - return self - - @copy_doc(Experiment.remove_tracking_metric) - def remove_tracking_metric(self, metric_name: str) -> Self: - if metric_name not in self._tracking_metrics: - raise ValueError(f"Metric `{metric_name}` doesn't exist on experiment.") - - # Required fields - del self._tracking_metrics[metric_name] - del self._metric_to_trial_type[metric_name] - - # Optional - if metric_name in self._metric_to_canonical_name: - del self._metric_to_canonical_name[metric_name] - return self - - @copy_doc(Experiment.fetch_data) - def fetch_data( - self, - trial_indices: Iterable[int] | None = None, - metrics: list[Metric] | None = None, - **kwargs: Any, - ) -> Data: - # TODO: make this more efficient for fetching - # data for multiple trials of the same type - # by overriding Experiment._lookup_or_fetch_trials_results - return Data.from_multiple_data( - [ - ( - trial.fetch_data(**kwargs, metrics=metrics) - if trial.status.expecting_data - else Data() - ) - for trial in self.trials.values() - ] - ) - - @copy_doc(Experiment._fetch_trial_data) - def _fetch_trial_data( - self, trial_index: int, metrics: list[Metric] | None = None, **kwargs: Any - ) -> dict[str, MetricFetchResult]: - trial = self.trials[trial_index] - metrics = [ - metric - for metric in (metrics or self.metrics.values()) - if self.metric_to_trial_type[metric.name] == trial.trial_type - ] - # Invoke parent's fetch method using only metrics for this trial_type - return super()._fetch_trial_data(trial.index, metrics=metrics, **kwargs) - - @property - def default_trials(self) -> set[int]: - """Return the indicies for trials of the default type.""" - return { - idx - for idx, trial in self.trials.items() - if trial.trial_type == self.default_trial_type - } - - @property - def metric_to_trial_type(self) -> dict[str, str]: - """Map metrics to trial types. - - Adds in default trial type for OC metrics to custom defined trial types.. - """ - opt_config_types = { - metric_name: self.default_trial_type - for metric_name in self.optimization_config.metrics.keys() - } - return {**opt_config_types, **self._metric_to_trial_type} - - # -- Overridden functions from Base Experiment Class -- - @property - def default_trial_type(self) -> str | None: - """Default trial type assigned to trials in this experiment.""" - return self._default_trial_type - - def metrics_for_trial_type(self, trial_type: str) -> list[Metric]: - """The default runner to use for a given trial type. - - Looks up the appropriate runner for this trial type in the trial_type_to_runner. - """ - if not self.supports_trial_type(trial_type): - raise ValueError(f"Trial type `{trial_type}` is not supported.") - return [ - self.metrics[metric_name] - for metric_name, metric_trial_type in self._metric_to_trial_type.items() - if metric_trial_type == trial_type - ] - - def supports_trial_type(self, trial_type: str | None) -> bool: - """Whether this experiment allows trials of the given type. - - Only trial types defined in the trial_type_to_runner are allowed. - """ - return trial_type in self._trial_type_to_runner.keys() - def filter_trials_by_type( - trials: Sequence[BaseTrial], trial_type: str | None + trials: Sequence[BaseTrial], + trial_type: str | None, ) -> list[BaseTrial]: """Filter trials by trial type if provided. @@ -352,7 +45,9 @@ def filter_trials_by_type( def get_trial_indices_for_statuses( - experiment: Experiment, statuses: set[TrialStatus], trial_type: str | None = None + experiment: Experiment, + statuses: set[TrialStatus], + trial_type: str | None = None, ) -> set[int]: """Get trial indices for a set of statuses. diff --git a/ax/core/tests/test_multi_type_experiment.py b/ax/core/tests/test_multi_type_experiment.py index d8a28e60ac2..e18b800985e 100644 --- a/ax/core/tests/test_multi_type_experiment.py +++ b/ax/core/tests/test_multi_type_experiment.py @@ -48,18 +48,8 @@ def test_MTExperimentFlow(self) -> None: self.assertEqual(b2.run_metadata["dummy_metadata"], "dummy3") df = self.experiment.fetch_data().df - for _, row in df.iterrows(): - # Make sure proper metric present for each batch only - self.assertEqual( - row["metric_name"], "m1" if row["trial_index"] == 0 else "m2" - ) - arm_0_slice = df.loc[df["arm_name"] == "0_0"] - self.assertNotEqual( - float(arm_0_slice[df["trial_index"] == 0]["mean"].item()), - float(arm_0_slice[df["trial_index"] == 1]["mean"].item()), - ) - self.assertEqual(len(df), 2 * n) + self.assertEqual(len(df), 4 * n) self.assertEqual(self.experiment.default_trials, {0}) # Set 2 metrics to be equal self.experiment.update_tracking_metric( @@ -68,13 +58,21 @@ def test_MTExperimentFlow(self) -> None: df = self.experiment.fetch_data().df arm_0_slice = df.loc[df["arm_name"] == "0_0"] self.assertAlmostEqual( - float(arm_0_slice[df["trial_index"] == 0]["mean"].item()), - float(arm_0_slice[df["trial_index"] == 1]["mean"].item()), + float( + arm_0_slice[(df["trial_index"] == 0) & (df["metric_name"] == "m2")][ + "mean" + ].item() + ), + float( + arm_0_slice[(df["trial_index"] == 1) & (df["metric_name"] == "m2")][ + "mean" + ].item() + ), places=10, ) def test_Repr(self) -> None: - self.assertEqual(str(self.experiment), "MultiTypeExperiment(test_exp)") + self.assertEqual(str(self.experiment), "Experiment(test_exp)") def test_Eq(self) -> None: exp2 = get_multi_type_experiment() @@ -83,24 +81,19 @@ def test_Eq(self) -> None: self.assertTrue(self.experiment == exp2) self.experiment.add_tracking_metric( - BraninMetric("m3", ["x2", "x1"]), trial_type="type1", canonical_name="m4" + BraninMetric("m3", ["x2", "x1"]), + trial_type="type1", ) # Test different set of metrics self.assertFalse(self.experiment == exp2) exp2.add_tracking_metric( - BraninMetric("m3", ["x2", "x1"]), trial_type="type1", canonical_name="m5" - ) - - # Test different metric definitions - self.assertFalse(self.experiment == exp2) - - exp2.update_tracking_metric( - BraninMetric("m3", ["x2", "x1"]), trial_type="type1", canonical_name="m4" + BraninMetric("m3", ["x2", "x1"]), + trial_type="type1", ) - # Should be the same + # Both have the same metrics now, should be equal self.assertTrue(self.experiment == exp2) exp2.remove_tracking_metric("m3") diff --git a/ax/orchestration/tests/test_orchestrator.py b/ax/orchestration/tests/test_orchestrator.py index 8417fd031d7..7d04acb154f 100644 --- a/ax/orchestration/tests/test_orchestrator.py +++ b/ax/orchestration/tests/test_orchestrator.py @@ -1315,9 +1315,9 @@ def _should_stop_trials_early( len(res_list[1]["trials_early_stopped_so_far"]), ) - looked_up_data = orchestrator.experiment.lookup_data() fetched_data = orchestrator.experiment.fetch_data() - num_metrics = 2 + looked_up_data = orchestrator.experiment.lookup_data() + num_metrics = len(orchestrator.experiment.metrics) expected_num_rows = num_metrics * total_trials # There are 3 trials and two metrics for "type1" for MT experiments self.assertEqual(len(looked_up_data.df), expected_num_rows) @@ -1329,7 +1329,7 @@ def _should_stop_trials_early( # longer and gets results for an extra timestamp. # For MultiTypeExperiment there are two metrics # for trial type "type1" - expected_num_rows = 7 + expected_num_rows = expected_num_rows + 1 self.assertEqual(len(looked_up_data.full_df), expected_num_rows) self.assertEqual(len(fetched_data.full_df), expected_num_rows) ess = orchestrator.options.early_stopping_strategy @@ -2842,7 +2842,7 @@ def test_terminate_if_status_quo_infeasible(self) -> None: class TestAxOrchestratorMultiTypeExperiment(TestAxOrchestrator): # After D80128678, choose_generation_strategy_legacy returns node-based GS. EXPECTED_orchestrator_REPR: str = ( - "Orchestrator(experiment=MultiTypeExperiment(branin_test_experiment), " + "Orchestrator(experiment=Experiment(branin_test_experiment), " "generation_strategy=GenerationStrategy(" "name='GenerationStep_0_Sobol+GenerationStep_1_BoTorch', " "nodes=[GenerationNode(name='GenerationStep_0_Sobol', " @@ -2894,13 +2894,13 @@ def setUp(self) -> None: trial_type="type1", runner=RunnerToAllowMultipleMapMetricFetches() ) - self.branin_experiment_no_impl_runner_or_metrics = MultiTypeExperiment( + self.branin_experiment_no_impl_runner_or_metrics = Experiment( search_space=get_branin_search_space(), optimization_config=OptimizationConfig( Objective(Metric(name="branin"), minimize=True) ), default_trial_type="type1", - default_runner=None, + runner=None, name="branin_experiment_no_impl_runner_or_metrics", ) self.sobol_MBM_GS = choose_generation_strategy_legacy( diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index afe81f05d01..836a5b08bc5 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -21,9 +21,9 @@ from ax.adapter.registry import Cont_X_trans, Generators from ax.core.arm import Arm from ax.core.data import Data, MAP_KEY +from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric -from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.optimization_config import MultiObjectiveOptimizationConfig from ax.core.outcome_constraint import ObjectiveThreshold, OutcomeConstraint from ax.core.parameter import ( @@ -974,8 +974,8 @@ def test_create_multitype_experiment(self) -> None: default_runner=SyntheticRunner(), ) - self.assertEqual(ax_client.experiment.__class__.__name__, "MultiTypeExperiment") - experiment = assert_is_instance(ax_client.experiment, MultiTypeExperiment) + self.assertEqual(ax_client.experiment.__class__.__name__, "Experiment") + experiment = assert_is_instance(ax_client.experiment, Experiment) self.assertEqual( experiment._trial_type_to_runner["test_trial_type"].__class__.__name__, "SyntheticRunner", diff --git a/ax/service/tests/test_instantiation_utils.py b/ax/service/tests/test_instantiation_utils.py index 160518ec457..c75c3ca3f5c 100644 --- a/ax/service/tests/test_instantiation_utils.py +++ b/ax/service/tests/test_instantiation_utils.py @@ -372,7 +372,7 @@ def test_make_multitype_experiment_with_default_trial_type(self) -> None: default_trial_type="test_trial_type", default_runner=SyntheticRunner(), ) - self.assertEqual(experiment.__class__.__name__, "MultiTypeExperiment") + self.assertEqual(experiment.__class__.__name__, "Experiment") def test_make_single_type_experiment_with_no_default_trial_type(self) -> None: experiment = InstantiationBase.make_experiment( diff --git a/ax/service/utils/instantiation.py b/ax/service/utils/instantiation.py index d08849d6f0d..f00bbc4f57f 100644 --- a/ax/service/utils/instantiation.py +++ b/ax/service/utils/instantiation.py @@ -17,7 +17,6 @@ from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose from ax.core.experiment import Experiment from ax.core.metric import Metric -from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, Objective from ax.core.observation import ObservationFeatures from ax.core.optimization_config import ( @@ -850,12 +849,9 @@ def make_experiment( auxiliary_experiments_by_purpose: Dictionary of auxiliary experiments for different use cases (e.g., transfer learning). default_trial_type: The default trial type if multiple - trial types are intended to be used in the experiment. If specified, - a MultiTypeExperiment will be created. Otherwise, a single-type - Experiment will be created. + trial types are intended to be used in the experiment. default_runner: The default runner in this experiment. - This only applies to MultiTypeExperiment (when default_trial_type - is specified). + This is required if default_trial_type is specified. is_test: Whether this experiment will be a test experiment (useful for marking test experiments in storage etc). Defaults to False. @@ -863,7 +859,7 @@ def make_experiment( if (default_trial_type is None) != (default_runner is None): raise ValueError( "Must specify both default_trial_type and default_runner if " - "using a MultiTypeExperiment." + "using multiple trial types." ) status_quo_arm = None if status_quo is None else Arm(parameters=status_quo) @@ -900,23 +896,6 @@ def make_experiment( if owners is not None: properties["owners"] = owners - if default_trial_type is not None: - return MultiTypeExperiment( - name=none_throws(name), - search_space=cls.make_search_space( - parameters=parameters, parameter_constraints=parameter_constraints - ), - default_trial_type=none_throws(default_trial_type), - default_runner=none_throws(default_runner), - optimization_config=optimization_config, - tracking_metrics=tracking_metrics, - status_quo=status_quo_arm, - description=description, - is_test=is_test, - experiment_type=experiment_type, - properties=properties, - ) - return Experiment( name=name, description=description, @@ -929,6 +908,7 @@ def make_experiment( auxiliary_experiments_by_purpose=auxiliary_experiments_by_purpose, is_test=is_test, runner=default_runner, + default_trial_type=default_trial_type or Keys.DEFAULT_TRIAL_TYPE.value, ) @classmethod diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 213994b6f06..371277667be 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -654,49 +654,17 @@ def multi_type_experiment_from_json( object_json: dict[str, Any], decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, -) -> MultiTypeExperiment: - """Load AE MultiTypeExperiment from JSON.""" - experiment_info = _get_experiment_info(object_json) - - _metric_to_canonical_name = object_json.pop("_metric_to_canonical_name") - _metric_to_trial_type = object_json.pop("_metric_to_trial_type") - _trial_type_to_runner = object_from_json( - object_json.pop("_trial_type_to_runner"), - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, - ) - tracking_metrics = object_from_json( - object_json.pop("tracking_metrics"), - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, - ) - # not relevant to multi type experiment - del object_json["runner"] - - kwargs = { - k: object_from_json( - v, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, - ) - for k, v in object_json.items() - } - kwargs["default_runner"] = _trial_type_to_runner[object_json["default_trial_type"]] - - experiment = MultiTypeExperiment(**kwargs) - for metric in tracking_metrics: - experiment._tracking_metrics[metric.name] = metric - experiment._metric_to_canonical_name = _metric_to_canonical_name - experiment._metric_to_trial_type = _metric_to_trial_type - experiment._trial_type_to_runner = _trial_type_to_runner +) -> Experiment: + """Load AE MultiTypeExperiment from JSON. - _load_experiment_info( - exp=experiment, - exp_info=experiment_info, + This function is kept for backwards compatibility. It delegates to the + unified experiment_from_json which handles all experiments uniformly. + """ + return experiment_from_json( + object_json=object_json, decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ) - return experiment def experiment_from_json( @@ -704,8 +672,17 @@ def experiment_from_json( decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> Experiment: - """Load Ax Experiment from JSON.""" + """Load Ax Experiment from JSON. + + This function handles all experiments uniformly, including those with + multiple trial types (formerly MultiTypeExperiment). + """ experiment_info = _get_experiment_info(object_json) + + # Handle _metric_to_trial_type (may or may not be present) + _metric_to_trial_type = object_json.pop("_metric_to_trial_type", {}) + + # Handle _trial_type_to_runner _trial_type_to_runner_json = object_json.pop("_trial_type_to_runner", None) _trial_type_to_runner = ( object_from_json( @@ -717,6 +694,20 @@ def experiment_from_json( else None ) + # Handle tracking_metrics separately for multi-type experiments + tracking_metrics = object_from_json( + object_json.pop("tracking_metrics"), + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) + + # Handle default_trial_type (may or may not be present) + default_trial_type = object_json.pop("default_trial_type", None) + + # For multi-type experiments, runner in JSON is not relevant + # (we use _trial_type_to_runner instead) + runner_json = object_json.pop("runner", None) + experiment = Experiment( **{ k: object_from_json( @@ -729,18 +720,33 @@ def experiment_from_json( ) experiment._arms_by_name = {} - # Handle backwards compatibility issue where some Experiments support None - # trial types. + # Add tracking metrics + for metric in tracking_metrics: + experiment._tracking_metrics[metric.name] = metric + + # Set up _metric_to_trial_type + experiment._metric_to_trial_type = _metric_to_trial_type + + # Set up _trial_type_to_runner if ( _trial_type_to_runner is not None and len(_trial_type_to_runner) > 0 and ({*_trial_type_to_runner.keys()} != {None}) ): experiment._trial_type_to_runner = _trial_type_to_runner + # Set the runner from _trial_type_to_runner if we have a default_trial_type + if default_trial_type is not None: + experiment._runner = _trial_type_to_runner.get(default_trial_type) + experiment._default_trial_type = default_trial_type else: - experiment._trial_type_to_runner = { - Keys.DEFAULT_TRIAL_TYPE.value: experiment.runner - } + # Decode and set the runner for non-multi-type experiments + runner = object_from_json( + runner_json, + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) + experiment._runner = runner + experiment._trial_type_to_runner = {Keys.DEFAULT_TRIAL_TYPE.value: runner} _load_experiment_info( exp=experiment, diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index bfd6157f129..8a7e1123cef 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -18,7 +18,6 @@ from ax.core.data import Data from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric -from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, Objective, ScalarizedObjective from ax.core.optimization_config import ( MultiObjectiveOptimizationConfig, @@ -75,7 +74,12 @@ def experiment_to_dict(experiment: Experiment) -> dict[str, Any]: - """Convert Ax experiment to a dictionary.""" + """Convert Ax experiment to a dictionary. + + This encoder handles all experiments uniformly, including those with + multiple trial types. The _metric_to_trial_type and _trial_type_to_runner + fields are always included. + """ return { "__type": experiment.__class__.__name__, "name": experiment._name, @@ -92,19 +96,21 @@ def experiment_to_dict(experiment: Experiment) -> dict[str, Any]: "data_by_trial": data_to_data_by_trial(data=experiment.data), "properties": experiment._properties, "_trial_type_to_runner": experiment._trial_type_to_runner, + "_metric_to_trial_type": experiment._metric_to_trial_type, + "default_trial_type": experiment._default_trial_type, } -def multi_type_experiment_to_dict(experiment: MultiTypeExperiment) -> dict[str, Any]: - """Convert AE multitype experiment to a dictionary.""" - multi_type_dict = { - "default_trial_type": experiment._default_trial_type, - "_metric_to_canonical_name": experiment._metric_to_canonical_name, - "_metric_to_trial_type": experiment._metric_to_trial_type, - "_trial_type_to_runner": experiment._trial_type_to_runner, - } - multi_type_dict.update(experiment_to_dict(experiment)) - return multi_type_dict +def multi_type_experiment_to_dict(experiment: Experiment) -> dict[str, Any]: + """Convert AE multitype experiment to a dictionary. + + This encoder is kept for backwards compatibility. It uses the same + unified encoding as experiment_to_dict but sets __type to + MultiTypeExperiment for older loaders. + """ + result = experiment_to_dict(experiment) + result["__type"] = "MultiTypeExperiment" + return result def batch_to_dict(batch: BatchTrial) -> dict[str, Any]: diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 40faa2f6485..4d25c33d0e6 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -31,7 +31,6 @@ from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric -from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, Objective, ScalarizedObjective from ax.core.optimization_config import ( MultiObjectiveOptimizationConfig, @@ -214,10 +213,16 @@ def _init_experiment_from_sqa( load_auxiliary_experiments: bool = True, reduced_state: bool = False, ) -> Experiment: - """First step of conversion within experiment_from_sqa.""" + """Convert SQAExperiment to Experiment. + + This method handles all experiments uniformly, including those with + multiple trial types (formerly MultiTypeExperiment). + """ # `experiment_sqa.properties` is `sqlalchemy.ext.mutable.MutableDict` # so need to convert it to regular dict. properties = dict(experiment_sqa.properties or {}) + is_multi_type = properties.get(Keys.SUBCLASS) == "MultiTypeExperiment" + opt_config, tracking_metrics = self.opt_config_and_tracking_metrics_from_sqa( metrics_sqa=experiment_sqa.metrics, pruning_target_parameterization=( @@ -240,14 +245,34 @@ def _init_experiment_from_sqa( if experiment_sqa.status_quo_parameters is not None else None ) + + # Build trial_type_to_runner from runners + trial_type_to_runner: dict[str, Runner | None] = {} if len(experiment_sqa.runners) == 0: runner = None - elif len(experiment_sqa.runners) == 1: + elif len(experiment_sqa.runners) == 1 and not is_multi_type: runner = self.runner_from_sqa(runner_sqa=experiment_sqa.runners[0]) else: - raise ValueError( - "Multiple runners on experiment only supported for MultiTypeExperiment." - ) + # Multiple runners or multi-type experiment + runner = None + for sqa_runner in experiment_sqa.runners: + trial_type = sqa_runner.trial_type + trial_type_to_runner[none_throws(trial_type)] = self.runner_from_sqa( + sqa_runner + ) + + # Handle default_trial_type for multi-type experiments + default_trial_type = experiment_sqa.default_trial_type + if is_multi_type and default_trial_type is not None: + if len(trial_type_to_runner) == 0: + trial_type_to_runner = {default_trial_type: None} + trial_types_with_metrics = { + metric.trial_type + for metric in experiment_sqa.metrics + if metric.trial_type + } + trial_type_to_runner.update(dict.fromkeys(trial_types_with_metrics)) + runner = trial_type_to_runner.get(default_trial_type) auxiliary_experiments_by_purpose = ( ( @@ -260,88 +285,52 @@ def _init_experiment_from_sqa( else {} ) - return Experiment( + experiment = Experiment( name=experiment_sqa.name, description=experiment_sqa.description, search_space=search_space, optimization_config=opt_config, - tracking_metrics=tracking_metrics, + tracking_metrics=( + tracking_metrics if not is_multi_type else [] + ), # Add later for multi-type runner=runner, status_quo=status_quo, is_test=experiment_sqa.is_test, properties=properties, auxiliary_experiments_by_purpose=auxiliary_experiments_by_purpose, + default_trial_type=default_trial_type + if default_trial_type is not None + else Keys.DEFAULT_TRIAL_TYPE.value, ) + # For multi-type experiments, set up trial_type_to_runner and add + # tracking metrics with their trial types + if is_multi_type: + experiment._trial_type_to_runner = trial_type_to_runner + sqa_metric_dict = {metric.name: metric for metric in experiment_sqa.metrics} + for tracking_metric in tracking_metrics: + sqa_metric = sqa_metric_dict[tracking_metric.name] + experiment.add_tracking_metric( + tracking_metric, + trial_type=none_throws(sqa_metric.trial_type), + ) + + return experiment + def _init_mt_experiment_from_sqa( self, experiment_sqa: SQAExperiment, - ) -> MultiTypeExperiment: - """First step of conversion within experiment_from_sqa.""" - properties = dict(experiment_sqa.properties or {}) - opt_config, tracking_metrics = self.opt_config_and_tracking_metrics_from_sqa( - metrics_sqa=experiment_sqa.metrics, - pruning_target_parameterization=( - self._get_pruning_target_parameterization_from_experiment_properties( - properties=properties - ) - ), - ) - search_space = self.search_space_from_sqa( - parameters_sqa=experiment_sqa.parameters, - parameter_constraints_sqa=experiment_sqa.parameter_constraints, - ) - if search_space is None: - raise SQADecodeError("Experiment SearchSpace cannot be None.") - status_quo = ( - Arm( - parameters=experiment_sqa.status_quo_parameters, - name=experiment_sqa.status_quo_name, - ) - if experiment_sqa.status_quo_parameters is not None - else None - ) - - default_trial_type = none_throws(experiment_sqa.default_trial_type) - trial_type_to_runner = { - none_throws(sqa_runner.trial_type): self.runner_from_sqa(sqa_runner) - for sqa_runner in experiment_sqa.runners - } - if len(trial_type_to_runner) == 0: - trial_type_to_runner = {default_trial_type: None} - trial_types_with_metrics = { - metric.trial_type - for metric in experiment_sqa.metrics - if metric.trial_type - } - # trial_type_to_runner is instantiated to map all trial types to None, - # so the trial types are associated with the experiment. This is - # important for adding metrics. - trial_type_to_runner.update(dict.fromkeys(trial_types_with_metrics)) + ) -> Experiment: + """First step of conversion within experiment_from_sqa. - experiment = MultiTypeExperiment( - name=experiment_sqa.name, - description=experiment_sqa.description, - search_space=search_space, - default_trial_type=default_trial_type, - default_runner=trial_type_to_runner.get(default_trial_type), - optimization_config=opt_config, - status_quo=status_quo, - properties=properties, + This method is kept for backwards compatibility. It delegates to the + unified _init_experiment_from_sqa. + """ + return self._init_experiment_from_sqa( + experiment_sqa=experiment_sqa, + load_auxiliary_experiments=False, + reduced_state=False, ) - # pyre-ignore Imcompatible attribute type [8]: attribute _trial_type_to_runner - # has type Dict[str, Optional[Runner]] but is used as type - # Uniont[Dict[str, Optional[Runner]], Dict[str, None]] - experiment._trial_type_to_runner = trial_type_to_runner - sqa_metric_dict = {metric.name: metric for metric in experiment_sqa.metrics} - for tracking_metric in tracking_metrics: - sqa_metric = sqa_metric_dict[tracking_metric.name] - experiment.add_tracking_metric( - tracking_metric, - trial_type=none_throws(sqa_metric.trial_type), - canonical_name=sqa_metric.canonical_name, - ) - return experiment def experiment_from_sqa( self, @@ -359,14 +348,14 @@ def experiment_from_sqa( load_auxiliary_experiment: whether to load auxiliary experiments. """ subclass = (experiment_sqa.properties or {}).get(Keys.SUBCLASS) - if subclass == "MultiTypeExperiment": - experiment = self._init_mt_experiment_from_sqa(experiment_sqa) - else: - experiment = self._init_experiment_from_sqa( - experiment_sqa, - load_auxiliary_experiments=load_auxiliary_experiments, - reduced_state=reduced_state, - ) + is_multi_type = subclass == "MultiTypeExperiment" + + experiment = self._init_experiment_from_sqa( + experiment_sqa, + load_auxiliary_experiments=load_auxiliary_experiments, + reduced_state=reduced_state, + ) + trials = [ self.trial_from_sqa( trial_sqa=trial, @@ -416,17 +405,16 @@ def experiment_from_sqa( experiment._experiment_type = self.get_enum_name( value=experiment_sqa.experiment_type, enum=self.config.experiment_type_enum ) - # `_trial_type_to_runner` is set in _init_mt_experiment_from_sqa - if subclass != "MultiTypeExperiment": + # `_trial_type_to_runner` is set in _init_experiment_from_sqa for multi-type + if not is_multi_type: experiment._trial_type_to_runner = cast( dict[str, Runner | None], trial_type_to_runner ) experiment.db_id = experiment_sqa.id - # For non-MultiTypeExperiment, populate _metric_to_trial_type - # This is needed because the metrics were added directly to the experiment - # without going through the setters that populate this field. - if subclass != "MultiTypeExperiment": + # Populate _metric_to_trial_type for all experiments + # For multi-type experiments, this was already done in _init_experiment_from_sqa + if not is_multi_type: default_trial_type = Keys.DEFAULT_TRIAL_TYPE.value # Add OC metrics oc = experiment.optimization_config diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index cba1364ee7d..636ec80425f 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -29,7 +29,6 @@ from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric -from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, Objective, ScalarizedObjective from ax.core.optimization_config import ( MultiObjectiveOptimizationConfig, @@ -213,7 +212,7 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment: ] auxiliary_experiments_by_purpose[aux_exp_type] = aux_exp_jsons runners = [] - if isinstance(experiment, MultiTypeExperiment): + if experiment.is_multi_type: experiment._properties[Keys.SUBCLASS] = "MultiTypeExperiment" for trial_type, runner in experiment._trial_type_to_runner.items(): runner_sqa = self.runner_to_sqa(none_throws(runner), trial_type) @@ -221,10 +220,6 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment: for metric in tracking_metrics: metric.trial_type = experiment._metric_to_trial_type[metric.name] - if metric.name in experiment._metric_to_canonical_name: - metric.canonical_name = experiment._metric_to_canonical_name[ - metric.name - ] elif experiment.runner: runners.append( self.runner_to_sqa( diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index d0a9df77ce4..228517a42da 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -903,8 +903,6 @@ def test_mt_experiment_save_and_load(self) -> None: # pyre-fixme[16]: `Experiment` has no attribute `metric_to_trial_type`. self.assertEqual(loaded_experiment.metric_to_trial_type["m1"], "type1") self.assertEqual(loaded_experiment.metric_to_trial_type["m2"], "type2") - # pyre-fixme[16]: `Experiment` has no attribute `_metric_to_canonical_name`. - self.assertEqual(loaded_experiment._metric_to_canonical_name["m2"], "m1") self.assertEqual(len(loaded_experiment.trials), 2) def test_mt_experiment_save_and_load_skip_runners_and_metrics(self) -> None: @@ -919,8 +917,6 @@ def test_mt_experiment_save_and_load_skip_runners_and_metrics(self) -> None: # pyre-fixme[16]: `Experiment` has no attribute `metric_to_trial_type`. self.assertEqual(loaded_experiment.metric_to_trial_type["m1"], "type1") self.assertEqual(loaded_experiment.metric_to_trial_type["m2"], "type2") - # pyre-fixme[16]: `Experiment` has no attribute `_metric_to_canonical_name`. - self.assertEqual(loaded_experiment._metric_to_canonical_name["m2"], "m1") self.assertEqual(len(loaded_experiment.trials), 2) def test_experiment_new_trial(self) -> None: diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 62b36197304..7b1d51cf17e 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -34,7 +34,6 @@ from ax.core.generator_run import GeneratorRun from ax.core.map_metric import MapMetric from ax.core.metric import Metric -from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, Objective, ScalarizedObjective from ax.core.optimization_config import ( MultiObjectiveOptimizationConfig, @@ -609,13 +608,13 @@ def get_test_map_data_experiment( def get_multi_type_experiment( add_trial_type: bool = True, add_trials: bool = False, num_arms: int = 10 -) -> MultiTypeExperiment: +) -> Experiment: oc = OptimizationConfig(Objective(BraninMetric("m1", ["x1", "x2"]), minimize=True)) - experiment = MultiTypeExperiment( + experiment = Experiment( name="test_exp", search_space=get_branin_search_space(), default_trial_type="type1", - default_runner=SyntheticRunner(dummy_metadata="dummy1"), + runner=SyntheticRunner(dummy_metadata="dummy1"), optimization_config=oc, status_quo=Arm(parameters={"x1": 0.0, "x2": 0.0}), ) @@ -625,7 +624,8 @@ def get_multi_type_experiment( ) # Switch the order of variables so metric gives different results experiment.add_tracking_metric( - BraninMetric("m2", ["x2", "x1"]), trial_type="type2", canonical_name="m1" + BraninMetric("m2", ["x2", "x1"]), + trial_type="type2", ) if add_trials and add_trial_type: