diff --git a/ax/core/base_trial.py b/ax/core/base_trial.py index ee201603067..77b8160cfae 100644 --- a/ax/core/base_trial.py +++ b/ax/core/base_trial.py @@ -98,16 +98,15 @@ def __init__( self._ttl_seconds: int | None = ttl_seconds self._index: int = self._experiment._attach_trial(self, index=index) - trial_type = ( + self._trial_type: str = ( trial_type if trial_type is not None else self._experiment.default_trial_type ) - if not self._experiment.supports_trial_type(trial_type): + if not self._experiment.supports_trial_type(self._trial_type): raise ValueError( - f"Trial type {trial_type} is not supported by the experiment." + f"Trial type {self._trial_type} is not supported by the experiment." ) - self._trial_type = trial_type self.__status: TrialStatus | None = None # Uses `_status` setter, which updates trial statuses to trial indices @@ -285,7 +284,7 @@ def stop_metadata(self) -> dict[str, Any]: return self._stop_metadata @property - def trial_type(self) -> str | None: + def trial_type(self) -> str: """The type of the trial. Relevant for experiments containing different kinds of trials diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 071fe83f802..8edcbba241d 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -100,7 +100,7 @@ def __init__( default_data_type: Any = None, auxiliary_experiments_by_purpose: None | (dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]]) = None, - default_trial_type: str | None = None, + default_trial_type: str = Keys.DEFAULT_TRIAL_TYPE.value, ) -> None: """Inits Experiment. @@ -123,6 +123,8 @@ def __init__( default_data_type: Deprecated and ignored. auxiliary_experiments_by_purpose: Dictionary of auxiliary experiments for different purposes (e.g., transfer learning). + default_trial_type: Default trial type for trials on this experiment. + Defaults to Keys.DEFAULT_TRIAL_TYPE. """ if default_data_type is not None: warnings.warn( @@ -150,10 +152,16 @@ def __init__( self._properties: dict[str, Any] = properties or {} # Initialize trial type to runner mapping - self._default_trial_type = default_trial_type - self._trial_type_to_runner: dict[str | None, Runner | None] = { - default_trial_type: runner + self._default_trial_type: str = ( + default_trial_type or Keys.DEFAULT_TRIAL_TYPE.value + ) + self._trial_type_to_runner: dict[str, Runner | None] = { + self._default_trial_type: runner } + + # Maps metric names to their trial types. Every metric must have an entry. + self._metric_to_trial_type: dict[str, str] = {} + # Used to keep track of whether any trials on the experiment # specify a TTL. Since trials need to be checked for their TTL's # expiration often, having this attribute helps avoid unnecessary @@ -413,16 +421,46 @@ def runner(self) -> Runner | None: def runner(self, runner: Runner | None) -> None: """Set the default runner and update trial type mapping.""" self._runner = runner - if runner is not None: - self._trial_type_to_runner[self._default_trial_type] = runner - else: - self._trial_type_to_runner = {None: None} + self._trial_type_to_runner[self._default_trial_type] = runner @runner.deleter def runner(self) -> None: """Delete the runner.""" self._runner = None - self._trial_type_to_runner = {None: None} + self._trial_type_to_runner[self._default_trial_type] = None + + def add_trial_type(self, trial_type: str, runner: Runner) -> "Experiment": + """Add a new trial_type to be supported by this experiment. + + Args: + trial_type: The new trial_type to be added. + runner: The default runner for trials of this type. + + Returns: + The experiment with the new trial type added. + """ + if self.supports_trial_type(trial_type): + raise ValueError(f"Experiment already contains trial_type `{trial_type}`") + + self._trial_type_to_runner[trial_type] = runner + return self + + def update_runner(self, trial_type: str, runner: Runner) -> "Experiment": + """Update the default runner for an existing trial_type. + + Args: + trial_type: The trial_type to update. + runner: The new runner for trials of this type. + + Returns: + The experiment with the updated runner. + """ + if not self.supports_trial_type(trial_type): + raise ValueError(f"Experiment does not contain trial_type `{trial_type}`") + + self._trial_type_to_runner[trial_type] = runner + self._runner = runner + return self @property def parameters(self) -> dict[str, Parameter]: @@ -489,13 +527,25 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None: f"`{Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF.value}` " "property that is set to `True` on this experiment." ) + + # Remove old OC metrics from trial type mapping + prev_optimization_config = self._optimization_config + if prev_optimization_config is not None: + for metric_name in prev_optimization_config.metrics.keys(): + self._metric_to_trial_type.pop(metric_name, None) + for metric_name in optimization_config.metrics.keys(): if metric_name in self._tracking_metrics: self.remove_tracking_metric(metric_name) + # add metrics from the previous optimization config that are not in the new # optimization config as tracking metrics - prev_optimization_config = self._optimization_config self._optimization_config = optimization_config + + # Map new OC metrics to default trial type + for metric_name in optimization_config.metrics.keys(): + self._metric_to_trial_type[metric_name] = self._default_trial_type + if prev_optimization_config is not None: metrics_to_track = ( set(prev_optimization_config.metrics.keys()) @@ -505,6 +555,16 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None: for metric_name in metrics_to_track: self.add_tracking_metric(prev_optimization_config.metrics[metric_name]) + # Clean up any stale entries in _metric_to_trial_type that don't correspond + # to actual metrics (can happen when same optimization_config object is + # mutated and reassigned). + current_metric_names = set(self.metrics.keys()) + stale_metric_names = ( + set(self._metric_to_trial_type.keys()) - current_metric_names + ) + for metric_name in stale_metric_names: + self._metric_to_trial_type.pop(metric_name, None) + @property def is_moo_problem(self) -> bool: """Whether the experiment's optimization config contains multiple objectives.""" @@ -553,12 +613,25 @@ def immutable_search_space_and_opt_config(self) -> bool: def tracking_metrics(self) -> list[Metric]: return list(self._tracking_metrics.values()) - def add_tracking_metric(self, metric: Metric) -> Self: + def add_tracking_metric( + self, + metric: Metric, + trial_type: str | None = None, + ) -> Self: """Add a new metric to the experiment. Args: metric: Metric to be added. + trial_type: The trial type for which this metric is used. If not + provided, defaults to the experiment's default trial type. """ + effective_trial_type = ( + trial_type if trial_type is not None else self._default_trial_type + ) + + if not self.supports_trial_type(effective_trial_type): + raise ValueError(f"`{effective_trial_type}` is not a supported trial type.") + if metric.name in self._tracking_metrics: raise ValueError( f"Metric `{metric.name}` already defined on experiment. " @@ -574,9 +647,14 @@ def add_tracking_metric(self, metric: Metric) -> Self: ) self._tracking_metrics[metric.name] = metric + self._metric_to_trial_type[metric.name] = effective_trial_type return self - def add_tracking_metrics(self, metrics: list[Metric]) -> Self: + def add_tracking_metrics( + self, + metrics: list[Metric], + metrics_to_trial_types: dict[str, str] | None = None, + ) -> Self: """Add a list of new metrics to the experiment. If any of the metrics are already defined on the experiment, @@ -584,23 +662,58 @@ def add_tracking_metrics(self, metrics: list[Metric]) -> Self: Args: metrics: Metrics to be added. + metrics_to_trial_types: Optional mapping from metric names to + corresponding trial types. If not provided for a metric, + the experiment's default trial type is used. """ - # Before setting any metrics, we validate none are already on - # the experiment + metrics_to_trial_types = metrics_to_trial_types or {} for metric in metrics: - self.add_tracking_metric(metric) + self.add_tracking_metric( + metric=metric, + trial_type=metrics_to_trial_types.get(metric.name), + ) return self - def update_tracking_metric(self, metric: Metric) -> Self: + def update_tracking_metric( + self, + metric: Metric, + trial_type: str | None = None, + ) -> Self: """Redefine a metric that already exists on the experiment. Args: metric: New metric definition. + trial_type: The trial type for which this metric is used. If not + provided, keeps the existing trial type mapping. """ if metric.name not in self._tracking_metrics: raise ValueError(f"Metric `{metric.name}` doesn't exist on experiment.") + # Validate trial type if provided + effective_trial_type = ( + trial_type + if trial_type is not None + else self._metric_to_trial_type.get(metric.name, self._default_trial_type) + ) + + # Check that optimization config metrics stay on default trial type + oc = self.optimization_config + oc_metrics = oc.metrics if oc else {} + if ( + metric.name in oc_metrics + and effective_trial_type != self._default_trial_type + ): + raise ValueError( + f"Metric `{metric.name}` must remain a " + f"`{self._default_trial_type}` metric because it is part of the " + "optimization_config." + ) + + if not self.supports_trial_type(effective_trial_type): + raise ValueError(f"`{effective_trial_type}` is not a supported trial type.") + self._tracking_metrics[metric.name] = metric + self._metric_to_trial_type[metric.name] = effective_trial_type return self def remove_tracking_metric(self, metric_name: str) -> Self: @@ -613,6 +726,7 @@ def remove_tracking_metric(self, metric_name: str) -> Self: raise ValueError(f"Metric `{metric_name}` doesn't exist on experiment.") del self._tracking_metrics[metric_name] + self._metric_to_trial_type.pop(metric_name, None) return self @property @@ -852,8 +966,21 @@ def _fetch_trial_data( ) -> dict[str, MetricFetchResult]: trial = self.trials[trial_index] + # If metrics are not provided, fetch all metrics on the experiment for the + # relevant trial type, or the default trial type as a fallback. Otherwise, + # fetch provided metrics. + if metrics is None: + resolved_metrics = [ + metric + for metric in list(self.metrics.values()) + if self._metric_to_trial_type.get(metric.name, self._default_trial_type) + == trial.trial_type + ] + else: + resolved_metrics = metrics + trial_data = self._lookup_or_fetch_trials_results( - trials=[trial], metrics=metrics, **kwargs + trials=[trial], metrics=resolved_metrics, **kwargs ) if trial_index in trial_data: @@ -1098,6 +1225,15 @@ def trial_indices_with_data( return trials_with_data + @property + def default_trials(self) -> set[int]: + """Return the indicies for trials of the default type.""" + return { + idx + for idx, trial in self.trials.items() + if trial.trial_type == self.default_trial_type + } + def new_trial( self, generator_run: GeneratorRun | None = None, @@ -1548,39 +1684,79 @@ def __repr__(self) -> str: # overridden in the MultiTypeExperiment class. @property - def default_trial_type(self) -> str | None: - """Default trial type assigned to trials in this experiment. - - In the base experiment class this is always None. For experiments - with multiple trial types, use the MultiTypeExperiment class. - """ + def default_trial_type(self) -> str: + """Default trial type assigned to trials in this experiment.""" return self._default_trial_type - def runner_for_trial_type(self, trial_type: str | None) -> Runner | None: + def runner_for_trial_type(self, trial_type: str) -> Runner | None: """The default runner to use for a given trial type. Looks up the appropriate runner for this trial type in the trial_type_to_runner. """ + # Special case for LONG_ and SHORT_RUN trial types, which we treat as "default" + # trial types for deployment. + if ( + trial_type == Keys.SHORT_RUN or trial_type == Keys.LONG_RUN + ) and self.supports_trial_type(trial_type=Keys.DEFAULT_TRIAL_TYPE): + return self._trial_type_to_runner[Keys.DEFAULT_TRIAL_TYPE] + if not self.supports_trial_type(trial_type): raise ValueError(f"Trial type `{trial_type}` is not supported.") if (runner := self._trial_type_to_runner.get(trial_type)) is None: return self.runner # return the default runner return runner - def supports_trial_type(self, trial_type: str | None) -> bool: + def supports_trial_type(self, trial_type: str) -> bool: """Whether this experiment allows trials of the given type. - The base experiment class only supports None. For experiments - with multiple trial types, use the MultiTypeExperiment class. + Checks if the trial type is registered in the trial_type_to_runner mapping. """ - return ( - trial_type is None - # We temporarily allow "short run" and "long run" trial - # types in single-type experiments during development of - # a new ``GenerationStrategy`` that needs them. - or trial_type == Keys.SHORT_RUN - or trial_type == Keys.LONG_RUN - ) + # Special case for LONG_ and SHORT_RUN trial types, which we treat as "default" + # trial types for deployment. + if ( + trial_type == Keys.SHORT_RUN or trial_type == Keys.LONG_RUN + ) and self.supports_trial_type(trial_type=Keys.DEFAULT_TRIAL_TYPE): + return True + + return trial_type in self._trial_type_to_runner + + @property + def is_multi_type(self) -> bool: + """Returns True if this experiment has multiple trial types registered.""" + return len(self._trial_type_to_runner) > 1 + + @property + def metric_to_trial_type(self) -> dict[str, str]: + """Read-only mapping of metric names to trial types.""" + return self._metric_to_trial_type.copy() + + def metrics_for_trial_type(self, trial_type: str) -> list[Metric]: + """Returns metrics associated with a specific trial type. + + Args: + trial_type: The trial type to get metrics for. + + Returns: + List of metrics associated with the given trial type. + """ + # Special case for LONG_ and SHORT_RUN trial types, which we treat as "default" + # trial types for deployment. + if ( + trial_type == Keys.SHORT_RUN or trial_type == Keys.LONG_RUN + ) and self.supports_trial_type(trial_type=Keys.DEFAULT_TRIAL_TYPE): + return [ + self.metrics[metric_name] + for metric_name, metric_trial_type in self._metric_to_trial_type.items() + if metric_trial_type == Keys.DEFAULT_TRIAL_TYPE + ] + + if not self.supports_trial_type(trial_type): + raise ValueError(f"Trial type `{trial_type}` is not supported.") + return [ + self.metrics[metric_name] + for metric_name, metric_trial_type in self._metric_to_trial_type.items() + if metric_trial_type == trial_type + ] def attach_trial( self, diff --git a/ax/core/multi_type_experiment.py b/ax/core/multi_type_experiment.py index b6c552d2bcb..754a0d626ab 100644 --- a/ax/core/multi_type_experiment.py +++ b/ax/core/multi_type_experiment.py @@ -6,334 +6,27 @@ # pyre-strict -from collections.abc import Iterable, Sequence -from typing import Any, Self +from collections.abc import Sequence -from ax.core.arm import Arm -from ax.core.base_trial import BaseTrial, TrialStatus -from ax.core.data import Data +from ax.core.base_trial import BaseTrial from ax.core.experiment import Experiment -from ax.core.metric import Metric, MetricFetchResult -from ax.core.optimization_config import OptimizationConfig -from ax.core.runner import Runner -from ax.core.search_space import SearchSpace -from ax.utils.common.docutils import copy_doc -from pyre_extensions import none_throws +from ax.core.trial_status import TrialStatus class MultiTypeExperiment(Experiment): """Class for experiment with multiple trial types. - A canonical use case for this is tuning a large production system - with limited evaluation budget and a simulator which approximates - evaluations on the main system. Trial deployment and data fetching - is separate for the two systems, but the final data is combined and - fed into multi-task models. + .. deprecated:: + The `MultiTypeExperiment` class is deprecated. Use `Experiment` with + `default_trial_type` parameter instead. All multi-type experiment + functionality has been moved to the base `Experiment` class. - See the Multi-Task Modeling tutorial for more details. - - Attributes: - name: Name of the experiment. - description: Description of the experiment. """ - def __init__( - self, - name: str, - search_space: SearchSpace, - default_trial_type: str, - default_runner: Runner | None, - optimization_config: OptimizationConfig | None = None, - tracking_metrics: list[Metric] | None = None, - status_quo: Arm | None = None, - description: str | None = None, - is_test: bool = False, - experiment_type: str | None = None, - properties: dict[str, Any] | None = None, - default_data_type: Any = None, - ) -> None: - """Inits Experiment. - - Args: - name: Name of the experiment. - search_space: Search space of the experiment. - default_trial_type: Default type for trials on this experiment. - default_runner: Default runner for trials of the default type. - optimization_config: Optimization config of the experiment. - tracking_metrics: Additional tracking metrics not used for optimization. - These are associated with the default trial type. - runner: Default runner used for trials on this experiment. - status_quo: Arm representing existing "control" arm. - description: Description of the experiment. - is_test: Convenience metadata tracker for the user to mark test experiments. - experiment_type: The class of experiments this one belongs to. - properties: Dictionary of this experiment's properties. - default_data_type: Deprecated and ignored. - """ - - # Specifies which trial type each metric belongs to - self._metric_to_trial_type: dict[str, str] = {} - - # Maps certain metric names to a canonical name. Useful for ancillary trial - # types' metrics, to specify which primary metrics they correspond to - # (e.g. 'comment_prediction' => 'comment') - self._metric_to_canonical_name: dict[str, str] = {} - - # call super.__init__() after defining fields above, because we need - # them to be populated before optimization config is set - super().__init__( - name=name, - search_space=search_space, - optimization_config=optimization_config, - status_quo=status_quo, - description=description, - is_test=is_test, - experiment_type=experiment_type, - properties=properties, - tracking_metrics=tracking_metrics, - runner=default_runner, - default_trial_type=default_trial_type, - default_data_type=default_data_type, - ) - - def add_trial_type(self, trial_type: str, runner: Runner) -> Self: - """Add a new trial_type to be supported by this experiment. - - Args: - trial_type: The new trial_type to be added. - runner: The default runner for trials of this type. - """ - if self.supports_trial_type(trial_type): - raise ValueError(f"Experiment already contains trial_type `{trial_type}`") - - self._trial_type_to_runner[trial_type] = runner - return self - - # pyre-fixme [56]: Pyre was not able to infer the type of the decorator - # `Experiment.optimization_config.setter`. - @Experiment.optimization_config.setter - def optimization_config(self, optimization_config: OptimizationConfig) -> None: - # pyre-fixme [16]: `Optional` has no attribute `fset`. - Experiment.optimization_config.fset(self, optimization_config) - for metric_name in optimization_config.metrics.keys(): - # Optimization config metrics are required to be the default trial type - # currently. TODO: remove that restriction (T202797235) - self._metric_to_trial_type[metric_name] = none_throws( - self.default_trial_type - ) - - def update_runner(self, trial_type: str, runner: Runner) -> Self: - """Update the default runner for an existing trial_type. - - Args: - trial_type: The new trial_type to be added. - runner: The new runner for trials of this type. - """ - if not self.supports_trial_type(trial_type): - raise ValueError(f"Experiment does not contain trial_type `{trial_type}`") - - self._trial_type_to_runner[trial_type] = runner - self._runner = runner - return self - - def add_tracking_metric( - self, - metric: Metric, - trial_type: str | None = None, - canonical_name: str | None = None, - ) -> Self: - """Add a new metric to the experiment. - - Args: - metric: The metric to add. - trial_type: The trial type for which this metric is used. - canonical_name: The default metric for which this metric is a proxy. - """ - if trial_type is None: - trial_type = self._default_trial_type - if not self.supports_trial_type(trial_type): - raise ValueError(f"`{trial_type}` is not a supported trial type.") - - super().add_tracking_metric(metric) - self._metric_to_trial_type[metric.name] = none_throws(trial_type) - if canonical_name is not None: - self._metric_to_canonical_name[metric.name] = canonical_name - return self - - def add_tracking_metrics( - self, - metrics: list[Metric], - metrics_to_trial_types: dict[str, str] | None = None, - canonical_names: dict[str, str] | None = None, - ) -> Experiment: - """Add a list of new metrics to the experiment. - - If any of the metrics are already defined on the experiment, - we raise an error and don't add any of them to the experiment - - Args: - metrics: Metrics to be added. - metrics_to_trial_types: The mapping from metric names to corresponding - trial types for each metric. If provided, the metrics will be - added to their trial types. If not provided, then the default - trial type will be used. - canonical_names: A mapping of metric names to their - canonical names(The default metrics for which the metrics are - proxies.) - - Returns: - The experiment with the added metrics. - """ - metrics_to_trial_types = metrics_to_trial_types or {} - canonical_name = None - for metric in metrics: - if canonical_names is not None: - canonical_name = none_throws(canonical_names).get(metric.name, None) - - self.add_tracking_metric( - metric=metric, - trial_type=metrics_to_trial_types.get( - metric.name, self._default_trial_type - ), - canonical_name=canonical_name, - ) - return self - - def update_tracking_metric( - self, - metric: Metric, - trial_type: str | None = None, - canonical_name: str | None = None, - ) -> Self: - """Update an existing metric on the experiment. - - Args: - metric: The metric to add. - trial_type: The trial type for which this metric is used. Defaults to - the current trial type of the metric (if set), or the default trial - type otherwise. - canonical_name: The default metric for which this metric is a proxy. - """ - # Default to the existing trial type if not specified - if trial_type is None: - trial_type = self._metric_to_trial_type.get( - metric.name, self._default_trial_type - ) - oc = self.optimization_config - oc_metrics = oc.metrics if oc else [] - if metric.name in oc_metrics and trial_type != self._default_trial_type: - raise ValueError( - f"Metric `{metric.name}` must remain a " - f"`{self._default_trial_type}` metric because it is part of the " - "optimization_config." - ) - elif not self.supports_trial_type(trial_type): - raise ValueError(f"`{trial_type}` is not a supported trial type.") - - super().update_tracking_metric(metric) - self._metric_to_trial_type[metric.name] = none_throws(trial_type) - if canonical_name is not None: - self._metric_to_canonical_name[metric.name] = canonical_name - return self - - @copy_doc(Experiment.remove_tracking_metric) - def remove_tracking_metric(self, metric_name: str) -> Self: - if metric_name not in self._tracking_metrics: - raise ValueError(f"Metric `{metric_name}` doesn't exist on experiment.") - - # Required fields - del self._tracking_metrics[metric_name] - del self._metric_to_trial_type[metric_name] - - # Optional - if metric_name in self._metric_to_canonical_name: - del self._metric_to_canonical_name[metric_name] - return self - - @copy_doc(Experiment.fetch_data) - def fetch_data( - self, - trial_indices: Iterable[int] | None = None, - metrics: list[Metric] | None = None, - **kwargs: Any, - ) -> Data: - # TODO: make this more efficient for fetching - # data for multiple trials of the same type - # by overriding Experiment._lookup_or_fetch_trials_results - return Data.from_multiple_data( - [ - ( - trial.fetch_data(**kwargs, metrics=metrics) - if trial.status.expecting_data - else Data() - ) - for trial in self.trials.values() - ] - ) - - @copy_doc(Experiment._fetch_trial_data) - def _fetch_trial_data( - self, trial_index: int, metrics: list[Metric] | None = None, **kwargs: Any - ) -> dict[str, MetricFetchResult]: - trial = self.trials[trial_index] - metrics = [ - metric - for metric in (metrics or self.metrics.values()) - if self.metric_to_trial_type[metric.name] == trial.trial_type - ] - # Invoke parent's fetch method using only metrics for this trial_type - return super()._fetch_trial_data(trial.index, metrics=metrics, **kwargs) - - @property - def default_trials(self) -> set[int]: - """Return the indicies for trials of the default type.""" - return { - idx - for idx, trial in self.trials.items() - if trial.trial_type == self.default_trial_type - } - - @property - def metric_to_trial_type(self) -> dict[str, str]: - """Map metrics to trial types. - - Adds in default trial type for OC metrics to custom defined trial types.. - """ - opt_config_types = { - metric_name: self.default_trial_type - for metric_name in self.optimization_config.metrics.keys() - } - return {**opt_config_types, **self._metric_to_trial_type} - - # -- Overridden functions from Base Experiment Class -- - @property - def default_trial_type(self) -> str | None: - """Default trial type assigned to trials in this experiment.""" - return self._default_trial_type - - def metrics_for_trial_type(self, trial_type: str) -> list[Metric]: - """The default runner to use for a given trial type. - - Looks up the appropriate runner for this trial type in the trial_type_to_runner. - """ - if not self.supports_trial_type(trial_type): - raise ValueError(f"Trial type `{trial_type}` is not supported.") - return [ - self.metrics[metric_name] - for metric_name, metric_trial_type in self._metric_to_trial_type.items() - if metric_trial_type == trial_type - ] - - def supports_trial_type(self, trial_type: str | None) -> bool: - """Whether this experiment allows trials of the given type. - - Only trial types defined in the trial_type_to_runner are allowed. - """ - return trial_type in self._trial_type_to_runner.keys() - def filter_trials_by_type( - trials: Sequence[BaseTrial], trial_type: str | None + trials: Sequence[BaseTrial], + trial_type: str | None, ) -> list[BaseTrial]: """Filter trials by trial type if provided. @@ -352,7 +45,9 @@ def filter_trials_by_type( def get_trial_indices_for_statuses( - experiment: Experiment, statuses: set[TrialStatus], trial_type: str | None = None + experiment: Experiment, + statuses: set[TrialStatus], + trial_type: str | None = None, ) -> set[int]: """Get trial indices for a set of statuses. diff --git a/ax/core/tests/test_batch_trial.py b/ax/core/tests/test_batch_trial.py index 8a2c201f4b9..97901dd3be9 100644 --- a/ax/core/tests/test_batch_trial.py +++ b/ax/core/tests/test_batch_trial.py @@ -410,8 +410,9 @@ def test_clone_to(self, _) -> None: cloned_batch._time_created = batch._time_created self.assertEqual(cloned_batch, batch) # test cloning with clear_trial_type=True + # When clear_trial_type=True, uses experiment's default_trial_type cloned_batch = batch.clone_to(clear_trial_type=True) - self.assertIsNone(cloned_batch.trial_type) + self.assertEqual(cloned_batch.trial_type, self.experiment.default_trial_type) self.assertEqual( cloned_batch.generation_method_str, f"{MANUAL_GENERATION_METHOD_STR}, {STATUS_QUO_GENERATION_METHOD_STR}", diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index 9903fbeb578..c87ec407bb5 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -1417,7 +1417,7 @@ def test_clone_with(self) -> None: cloned_experiment._time_created = experiment._time_created self.assertEqual(cloned_experiment, experiment) - # test clear_trial_type + # test clear_trial_type - uses experiment's default_trial_type experiment = get_branin_experiment( with_batch=True, num_batch_trial=1, @@ -1427,7 +1427,10 @@ def test_clone_with(self) -> None: with self.assertRaisesRegex(ValueError, ".* foo is not supported by the exp"): experiment.clone_with() cloned_experiment = experiment.clone_with(clear_trial_type=True) - self.assertIsNone(cloned_experiment.trials[0].trial_type) + self.assertEqual( + cloned_experiment.trials[0].trial_type, + cloned_experiment.default_trial_type, + ) # Test cloning with specific properties to keep experiment_w_props = get_branin_experiment() diff --git a/ax/core/tests/test_multi_type_experiment.py b/ax/core/tests/test_multi_type_experiment.py index d8a28e60ac2..e18b800985e 100644 --- a/ax/core/tests/test_multi_type_experiment.py +++ b/ax/core/tests/test_multi_type_experiment.py @@ -48,18 +48,8 @@ def test_MTExperimentFlow(self) -> None: self.assertEqual(b2.run_metadata["dummy_metadata"], "dummy3") df = self.experiment.fetch_data().df - for _, row in df.iterrows(): - # Make sure proper metric present for each batch only - self.assertEqual( - row["metric_name"], "m1" if row["trial_index"] == 0 else "m2" - ) - arm_0_slice = df.loc[df["arm_name"] == "0_0"] - self.assertNotEqual( - float(arm_0_slice[df["trial_index"] == 0]["mean"].item()), - float(arm_0_slice[df["trial_index"] == 1]["mean"].item()), - ) - self.assertEqual(len(df), 2 * n) + self.assertEqual(len(df), 4 * n) self.assertEqual(self.experiment.default_trials, {0}) # Set 2 metrics to be equal self.experiment.update_tracking_metric( @@ -68,13 +58,21 @@ def test_MTExperimentFlow(self) -> None: df = self.experiment.fetch_data().df arm_0_slice = df.loc[df["arm_name"] == "0_0"] self.assertAlmostEqual( - float(arm_0_slice[df["trial_index"] == 0]["mean"].item()), - float(arm_0_slice[df["trial_index"] == 1]["mean"].item()), + float( + arm_0_slice[(df["trial_index"] == 0) & (df["metric_name"] == "m2")][ + "mean" + ].item() + ), + float( + arm_0_slice[(df["trial_index"] == 1) & (df["metric_name"] == "m2")][ + "mean" + ].item() + ), places=10, ) def test_Repr(self) -> None: - self.assertEqual(str(self.experiment), "MultiTypeExperiment(test_exp)") + self.assertEqual(str(self.experiment), "Experiment(test_exp)") def test_Eq(self) -> None: exp2 = get_multi_type_experiment() @@ -83,24 +81,19 @@ def test_Eq(self) -> None: self.assertTrue(self.experiment == exp2) self.experiment.add_tracking_metric( - BraninMetric("m3", ["x2", "x1"]), trial_type="type1", canonical_name="m4" + BraninMetric("m3", ["x2", "x1"]), + trial_type="type1", ) # Test different set of metrics self.assertFalse(self.experiment == exp2) exp2.add_tracking_metric( - BraninMetric("m3", ["x2", "x1"]), trial_type="type1", canonical_name="m5" - ) - - # Test different metric definitions - self.assertFalse(self.experiment == exp2) - - exp2.update_tracking_metric( - BraninMetric("m3", ["x2", "x1"]), trial_type="type1", canonical_name="m4" + BraninMetric("m3", ["x2", "x1"]), + trial_type="type1", ) - # Should be the same + # Both have the same metrics now, should be equal self.assertTrue(self.experiment == exp2) exp2.remove_tracking_metric("m3") diff --git a/ax/core/tests/test_observation.py b/ax/core/tests/test_observation.py index c18ccf25deb..0817cac286a 100644 --- a/ax/core/tests/test_observation.py +++ b/ax/core/tests/test_observation.py @@ -268,6 +268,8 @@ def test_ObservationsFromData(self) -> None: } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} + experiment.default_trial_type = "default" + experiment.supports_trial_type = Mock(return_value=True) trials = { obs["trial_index"]: Trial( experiment, GeneratorRun(arms=[arms[obs["arm_name"]]]) @@ -525,6 +527,8 @@ def test_ObservationsFromDataAbandoned(self) -> None: } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} + experiment.default_trial_type = "default" + experiment.supports_trial_type = Mock(return_value=True) trials = { obs["trial_index"]: ( Trial(experiment, GeneratorRun(arms=[arms[obs["arm_name"]]])) @@ -637,6 +641,8 @@ def test_ObservationsFromDataWithSomeMissingTimes(self) -> None: } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} + experiment.default_trial_type = "default" + experiment.supports_trial_type = Mock(return_value=True) trials = { obs["trial_index"]: Trial( experiment, GeneratorRun(arms=[arms[obs["arm_name"]]]) @@ -744,6 +750,8 @@ def test_ObservationsFromDataWithDifferentTimesSingleTrial( } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} + experiment.default_trial_type = "default" + experiment.supports_trial_type = Mock(return_value=True) trials = { 0: BatchTrial(experiment, GeneratorRun(arms=list(arms_by_name.values()))) } @@ -885,6 +893,8 @@ def test_ObservationsWithCandidateMetadata(self) -> None: } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} + experiment.default_trial_type = "default" + experiment.supports_trial_type = Mock(return_value=True) trials = { obs["trial_index"]: Trial( experiment, diff --git a/ax/core/tests/test_trial.py b/ax/core/tests/test_trial.py index 33eef439201..0fa7d3aecfe 100644 --- a/ax/core/tests/test_trial.py +++ b/ax/core/tests/test_trial.py @@ -460,9 +460,9 @@ def test_clone_to(self) -> None: # check that trial_type is cloned correctly self.assertEqual(new_trial.trial_type, "foo") - # test clear_trial_type + # test clear_trial_type - uses experiment's default_trial_type new_trial = self.trial.clone_to(clear_trial_type=True) - self.assertIsNone(new_trial.trial_type) + self.assertEqual(new_trial.trial_type, new_experiment.default_trial_type) def test_update_trial_status_on_clone(self) -> None: for status in [ diff --git a/ax/orchestration/orchestrator.py b/ax/orchestration/orchestrator.py index 8d4804893b7..a7a3401c215 100644 --- a/ax/orchestration/orchestrator.py +++ b/ax/orchestration/orchestrator.py @@ -27,7 +27,6 @@ from ax.core.multi_type_experiment import ( filter_trials_by_type, get_trial_indices_for_statuses, - MultiTypeExperiment, ) from ax.core.runner import Runner from ax.core.trial import Trial @@ -58,7 +57,7 @@ set_ax_logger_levels, ) from ax.utils.common.timeutils import current_timestamp_in_millis -from pyre_extensions import assert_is_instance, none_throws +from pyre_extensions import none_throws NOT_IMPLEMENTED_IN_BASE_CLASS_MSG = """ \ @@ -364,18 +363,14 @@ def options(self, options: OrchestratorOptions) -> None: self._validate_runner_and_implemented_metrics(experiment=self.experiment) @property - def trial_type(self) -> str | None: + def trial_type(self) -> str: """Trial type for the experiment this Orchestrator is running. - This returns None if the experiment is not a MultitypeExperiment - Returns: - Trial type for the experiment this Orchestrator is running if the - experiment is a MultiTypeExperiment and None otherwise. + Trial type for the experiment this Orchestrator is running. + Defaults to Keys.DEFAULT_TRIAL_TYPE if not specified. """ - if isinstance(self.experiment, MultiTypeExperiment): - return self.options.mt_experiment_trial_type - return None + return self.options.mt_experiment_trial_type or Keys.DEFAULT_TRIAL_TYPE.value @property def running_trials(self) -> list[BaseTrial]: @@ -487,12 +482,8 @@ def runner(self) -> Runner: """``Runner`` specified on the experiment associated with this ``Orchestrator`` instance. """ - if self.trial_type is not None: - runner = assert_is_instance( - self.experiment, MultiTypeExperiment - ).runner_for_trial_type(trial_type=none_throws(self.trial_type)) - else: - runner = self.experiment.runner + runner = self.experiment.runner_for_trial_type(trial_type=self.trial_type) + if runner is None: raise UnsupportedError( "`Orchestrator` requires that experiment specifies a `Runner`." @@ -1621,11 +1612,7 @@ def _validate_options(self, options: OrchestratorOptions) -> None: "will be unable to fetch intermediate results with which to " "evaluate early stopping criteria." ) - if isinstance(self.experiment, MultiTypeExperiment): - if options.mt_experiment_trial_type is None: - raise UserInputError( - "Must specify `mt_experiment_trial_type` for MultiTypeExperiment." - ) + if options.mt_experiment_trial_type is not None: if not self.experiment.supports_trial_type( options.mt_experiment_trial_type ): @@ -1633,11 +1620,6 @@ def _validate_options(self, options: OrchestratorOptions) -> None: "Experiment does not support trial type " f"{options.mt_experiment_trial_type}." ) - elif options.mt_experiment_trial_type is not None: - raise UserInputError( - "`mt_experiment_trial_type` must be None unless the experiment is a " - "MultiTypeExperiment." - ) def _get_max_pending_trials(self) -> int: """Returns the maximum number of pending trials specified in the options, or @@ -2034,11 +2016,11 @@ def _fetch_and_process_trials_data_results( try: kwargs = deepcopy(self.options.fetch_kwargs) - if self.trial_type is not None: - metrics = assert_is_instance( - self.experiment, MultiTypeExperiment - ).metrics_for_trial_type(trial_type=none_throws(self.trial_type)) - kwargs["metrics"] = metrics + metrics = self.experiment.metrics_for_trial_type( + trial_type=none_throws(self.trial_type) + ) + kwargs["metrics"] = metrics + results = self.experiment.fetch_trials_data_results( trial_indices=trial_indices, **kwargs, diff --git a/ax/orchestration/tests/test_orchestrator.py b/ax/orchestration/tests/test_orchestrator.py index 67f982ca1a1..7d04acb154f 100644 --- a/ax/orchestration/tests/test_orchestrator.py +++ b/ax/orchestration/tests/test_orchestrator.py @@ -1315,9 +1315,9 @@ def _should_stop_trials_early( len(res_list[1]["trials_early_stopped_so_far"]), ) - looked_up_data = orchestrator.experiment.lookup_data() fetched_data = orchestrator.experiment.fetch_data() - num_metrics = 2 + looked_up_data = orchestrator.experiment.lookup_data() + num_metrics = len(orchestrator.experiment.metrics) expected_num_rows = num_metrics * total_trials # There are 3 trials and two metrics for "type1" for MT experiments self.assertEqual(len(looked_up_data.df), expected_num_rows) @@ -1329,7 +1329,7 @@ def _should_stop_trials_early( # longer and gets results for an extra timestamp. # For MultiTypeExperiment there are two metrics # for trial type "type1" - expected_num_rows = 7 + expected_num_rows = expected_num_rows + 1 self.assertEqual(len(looked_up_data.full_df), expected_num_rows) self.assertEqual(len(fetched_data.full_df), expected_num_rows) ess = orchestrator.options.early_stopping_strategy @@ -2736,12 +2736,9 @@ def test_validate_options_not_none_mt_trial_type( self, msg: str | None = None ) -> None: # test that error is raised if `mt_experiment_trial_type` is not - # compatible with the type of experiment (single or multi-type) + # a supported trial type for this experiment if msg is None: - msg = ( - "`mt_experiment_trial_type` must be None unless the experiment is a " - "MultiTypeExperiment." - ) + msg = "Experiment does not support trial type type1." options = OrchestratorOptions( init_seconds_between_polls=0, # No wait bw polls so test is fast. batch_size=10, @@ -2752,7 +2749,7 @@ def test_validate_options_not_none_mt_trial_type( ), ) gs = self.two_sobol_steps_GS - with self.assertRaisesRegex(UserInputError, msg): + with self.assertRaisesRegex(ValueError, msg): Orchestrator( experiment=self.branin_experiment, generation_strategy=gs, @@ -2845,7 +2842,7 @@ def test_terminate_if_status_quo_infeasible(self) -> None: class TestAxOrchestratorMultiTypeExperiment(TestAxOrchestrator): # After D80128678, choose_generation_strategy_legacy returns node-based GS. EXPECTED_orchestrator_REPR: str = ( - "Orchestrator(experiment=MultiTypeExperiment(branin_test_experiment), " + "Orchestrator(experiment=Experiment(branin_test_experiment), " "generation_strategy=GenerationStrategy(" "name='GenerationStep_0_Sobol+GenerationStep_1_BoTorch', " "nodes=[GenerationNode(name='GenerationStep_0_Sobol', " @@ -2897,13 +2894,13 @@ def setUp(self) -> None: trial_type="type1", runner=RunnerToAllowMultipleMapMetricFetches() ) - self.branin_experiment_no_impl_runner_or_metrics = MultiTypeExperiment( + self.branin_experiment_no_impl_runner_or_metrics = Experiment( search_space=get_branin_search_space(), optimization_config=OptimizationConfig( Objective(Metric(name="branin"), minimize=True) ), default_trial_type="type1", - default_runner=None, + runner=None, name="branin_experiment_no_impl_runner_or_metrics", ) self.sobol_MBM_GS = choose_generation_strategy_legacy( @@ -3010,10 +3007,11 @@ def test_fetch_and_process_trials_data_results_failed_non_objective( def test_validate_options_not_none_mt_trial_type( self, msg: str | None = None ) -> None: - # test if a MultiTypeExperiment with `mt_experiment_trial_type=None` - self.orchestrator_options_kwargs["mt_experiment_trial_type"] = None + # test that error is raised if `mt_experiment_trial_type` is not + # a supported trial type for this experiment (using an invalid type) + self.orchestrator_options_kwargs["mt_experiment_trial_type"] = "invalid_type" super().test_validate_options_not_none_mt_trial_type( - msg="Must specify `mt_experiment_trial_type` for MultiTypeExperiment." + msg="Experiment does not support trial type invalid_type." ) def test_run_n_trials_single_step_existing_experiment( diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 592c316f037..fd1d4794886 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -24,7 +24,6 @@ from ax.core.evaluations_to_data import raw_evaluations_to_data from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun -from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, Objective from ax.core.observation import ObservationFeatures from ax.core.runner import Runner @@ -458,15 +457,11 @@ def add_tracking_metrics( for metric_name in metric_names ] - if isinstance(self.experiment, MultiTypeExperiment): - experiment = assert_is_instance(self.experiment, MultiTypeExperiment) - experiment.add_tracking_metrics( - metrics=metric_objects, - metrics_to_trial_types=metrics_to_trial_types, - canonical_names=canonical_names, - ) - else: - self.experiment.add_tracking_metrics(metrics=metric_objects) + self.experiment.add_tracking_metrics( + metrics=metric_objects, + metrics_to_trial_types=metrics_to_trial_types, + **({"canonical_names": canonical_names} if canonical_names else {}), + ) @copy_doc(Experiment.remove_tracking_metric) def remove_tracking_metric(self, metric_name: str) -> None: diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index afe81f05d01..836a5b08bc5 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -21,9 +21,9 @@ from ax.adapter.registry import Cont_X_trans, Generators from ax.core.arm import Arm from ax.core.data import Data, MAP_KEY +from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric -from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.optimization_config import MultiObjectiveOptimizationConfig from ax.core.outcome_constraint import ObjectiveThreshold, OutcomeConstraint from ax.core.parameter import ( @@ -974,8 +974,8 @@ def test_create_multitype_experiment(self) -> None: default_runner=SyntheticRunner(), ) - self.assertEqual(ax_client.experiment.__class__.__name__, "MultiTypeExperiment") - experiment = assert_is_instance(ax_client.experiment, MultiTypeExperiment) + self.assertEqual(ax_client.experiment.__class__.__name__, "Experiment") + experiment = assert_is_instance(ax_client.experiment, Experiment) self.assertEqual( experiment._trial_type_to_runner["test_trial_type"].__class__.__name__, "SyntheticRunner", diff --git a/ax/service/tests/test_instantiation_utils.py b/ax/service/tests/test_instantiation_utils.py index 160518ec457..c75c3ca3f5c 100644 --- a/ax/service/tests/test_instantiation_utils.py +++ b/ax/service/tests/test_instantiation_utils.py @@ -372,7 +372,7 @@ def test_make_multitype_experiment_with_default_trial_type(self) -> None: default_trial_type="test_trial_type", default_runner=SyntheticRunner(), ) - self.assertEqual(experiment.__class__.__name__, "MultiTypeExperiment") + self.assertEqual(experiment.__class__.__name__, "Experiment") def test_make_single_type_experiment_with_no_default_trial_type(self) -> None: experiment = InstantiationBase.make_experiment( diff --git a/ax/service/tests/test_report_utils.py b/ax/service/tests/test_report_utils.py index e857215be50..044c9fca78c 100644 --- a/ax/service/tests/test_report_utils.py +++ b/ax/service/tests/test_report_utils.py @@ -199,9 +199,9 @@ def test_exp_to_df_with_failure(self) -> None: self.assertEqual(f"{fail_reason}...", df["reason"].iloc[0]) def test_exp_to_df(self) -> None: - # MultiTypeExperiment should fail + # Experiments with multiple trial types should fail exp = get_multi_type_experiment() - with self.assertRaisesRegex(ValueError, "MultiTypeExperiment"): + with self.assertRaisesRegex(ValueError, "multiple trial types"): exp_to_df(exp=exp) # exp with no trials should return empty results diff --git a/ax/service/utils/instantiation.py b/ax/service/utils/instantiation.py index d08849d6f0d..f00bbc4f57f 100644 --- a/ax/service/utils/instantiation.py +++ b/ax/service/utils/instantiation.py @@ -17,7 +17,6 @@ from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose from ax.core.experiment import Experiment from ax.core.metric import Metric -from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, Objective from ax.core.observation import ObservationFeatures from ax.core.optimization_config import ( @@ -850,12 +849,9 @@ def make_experiment( auxiliary_experiments_by_purpose: Dictionary of auxiliary experiments for different use cases (e.g., transfer learning). default_trial_type: The default trial type if multiple - trial types are intended to be used in the experiment. If specified, - a MultiTypeExperiment will be created. Otherwise, a single-type - Experiment will be created. + trial types are intended to be used in the experiment. default_runner: The default runner in this experiment. - This only applies to MultiTypeExperiment (when default_trial_type - is specified). + This is required if default_trial_type is specified. is_test: Whether this experiment will be a test experiment (useful for marking test experiments in storage etc). Defaults to False. @@ -863,7 +859,7 @@ def make_experiment( if (default_trial_type is None) != (default_runner is None): raise ValueError( "Must specify both default_trial_type and default_runner if " - "using a MultiTypeExperiment." + "using multiple trial types." ) status_quo_arm = None if status_quo is None else Arm(parameters=status_quo) @@ -900,23 +896,6 @@ def make_experiment( if owners is not None: properties["owners"] = owners - if default_trial_type is not None: - return MultiTypeExperiment( - name=none_throws(name), - search_space=cls.make_search_space( - parameters=parameters, parameter_constraints=parameter_constraints - ), - default_trial_type=none_throws(default_trial_type), - default_runner=none_throws(default_runner), - optimization_config=optimization_config, - tracking_metrics=tracking_metrics, - status_quo=status_quo_arm, - description=description, - is_test=is_test, - experiment_type=experiment_type, - properties=properties, - ) - return Experiment( name=name, description=description, @@ -929,6 +908,7 @@ def make_experiment( auxiliary_experiments_by_purpose=auxiliary_experiments_by_purpose, is_test=is_test, runner=default_runner, + default_trial_type=default_trial_type or Keys.DEFAULT_TRIAL_TYPE.value, ) @classmethod diff --git a/ax/service/utils/report_utils.py b/ax/service/utils/report_utils.py index 4ddabe4c211..3f434504041 100644 --- a/ax/service/utils/report_utils.py +++ b/ax/service/utils/report_utils.py @@ -32,7 +32,6 @@ from ax.core.generator_run import GeneratorRunType from ax.core.map_metric import MapMetric from ax.core.metric import Metric -from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, ScalarizedObjective from ax.core.optimization_config import ( MultiObjectiveOptimizationConfig, @@ -787,8 +786,11 @@ def exp_to_df( ) # Accept Experiment and SimpleExperiment - if isinstance(exp, MultiTypeExperiment): - raise ValueError("Cannot transform MultiTypeExperiments to DataFrames.") + # Reject experiments with multiple trial types as they need special handling + if len(exp._trial_type_to_runner) > 1: + raise ValueError( + "Cannot transform experiments with multiple trial types to DataFrames." + ) key_components = ["trial_index", "arm_name"] diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 0beccd47c09..371277667be 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -59,6 +59,7 @@ CORE_DECODER_REGISTRY, ) from ax.storage.utils import data_by_trial_to_data +from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger from ax.utils.common.serialization import ( extract_init_args, @@ -653,49 +654,17 @@ def multi_type_experiment_from_json( object_json: dict[str, Any], decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, -) -> MultiTypeExperiment: - """Load AE MultiTypeExperiment from JSON.""" - experiment_info = _get_experiment_info(object_json) - - _metric_to_canonical_name = object_json.pop("_metric_to_canonical_name") - _metric_to_trial_type = object_json.pop("_metric_to_trial_type") - _trial_type_to_runner = object_from_json( - object_json.pop("_trial_type_to_runner"), - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, - ) - tracking_metrics = object_from_json( - object_json.pop("tracking_metrics"), - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, - ) - # not relevant to multi type experiment - del object_json["runner"] - - kwargs = { - k: object_from_json( - v, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, - ) - for k, v in object_json.items() - } - kwargs["default_runner"] = _trial_type_to_runner[object_json["default_trial_type"]] - - experiment = MultiTypeExperiment(**kwargs) - for metric in tracking_metrics: - experiment._tracking_metrics[metric.name] = metric - experiment._metric_to_canonical_name = _metric_to_canonical_name - experiment._metric_to_trial_type = _metric_to_trial_type - experiment._trial_type_to_runner = _trial_type_to_runner +) -> Experiment: + """Load AE MultiTypeExperiment from JSON. - _load_experiment_info( - exp=experiment, - exp_info=experiment_info, + This function is kept for backwards compatibility. It delegates to the + unified experiment_from_json which handles all experiments uniformly. + """ + return experiment_from_json( + object_json=object_json, decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ) - return experiment def experiment_from_json( @@ -703,8 +672,17 @@ def experiment_from_json( decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> Experiment: - """Load Ax Experiment from JSON.""" + """Load Ax Experiment from JSON. + + This function handles all experiments uniformly, including those with + multiple trial types (formerly MultiTypeExperiment). + """ experiment_info = _get_experiment_info(object_json) + + # Handle _metric_to_trial_type (may or may not be present) + _metric_to_trial_type = object_json.pop("_metric_to_trial_type", {}) + + # Handle _trial_type_to_runner _trial_type_to_runner_json = object_json.pop("_trial_type_to_runner", None) _trial_type_to_runner = ( object_from_json( @@ -716,6 +694,20 @@ def experiment_from_json( else None ) + # Handle tracking_metrics separately for multi-type experiments + tracking_metrics = object_from_json( + object_json.pop("tracking_metrics"), + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) + + # Handle default_trial_type (may or may not be present) + default_trial_type = object_json.pop("default_trial_type", None) + + # For multi-type experiments, runner in JSON is not relevant + # (we use _trial_type_to_runner instead) + runner_json = object_json.pop("runner", None) + experiment = Experiment( **{ k: object_from_json( @@ -727,8 +719,34 @@ def experiment_from_json( } ) experiment._arms_by_name = {} - if _trial_type_to_runner is not None: + + # Add tracking metrics + for metric in tracking_metrics: + experiment._tracking_metrics[metric.name] = metric + + # Set up _metric_to_trial_type + experiment._metric_to_trial_type = _metric_to_trial_type + + # Set up _trial_type_to_runner + if ( + _trial_type_to_runner is not None + and len(_trial_type_to_runner) > 0 + and ({*_trial_type_to_runner.keys()} != {None}) + ): experiment._trial_type_to_runner = _trial_type_to_runner + # Set the runner from _trial_type_to_runner if we have a default_trial_type + if default_trial_type is not None: + experiment._runner = _trial_type_to_runner.get(default_trial_type) + experiment._default_trial_type = default_trial_type + else: + # Decode and set the runner for non-multi-type experiments + runner = object_from_json( + runner_json, + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) + experiment._runner = runner + experiment._trial_type_to_runner = {Keys.DEFAULT_TRIAL_TYPE.value: runner} _load_experiment_info( exp=experiment, diff --git a/ax/storage/json_store/decoders.py b/ax/storage/json_store/decoders.py index 5d65a745c97..26e30e10074 100644 --- a/ax/storage/json_store/decoders.py +++ b/ax/storage/json_store/decoders.py @@ -48,6 +48,7 @@ REMOVED_TRANSFORMS, REVERSE_TRANSFORM_REGISTRY, ) +from ax.utils.common.constants import Keys from ax.utils.common.kwargs import warn_on_kwargs from ax.utils.common.logger import get_logger from ax.utils.common.typeutils_torch import torch_type_from_str @@ -158,7 +159,9 @@ def batch_trial_from_json( # the SQ at the end of this function. ) batch._index = index - batch._trial_type = trial_type + batch._trial_type = ( + trial_type if trial_type is not None else Keys.DEFAULT_TRIAL_TYPE.value + ) batch._time_created = time_created batch._time_completed = time_completed batch._time_staged = time_staged @@ -219,7 +222,9 @@ def trial_from_json( experiment=experiment, generator_run=generator_run, ttl_seconds=ttl_seconds ) trial._index = index - trial._trial_type = trial_type + trial._trial_type = ( + trial_type if trial_type is not None else Keys.DEFAULT_TRIAL_TYPE.value + ) # Swap `DISPATCHED` for `RUNNING`, since `DISPATCHED` is deprecated and nearly # equivalent to `RUNNING`. trial._status = status if status != TrialStatus.DISPATCHED else TrialStatus.RUNNING diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index bfd6157f129..8a7e1123cef 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -18,7 +18,6 @@ from ax.core.data import Data from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric -from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, Objective, ScalarizedObjective from ax.core.optimization_config import ( MultiObjectiveOptimizationConfig, @@ -75,7 +74,12 @@ def experiment_to_dict(experiment: Experiment) -> dict[str, Any]: - """Convert Ax experiment to a dictionary.""" + """Convert Ax experiment to a dictionary. + + This encoder handles all experiments uniformly, including those with + multiple trial types. The _metric_to_trial_type and _trial_type_to_runner + fields are always included. + """ return { "__type": experiment.__class__.__name__, "name": experiment._name, @@ -92,19 +96,21 @@ def experiment_to_dict(experiment: Experiment) -> dict[str, Any]: "data_by_trial": data_to_data_by_trial(data=experiment.data), "properties": experiment._properties, "_trial_type_to_runner": experiment._trial_type_to_runner, + "_metric_to_trial_type": experiment._metric_to_trial_type, + "default_trial_type": experiment._default_trial_type, } -def multi_type_experiment_to_dict(experiment: MultiTypeExperiment) -> dict[str, Any]: - """Convert AE multitype experiment to a dictionary.""" - multi_type_dict = { - "default_trial_type": experiment._default_trial_type, - "_metric_to_canonical_name": experiment._metric_to_canonical_name, - "_metric_to_trial_type": experiment._metric_to_trial_type, - "_trial_type_to_runner": experiment._trial_type_to_runner, - } - multi_type_dict.update(experiment_to_dict(experiment)) - return multi_type_dict +def multi_type_experiment_to_dict(experiment: Experiment) -> dict[str, Any]: + """Convert AE multitype experiment to a dictionary. + + This encoder is kept for backwards compatibility. It uses the same + unified encoding as experiment_to_dict but sets __type to + MultiTypeExperiment for older loaders. + """ + result = experiment_to_dict(experiment) + result["__type"] = "MultiTypeExperiment" + return result def batch_to_dict(batch: BatchTrial) -> dict[str, Any]: diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 92628b093f3..4d25c33d0e6 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -31,7 +31,6 @@ from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric -from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, Objective, ScalarizedObjective from ax.core.optimization_config import ( MultiObjectiveOptimizationConfig, @@ -214,10 +213,16 @@ def _init_experiment_from_sqa( load_auxiliary_experiments: bool = True, reduced_state: bool = False, ) -> Experiment: - """First step of conversion within experiment_from_sqa.""" + """Convert SQAExperiment to Experiment. + + This method handles all experiments uniformly, including those with + multiple trial types (formerly MultiTypeExperiment). + """ # `experiment_sqa.properties` is `sqlalchemy.ext.mutable.MutableDict` # so need to convert it to regular dict. properties = dict(experiment_sqa.properties or {}) + is_multi_type = properties.get(Keys.SUBCLASS) == "MultiTypeExperiment" + opt_config, tracking_metrics = self.opt_config_and_tracking_metrics_from_sqa( metrics_sqa=experiment_sqa.metrics, pruning_target_parameterization=( @@ -240,14 +245,34 @@ def _init_experiment_from_sqa( if experiment_sqa.status_quo_parameters is not None else None ) + + # Build trial_type_to_runner from runners + trial_type_to_runner: dict[str, Runner | None] = {} if len(experiment_sqa.runners) == 0: runner = None - elif len(experiment_sqa.runners) == 1: + elif len(experiment_sqa.runners) == 1 and not is_multi_type: runner = self.runner_from_sqa(runner_sqa=experiment_sqa.runners[0]) else: - raise ValueError( - "Multiple runners on experiment only supported for MultiTypeExperiment." - ) + # Multiple runners or multi-type experiment + runner = None + for sqa_runner in experiment_sqa.runners: + trial_type = sqa_runner.trial_type + trial_type_to_runner[none_throws(trial_type)] = self.runner_from_sqa( + sqa_runner + ) + + # Handle default_trial_type for multi-type experiments + default_trial_type = experiment_sqa.default_trial_type + if is_multi_type and default_trial_type is not None: + if len(trial_type_to_runner) == 0: + trial_type_to_runner = {default_trial_type: None} + trial_types_with_metrics = { + metric.trial_type + for metric in experiment_sqa.metrics + if metric.trial_type + } + trial_type_to_runner.update(dict.fromkeys(trial_types_with_metrics)) + runner = trial_type_to_runner.get(default_trial_type) auxiliary_experiments_by_purpose = ( ( @@ -260,88 +285,52 @@ def _init_experiment_from_sqa( else {} ) - return Experiment( + experiment = Experiment( name=experiment_sqa.name, description=experiment_sqa.description, search_space=search_space, optimization_config=opt_config, - tracking_metrics=tracking_metrics, + tracking_metrics=( + tracking_metrics if not is_multi_type else [] + ), # Add later for multi-type runner=runner, status_quo=status_quo, is_test=experiment_sqa.is_test, properties=properties, auxiliary_experiments_by_purpose=auxiliary_experiments_by_purpose, + default_trial_type=default_trial_type + if default_trial_type is not None + else Keys.DEFAULT_TRIAL_TYPE.value, ) + # For multi-type experiments, set up trial_type_to_runner and add + # tracking metrics with their trial types + if is_multi_type: + experiment._trial_type_to_runner = trial_type_to_runner + sqa_metric_dict = {metric.name: metric for metric in experiment_sqa.metrics} + for tracking_metric in tracking_metrics: + sqa_metric = sqa_metric_dict[tracking_metric.name] + experiment.add_tracking_metric( + tracking_metric, + trial_type=none_throws(sqa_metric.trial_type), + ) + + return experiment + def _init_mt_experiment_from_sqa( self, experiment_sqa: SQAExperiment, - ) -> MultiTypeExperiment: - """First step of conversion within experiment_from_sqa.""" - properties = dict(experiment_sqa.properties or {}) - opt_config, tracking_metrics = self.opt_config_and_tracking_metrics_from_sqa( - metrics_sqa=experiment_sqa.metrics, - pruning_target_parameterization=( - self._get_pruning_target_parameterization_from_experiment_properties( - properties=properties - ) - ), - ) - search_space = self.search_space_from_sqa( - parameters_sqa=experiment_sqa.parameters, - parameter_constraints_sqa=experiment_sqa.parameter_constraints, - ) - if search_space is None: - raise SQADecodeError("Experiment SearchSpace cannot be None.") - status_quo = ( - Arm( - parameters=experiment_sqa.status_quo_parameters, - name=experiment_sqa.status_quo_name, - ) - if experiment_sqa.status_quo_parameters is not None - else None - ) - - default_trial_type = none_throws(experiment_sqa.default_trial_type) - trial_type_to_runner = { - none_throws(sqa_runner.trial_type): self.runner_from_sqa(sqa_runner) - for sqa_runner in experiment_sqa.runners - } - if len(trial_type_to_runner) == 0: - trial_type_to_runner = {default_trial_type: None} - trial_types_with_metrics = { - metric.trial_type - for metric in experiment_sqa.metrics - if metric.trial_type - } - # trial_type_to_runner is instantiated to map all trial types to None, - # so the trial types are associated with the experiment. This is - # important for adding metrics. - trial_type_to_runner.update(dict.fromkeys(trial_types_with_metrics)) + ) -> Experiment: + """First step of conversion within experiment_from_sqa. - experiment = MultiTypeExperiment( - name=experiment_sqa.name, - description=experiment_sqa.description, - search_space=search_space, - default_trial_type=default_trial_type, - default_runner=trial_type_to_runner.get(default_trial_type), - optimization_config=opt_config, - status_quo=status_quo, - properties=properties, + This method is kept for backwards compatibility. It delegates to the + unified _init_experiment_from_sqa. + """ + return self._init_experiment_from_sqa( + experiment_sqa=experiment_sqa, + load_auxiliary_experiments=False, + reduced_state=False, ) - # pyre-ignore Imcompatible attribute type [8]: attribute _trial_type_to_runner - # has type Dict[str, Optional[Runner]] but is used as type - # Uniont[Dict[str, Optional[Runner]], Dict[str, None]] - experiment._trial_type_to_runner = trial_type_to_runner - sqa_metric_dict = {metric.name: metric for metric in experiment_sqa.metrics} - for tracking_metric in tracking_metrics: - sqa_metric = sqa_metric_dict[tracking_metric.name] - experiment.add_tracking_metric( - tracking_metric, - trial_type=none_throws(sqa_metric.trial_type), - canonical_name=sqa_metric.canonical_name, - ) - return experiment def experiment_from_sqa( self, @@ -359,14 +348,14 @@ def experiment_from_sqa( load_auxiliary_experiment: whether to load auxiliary experiments. """ subclass = (experiment_sqa.properties or {}).get(Keys.SUBCLASS) - if subclass == "MultiTypeExperiment": - experiment = self._init_mt_experiment_from_sqa(experiment_sqa) - else: - experiment = self._init_experiment_from_sqa( - experiment_sqa, - load_auxiliary_experiments=load_auxiliary_experiments, - reduced_state=reduced_state, - ) + is_multi_type = subclass == "MultiTypeExperiment" + + experiment = self._init_experiment_from_sqa( + experiment_sqa, + load_auxiliary_experiments=load_auxiliary_experiments, + reduced_state=reduced_state, + ) + trials = [ self.trial_from_sqa( trial_sqa=trial, @@ -386,11 +375,15 @@ def experiment_from_sqa( experiment.data = data_by_trial_to_data(data_by_trial=data_by_trial) trial_type_to_runner = { - sqa_runner.trial_type: self.runner_from_sqa(sqa_runner) + ( + sqa_runner.trial_type + if sqa_runner.trial_type is not None + else Keys.DEFAULT_TRIAL_TYPE.value + ): self.runner_from_sqa(sqa_runner) for sqa_runner in experiment_sqa.runners } if len(trial_type_to_runner) == 0: - trial_type_to_runner = {None: None} + trial_type_to_runner = {Keys.DEFAULT_TRIAL_TYPE.value: None} experiment._trials = {trial.index: trial for trial in trials} experiment._arms_by_name = {} @@ -412,12 +405,26 @@ def experiment_from_sqa( experiment._experiment_type = self.get_enum_name( value=experiment_sqa.experiment_type, enum=self.config.experiment_type_enum ) - # `_trial_type_to_runner` is set in _init_mt_experiment_from_sqa - if subclass != "MultiTypeExperiment": + # `_trial_type_to_runner` is set in _init_experiment_from_sqa for multi-type + if not is_multi_type: experiment._trial_type_to_runner = cast( - dict[str | None, Runner | None], trial_type_to_runner + dict[str, Runner | None], trial_type_to_runner ) experiment.db_id = experiment_sqa.id + + # Populate _metric_to_trial_type for all experiments + # For multi-type experiments, this was already done in _init_experiment_from_sqa + if not is_multi_type: + default_trial_type = Keys.DEFAULT_TRIAL_TYPE.value + # Add OC metrics + oc = experiment.optimization_config + if oc is not None: + for metric_name in oc.metrics.keys(): + experiment._metric_to_trial_type[metric_name] = default_trial_type + # Add tracking metrics + for metric_name in experiment._tracking_metrics.keys(): + experiment._metric_to_trial_type[metric_name] = default_trial_type + return experiment def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter: @@ -996,7 +1003,11 @@ def trial_from_sqa( reduced_state=reduced_state, immutable_search_space_and_opt_config=immutable_ss_and_oc, ) - trial._trial_type = trial_sqa.trial_type + trial._trial_type = ( + trial_sqa.trial_type + if trial_sqa.trial_type is not None + else Keys.DEFAULT_TRIAL_TYPE.value + ) # Swap `DISPATCHED` for `RUNNING`, since `DISPATCHED` is deprecated and nearly # equivalent to `RUNNING`. trial._status = ( diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index 00a9568caa7..636ec80425f 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -29,7 +29,6 @@ from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric -from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, Objective, ScalarizedObjective from ax.core.optimization_config import ( MultiObjectiveOptimizationConfig, @@ -213,7 +212,7 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment: ] auxiliary_experiments_by_purpose[aux_exp_type] = aux_exp_jsons runners = [] - if isinstance(experiment, MultiTypeExperiment): + if experiment.is_multi_type: experiment._properties[Keys.SUBCLASS] = "MultiTypeExperiment" for trial_type, runner in experiment._trial_type_to_runner.items(): runner_sqa = self.runner_to_sqa(none_throws(runner), trial_type) @@ -221,12 +220,13 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment: for metric in tracking_metrics: metric.trial_type = experiment._metric_to_trial_type[metric.name] - if metric.name in experiment._metric_to_canonical_name: - metric.canonical_name = experiment._metric_to_canonical_name[ - metric.name - ] elif experiment.runner: - runners.append(self.runner_to_sqa(none_throws(experiment.runner))) + runners.append( + self.runner_to_sqa( + none_throws(experiment.runner), + trial_type=experiment.default_trial_type, + ) + ) properties = experiment._properties.copy() if ( oc := experiment.optimization_config diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index d0a9df77ce4..228517a42da 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -903,8 +903,6 @@ def test_mt_experiment_save_and_load(self) -> None: # pyre-fixme[16]: `Experiment` has no attribute `metric_to_trial_type`. self.assertEqual(loaded_experiment.metric_to_trial_type["m1"], "type1") self.assertEqual(loaded_experiment.metric_to_trial_type["m2"], "type2") - # pyre-fixme[16]: `Experiment` has no attribute `_metric_to_canonical_name`. - self.assertEqual(loaded_experiment._metric_to_canonical_name["m2"], "m1") self.assertEqual(len(loaded_experiment.trials), 2) def test_mt_experiment_save_and_load_skip_runners_and_metrics(self) -> None: @@ -919,8 +917,6 @@ def test_mt_experiment_save_and_load_skip_runners_and_metrics(self) -> None: # pyre-fixme[16]: `Experiment` has no attribute `metric_to_trial_type`. self.assertEqual(loaded_experiment.metric_to_trial_type["m1"], "type1") self.assertEqual(loaded_experiment.metric_to_trial_type["m2"], "type2") - # pyre-fixme[16]: `Experiment` has no attribute `_metric_to_canonical_name`. - self.assertEqual(loaded_experiment._metric_to_canonical_name["m2"], "m1") self.assertEqual(len(loaded_experiment.trials), 2) def test_experiment_new_trial(self) -> None: diff --git a/ax/utils/common/constants.py b/ax/utils/common/constants.py index 2f6a6d5b4d5..a1f9fdf531b 100644 --- a/ax/utils/common/constants.py +++ b/ax/utils/common/constants.py @@ -53,6 +53,7 @@ class Keys(StrEnum): COST_INTERCEPT = "cost_intercept" CURRENT_VALUE = "current_value" DEFAULT_OBJECTIVE_NAME = "objective" + DEFAULT_TRIAL_TYPE = "default" EXPAND = "expand" EXPECTED_ACQF_VAL = "expected_acquisition_value" EXPERIMENT_TOTAL_CONCURRENT_ARMS = "total_concurrent_arms" diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 62b36197304..7b1d51cf17e 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -34,7 +34,6 @@ from ax.core.generator_run import GeneratorRun from ax.core.map_metric import MapMetric from ax.core.metric import Metric -from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, Objective, ScalarizedObjective from ax.core.optimization_config import ( MultiObjectiveOptimizationConfig, @@ -609,13 +608,13 @@ def get_test_map_data_experiment( def get_multi_type_experiment( add_trial_type: bool = True, add_trials: bool = False, num_arms: int = 10 -) -> MultiTypeExperiment: +) -> Experiment: oc = OptimizationConfig(Objective(BraninMetric("m1", ["x1", "x2"]), minimize=True)) - experiment = MultiTypeExperiment( + experiment = Experiment( name="test_exp", search_space=get_branin_search_space(), default_trial_type="type1", - default_runner=SyntheticRunner(dummy_metadata="dummy1"), + runner=SyntheticRunner(dummy_metadata="dummy1"), optimization_config=oc, status_quo=Arm(parameters={"x1": 0.0, "x2": 0.0}), ) @@ -625,7 +624,8 @@ def get_multi_type_experiment( ) # Switch the order of variables so metric gives different results experiment.add_tracking_metric( - BraninMetric("m2", ["x2", "x1"]), trial_type="type2", canonical_name="m1" + BraninMetric("m2", ["x2", "x1"]), + trial_type="type2", ) if add_trials and add_trial_type: