diff --git a/ax/core/arm.py b/ax/core/arm.py index af2fc40b924..7389b39b2c3 100644 --- a/ax/core/arm.py +++ b/ax/core/arm.py @@ -9,6 +9,7 @@ import hashlib import json from collections.abc import Mapping +from typing import Self from ax.core.types import TParameterization, TParamValue from ax.utils.common.base import SortableBase @@ -93,7 +94,7 @@ def md5hash(parameters: Mapping[str, TParamValue]) -> str: parameters_str = json.dumps(parameters, sort_keys=True) return hashlib.md5(parameters_str.encode("utf-8")).hexdigest() - def clone(self, clear_name: bool = False) -> "Arm": + def clone(self, clear_name: bool = False) -> Self: """Create a copy of this arm. Args: @@ -102,7 +103,7 @@ def clone(self, clear_name: bool = False) -> "Arm": Defaults to False. """ clear_name = clear_name or not self.has_name - return Arm( + return self.__class__( parameters=self.parameters.copy(), name=None if clear_name else self.name ) diff --git a/ax/core/base_trial.py b/ax/core/base_trial.py index c06fc2ce7d5..ee201603067 100644 --- a/ax/core/base_trial.py +++ b/ax/core/base_trial.py @@ -324,7 +324,7 @@ def update_stop_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: self._stop_metadata.update(metadata) return self._stop_metadata - def run(self) -> BaseTrial: + def run(self) -> Self: """Deploys the trial according to the behavior on the runner. The runner returns a `run_metadata` dict containining metadata @@ -349,7 +349,7 @@ def run(self) -> BaseTrial: self.mark_running() return self - def stop(self, new_status: TrialStatus, reason: str | None = None) -> BaseTrial: + def stop(self, new_status: TrialStatus, reason: str | None = None) -> Self: """Stops the trial according to the behavior on the runner. The runner returns a `stop_metadata` dict containining metadata @@ -384,7 +384,7 @@ def stop(self, new_status: TrialStatus, reason: str | None = None) -> BaseTrial: self.mark_as(new_status) return self - def complete(self, reason: str | None = None) -> BaseTrial: + def complete(self, reason: str | None = None) -> Self: """Stops the trial if functionality is defined on runner and marks trial completed. @@ -524,7 +524,7 @@ def status_reason(self) -> str | None: """Reason string for the trial status (failed, abandoned, or early stopped).""" return self._status_reason - def mark_staged(self, unsafe: bool = False) -> BaseTrial: + def mark_staged(self, unsafe: bool = False) -> Self: """Mark the trial as being staged for running. Args: @@ -542,7 +542,7 @@ def mark_staged(self, unsafe: bool = False) -> BaseTrial: def mark_running( self, no_runner_required: bool = False, unsafe: bool = False - ) -> BaseTrial: + ) -> Self: """Mark trial has started running. Args: @@ -572,7 +572,7 @@ def mark_running( def mark_completed( self, unsafe: bool = False, time_completed: str | None = None - ) -> BaseTrial: + ) -> Self: """Mark trial as completed. Args: @@ -596,9 +596,7 @@ def mark_completed( ) return self - def mark_abandoned( - self, reason: str | None = None, unsafe: bool = False - ) -> BaseTrial: + def mark_abandoned(self, reason: str | None = None, unsafe: bool = False) -> Self: """Mark trial as abandoned. NOTE: Arms in abandoned trials are considered to be 'pending points' @@ -624,7 +622,7 @@ def mark_abandoned( self._time_completed = datetime.now() return self - def mark_failed(self, reason: str | None = None, unsafe: bool = False) -> BaseTrial: + def mark_failed(self, reason: str | None = None, unsafe: bool = False) -> Self: """Mark trial as failed. Args: @@ -644,7 +642,7 @@ def mark_failed(self, reason: str | None = None, unsafe: bool = False) -> BaseTr def mark_early_stopped( self, reason: str | None = None, unsafe: bool = False - ) -> BaseTrial: + ) -> Self: """Mark trial as early stopped. Args: @@ -670,7 +668,7 @@ def mark_early_stopped( self._time_completed = datetime.now() return self - def mark_stale(self, unsafe: bool = False) -> BaseTrial: + def mark_stale(self, unsafe: bool = False) -> Self: """Mark trial as stale. Args: @@ -691,9 +689,7 @@ def mark_stale(self, unsafe: bool = False) -> BaseTrial: self._time_completed = datetime.now() return self - def mark_as( - self, status: TrialStatus, unsafe: bool = False, **kwargs: Any - ) -> BaseTrial: + def mark_as(self, status: TrialStatus, unsafe: bool = False, **kwargs: Any) -> Self: """Mark trial with a new TrialStatus. Args: @@ -724,7 +720,7 @@ def mark_as( raise TrialMutationError(f"Cannot mark trial as {status}.") return self - def mark_arm_abandoned(self, arm_name: str, reason: str | None = None) -> BaseTrial: + def mark_arm_abandoned(self, arm_name: str, reason: str | None = None) -> Self: raise NotImplementedError( "Abandoning arms is only supported for `BatchTrial`. " "Use `trial.mark_abandoned` if applicable." diff --git a/ax/core/batch_trial.py b/ax/core/batch_trial.py index ec51065d45d..f84e8c1ab1e 100644 --- a/ax/core/batch_trial.py +++ b/ax/core/batch_trial.py @@ -13,7 +13,7 @@ from dataclasses import dataclass from datetime import datetime from logging import Logger -from typing import Any, TYPE_CHECKING +from typing import Any, Self, TYPE_CHECKING import numpy as np from ax.core.arm import Arm @@ -467,9 +467,7 @@ def normalized_arm_weights( weights = weights * (total / np.sum(weights)) return OrderedDict(zip(self.arms, weights)) - def mark_arm_abandoned( - self, arm_name: str, reason: str | None = None - ) -> BatchTrial: + def mark_arm_abandoned(self, arm_name: str, reason: str | None = None) -> Self: """Mark a arm abandoned. Usually done after deployment when one arm causes issues but diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 9cdae91e135..071fe83f802 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -15,7 +15,7 @@ from collections.abc import Hashable, Iterable, Mapping, Sequence from datetime import datetime from functools import partial, reduce -from typing import Any, cast, Union +from typing import Any, cast, Self, Union import ax.core.observation as observation import pandas as pd @@ -553,7 +553,7 @@ def immutable_search_space_and_opt_config(self) -> bool: def tracking_metrics(self) -> list[Metric]: return list(self._tracking_metrics.values()) - def add_tracking_metric(self, metric: Metric) -> Experiment: + def add_tracking_metric(self, metric: Metric) -> Self: """Add a new metric to the experiment. Args: @@ -576,7 +576,7 @@ def add_tracking_metric(self, metric: Metric) -> Experiment: self._tracking_metrics[metric.name] = metric return self - def add_tracking_metrics(self, metrics: list[Metric]) -> Experiment: + def add_tracking_metrics(self, metrics: list[Metric]) -> Self: """Add a list of new metrics to the experiment. If any of the metrics are already defined on the experiment, @@ -591,7 +591,7 @@ def add_tracking_metrics(self, metrics: list[Metric]) -> Experiment: self.add_tracking_metric(metric) return self - def update_tracking_metric(self, metric: Metric) -> Experiment: + def update_tracking_metric(self, metric: Metric) -> Self: """Redefine a metric that already exists on the experiment. Args: @@ -603,7 +603,7 @@ def update_tracking_metric(self, metric: Metric) -> Experiment: self._tracking_metrics[metric.name] = metric return self - def remove_tracking_metric(self, metric_name: str) -> Experiment: + def remove_tracking_metric(self, metric_name: str) -> Self: """Remove a metric that already exists on the experiment. Args: diff --git a/ax/core/multi_type_experiment.py b/ax/core/multi_type_experiment.py index d40884ae93d..b6c552d2bcb 100644 --- a/ax/core/multi_type_experiment.py +++ b/ax/core/multi_type_experiment.py @@ -7,7 +7,7 @@ # pyre-strict from collections.abc import Iterable, Sequence -from typing import Any +from typing import Any, Self from ax.core.arm import Arm from ax.core.base_trial import BaseTrial, TrialStatus @@ -96,7 +96,7 @@ def __init__( default_data_type=default_data_type, ) - def add_trial_type(self, trial_type: str, runner: Runner) -> "MultiTypeExperiment": + def add_trial_type(self, trial_type: str, runner: Runner) -> Self: """Add a new trial_type to be supported by this experiment. Args: @@ -122,7 +122,7 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None: self.default_trial_type ) - def update_runner(self, trial_type: str, runner: Runner) -> "MultiTypeExperiment": + def update_runner(self, trial_type: str, runner: Runner) -> Self: """Update the default runner for an existing trial_type. Args: @@ -141,7 +141,7 @@ def add_tracking_metric( metric: Metric, trial_type: str | None = None, canonical_name: str | None = None, - ) -> "MultiTypeExperiment": + ) -> Self: """Add a new metric to the experiment. Args: @@ -199,18 +199,26 @@ def add_tracking_metrics( ) return self - # pyre-fixme[14]: `update_tracking_metric` overrides method defined in - # `Experiment` inconsistently. def update_tracking_metric( - self, metric: Metric, trial_type: str, canonical_name: str | None = None - ) -> "MultiTypeExperiment": + self, + metric: Metric, + trial_type: str | None = None, + canonical_name: str | None = None, + ) -> Self: """Update an existing metric on the experiment. Args: metric: The metric to add. - trial_type: The trial type for which this metric is used. + trial_type: The trial type for which this metric is used. Defaults to + the current trial type of the metric (if set), or the default trial + type otherwise. canonical_name: The default metric for which this metric is a proxy. """ + # Default to the existing trial type if not specified + if trial_type is None: + trial_type = self._metric_to_trial_type.get( + metric.name, self._default_trial_type + ) oc = self.optimization_config oc_metrics = oc.metrics if oc else [] if metric.name in oc_metrics and trial_type != self._default_trial_type: @@ -223,13 +231,13 @@ def update_tracking_metric( raise ValueError(f"`{trial_type}` is not a supported trial type.") super().update_tracking_metric(metric) - self._metric_to_trial_type[metric.name] = trial_type + self._metric_to_trial_type[metric.name] = none_throws(trial_type) if canonical_name is not None: self._metric_to_canonical_name[metric.name] = canonical_name return self @copy_doc(Experiment.remove_tracking_metric) - def remove_tracking_metric(self, metric_name: str) -> "MultiTypeExperiment": + def remove_tracking_metric(self, metric_name: str) -> Self: if metric_name not in self._tracking_metrics: raise ValueError(f"Metric `{metric_name}` doesn't exist on experiment.") diff --git a/ax/core/objective.py b/ax/core/objective.py index 30b59b8e54b..8036ff57ef8 100644 --- a/ax/core/objective.py +++ b/ax/core/objective.py @@ -9,6 +9,7 @@ from __future__ import annotations from collections.abc import Iterable +from typing import Self from ax.core.metric import Metric from ax.exceptions.core import UserInputError @@ -72,9 +73,9 @@ def metric_signatures(self) -> list[str]: """Get a list of objective metric signatures.""" return [m.signature for m in self.metrics] - def clone(self) -> Objective: + def clone(self) -> Self: """Create a copy of the objective.""" - return Objective(self.metric.clone(), self.minimize) + return self.__class__(self.metric.clone(), self.minimize) def __repr__(self) -> str: return 'Objective(metric_name="{}", minimize={})'.format( @@ -129,9 +130,9 @@ def objectives(self) -> list[Objective]: """Get the objectives.""" return self._objectives - def clone(self) -> MultiObjective: + def clone(self) -> Self: """Create a copy of the objective.""" - return MultiObjective(objectives=[o.clone() for o in self.objectives]) + return self.__class__(objectives=[o.clone() for o in self.objectives]) def __repr__(self) -> str: return f"MultiObjective(objectives={self.objectives})" @@ -219,9 +220,9 @@ def expression(self) -> str: return " + ".join(parts).replace(" + -", " - ") - def clone(self) -> ScalarizedObjective: + def clone(self) -> Self: """Create a copy of the objective.""" - return ScalarizedObjective( + return self.__class__( metrics=[m.clone() for m in self.metrics], weights=self.weights.copy(), minimize=self.minimize, diff --git a/ax/core/optimization_config.py b/ax/core/optimization_config.py index 5e08520aa56..ba5d1138f95 100644 --- a/ax/core/optimization_config.py +++ b/ax/core/optimization_config.py @@ -9,6 +9,7 @@ from __future__ import annotations from itertools import groupby +from typing import Self from ax.core.arm import Arm from ax.core.metric import Metric @@ -80,7 +81,7 @@ def __init__( self._outcome_constraints: list[OutcomeConstraint] = constraints self.pruning_target_parameterization = pruning_target_parameterization - def clone(self) -> "OptimizationConfig": + def clone(self) -> Self: """Make a copy of this optimization config.""" return self.clone_with_args() @@ -90,7 +91,7 @@ def clone_with_args( outcome_constraints: None | (list[OutcomeConstraint]) = _NO_OUTCOME_CONSTRAINTS, pruning_target_parameterization: Arm | None = _NO_PRUNING_TARGET_PARAMETERIZATION, - ) -> "OptimizationConfig": + ) -> Self: """Make a copy of this optimization config.""" objective = self.objective.clone() if objective is None else objective outcome_constraints = ( @@ -104,7 +105,7 @@ def clone_with_args( else pruning_target_parameterization ) - return OptimizationConfig( + return self.__class__( objective=objective, outcome_constraints=outcome_constraints, pruning_target_parameterization=pruning_target_parameterization, diff --git a/ax/core/parameter.py b/ax/core/parameter.py index 7591c525a35..523c7224475 100644 --- a/ax/core/parameter.py +++ b/ax/core/parameter.py @@ -14,7 +14,7 @@ from enum import Enum from logging import Logger from math import inf -from typing import Any, cast, Union +from typing import Any, cast, Self, Union from warnings import warn import numpy as np @@ -237,7 +237,7 @@ def dependents(self) -> dict[TParamValue, list[str]]: ) # pyre-fixme[7]: Expected `Parameter` but got implicit return value of `None`. - def clone(self) -> Parameter: + def clone(self) -> Self: pass def disable(self, default_value: TParamValue) -> None: diff --git a/ax/core/runner.py b/ax/core/runner.py index 0fcf8abff20..39d71ed99f8 100644 --- a/ax/core/runner.py +++ b/ax/core/runner.py @@ -10,7 +10,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable -from typing import Any, TYPE_CHECKING +from typing import Any, Self, TYPE_CHECKING from ax.utils.common.base import Base from ax.utils.common.serialization import SerializationMixin @@ -145,7 +145,7 @@ def stop( f"{self.__class__.__name__} does not implement a `stop` method." ) - def clone(self) -> Runner: + def clone(self) -> Self: """Create a copy of this Runner.""" cls = type(self) # pyre-ignore[45]: Cannot instantiate abstract class `Runner`.