From 1ededc9b51c3cc14a8d687e4268cfd3681256da0 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Mon, 9 Feb 2026 09:36:59 -0800 Subject: [PATCH] Phase 3: Add Self type annotations to fluent/clone methods Summary: Add Self type annotation from typing module to methods that return self or cloned instances of the same class. This enables better type inference for subclasses and improves IDE support. Files changed: - ax/core/base_trial.py - run(), stop(), complete(), mark_*() methods - ax/core/batch_trial.py - mark_arm_abandoned() - ax/core/experiment.py - add_tracking_metric(), update_tracking_metric(), etc. - ax/core/multi_type_experiment.py - add_trial_type(), update_runner(), metric methods - ax/core/arm.py - clone() - ax/core/objective.py - clone() on Objective, MultiObjective, ScalarizedObjective - ax/core/optimization_config.py - clone() - ax/core/runner.py - clone() - ax/core/parameter.py - clone() abstract method Reviewed By: saitcakmak Differential Revision: D91648885 --- ax/core/arm.py | 5 +++-- ax/core/base_trial.py | 28 ++++++++++++---------------- ax/core/batch_trial.py | 6 ++---- ax/core/experiment.py | 10 +++++----- ax/core/multi_type_experiment.py | 30 +++++++++++++++++++----------- ax/core/objective.py | 13 +++++++------ ax/core/optimization_config.py | 7 ++++--- ax/core/parameter.py | 4 ++-- ax/core/runner.py | 4 ++-- 9 files changed, 56 insertions(+), 51 deletions(-) diff --git a/ax/core/arm.py b/ax/core/arm.py index af2fc40b924..7389b39b2c3 100644 --- a/ax/core/arm.py +++ b/ax/core/arm.py @@ -9,6 +9,7 @@ import hashlib import json from collections.abc import Mapping +from typing import Self from ax.core.types import TParameterization, TParamValue from ax.utils.common.base import SortableBase @@ -93,7 +94,7 @@ def md5hash(parameters: Mapping[str, TParamValue]) -> str: parameters_str = json.dumps(parameters, sort_keys=True) return hashlib.md5(parameters_str.encode("utf-8")).hexdigest() - def clone(self, clear_name: bool = False) -> "Arm": + def clone(self, clear_name: bool = False) -> Self: """Create a copy of this arm. Args: @@ -102,7 +103,7 @@ def clone(self, clear_name: bool = False) -> "Arm": Defaults to False. """ clear_name = clear_name or not self.has_name - return Arm( + return self.__class__( parameters=self.parameters.copy(), name=None if clear_name else self.name ) diff --git a/ax/core/base_trial.py b/ax/core/base_trial.py index c06fc2ce7d5..ee201603067 100644 --- a/ax/core/base_trial.py +++ b/ax/core/base_trial.py @@ -324,7 +324,7 @@ def update_stop_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: self._stop_metadata.update(metadata) return self._stop_metadata - def run(self) -> BaseTrial: + def run(self) -> Self: """Deploys the trial according to the behavior on the runner. The runner returns a `run_metadata` dict containining metadata @@ -349,7 +349,7 @@ def run(self) -> BaseTrial: self.mark_running() return self - def stop(self, new_status: TrialStatus, reason: str | None = None) -> BaseTrial: + def stop(self, new_status: TrialStatus, reason: str | None = None) -> Self: """Stops the trial according to the behavior on the runner. The runner returns a `stop_metadata` dict containining metadata @@ -384,7 +384,7 @@ def stop(self, new_status: TrialStatus, reason: str | None = None) -> BaseTrial: self.mark_as(new_status) return self - def complete(self, reason: str | None = None) -> BaseTrial: + def complete(self, reason: str | None = None) -> Self: """Stops the trial if functionality is defined on runner and marks trial completed. @@ -524,7 +524,7 @@ def status_reason(self) -> str | None: """Reason string for the trial status (failed, abandoned, or early stopped).""" return self._status_reason - def mark_staged(self, unsafe: bool = False) -> BaseTrial: + def mark_staged(self, unsafe: bool = False) -> Self: """Mark the trial as being staged for running. Args: @@ -542,7 +542,7 @@ def mark_staged(self, unsafe: bool = False) -> BaseTrial: def mark_running( self, no_runner_required: bool = False, unsafe: bool = False - ) -> BaseTrial: + ) -> Self: """Mark trial has started running. Args: @@ -572,7 +572,7 @@ def mark_running( def mark_completed( self, unsafe: bool = False, time_completed: str | None = None - ) -> BaseTrial: + ) -> Self: """Mark trial as completed. Args: @@ -596,9 +596,7 @@ def mark_completed( ) return self - def mark_abandoned( - self, reason: str | None = None, unsafe: bool = False - ) -> BaseTrial: + def mark_abandoned(self, reason: str | None = None, unsafe: bool = False) -> Self: """Mark trial as abandoned. NOTE: Arms in abandoned trials are considered to be 'pending points' @@ -624,7 +622,7 @@ def mark_abandoned( self._time_completed = datetime.now() return self - def mark_failed(self, reason: str | None = None, unsafe: bool = False) -> BaseTrial: + def mark_failed(self, reason: str | None = None, unsafe: bool = False) -> Self: """Mark trial as failed. Args: @@ -644,7 +642,7 @@ def mark_failed(self, reason: str | None = None, unsafe: bool = False) -> BaseTr def mark_early_stopped( self, reason: str | None = None, unsafe: bool = False - ) -> BaseTrial: + ) -> Self: """Mark trial as early stopped. Args: @@ -670,7 +668,7 @@ def mark_early_stopped( self._time_completed = datetime.now() return self - def mark_stale(self, unsafe: bool = False) -> BaseTrial: + def mark_stale(self, unsafe: bool = False) -> Self: """Mark trial as stale. Args: @@ -691,9 +689,7 @@ def mark_stale(self, unsafe: bool = False) -> BaseTrial: self._time_completed = datetime.now() return self - def mark_as( - self, status: TrialStatus, unsafe: bool = False, **kwargs: Any - ) -> BaseTrial: + def mark_as(self, status: TrialStatus, unsafe: bool = False, **kwargs: Any) -> Self: """Mark trial with a new TrialStatus. Args: @@ -724,7 +720,7 @@ def mark_as( raise TrialMutationError(f"Cannot mark trial as {status}.") return self - def mark_arm_abandoned(self, arm_name: str, reason: str | None = None) -> BaseTrial: + def mark_arm_abandoned(self, arm_name: str, reason: str | None = None) -> Self: raise NotImplementedError( "Abandoning arms is only supported for `BatchTrial`. " "Use `trial.mark_abandoned` if applicable." diff --git a/ax/core/batch_trial.py b/ax/core/batch_trial.py index ec51065d45d..f84e8c1ab1e 100644 --- a/ax/core/batch_trial.py +++ b/ax/core/batch_trial.py @@ -13,7 +13,7 @@ from dataclasses import dataclass from datetime import datetime from logging import Logger -from typing import Any, TYPE_CHECKING +from typing import Any, Self, TYPE_CHECKING import numpy as np from ax.core.arm import Arm @@ -467,9 +467,7 @@ def normalized_arm_weights( weights = weights * (total / np.sum(weights)) return OrderedDict(zip(self.arms, weights)) - def mark_arm_abandoned( - self, arm_name: str, reason: str | None = None - ) -> BatchTrial: + def mark_arm_abandoned(self, arm_name: str, reason: str | None = None) -> Self: """Mark a arm abandoned. Usually done after deployment when one arm causes issues but diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 9cdae91e135..071fe83f802 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -15,7 +15,7 @@ from collections.abc import Hashable, Iterable, Mapping, Sequence from datetime import datetime from functools import partial, reduce -from typing import Any, cast, Union +from typing import Any, cast, Self, Union import ax.core.observation as observation import pandas as pd @@ -553,7 +553,7 @@ def immutable_search_space_and_opt_config(self) -> bool: def tracking_metrics(self) -> list[Metric]: return list(self._tracking_metrics.values()) - def add_tracking_metric(self, metric: Metric) -> Experiment: + def add_tracking_metric(self, metric: Metric) -> Self: """Add a new metric to the experiment. Args: @@ -576,7 +576,7 @@ def add_tracking_metric(self, metric: Metric) -> Experiment: self._tracking_metrics[metric.name] = metric return self - def add_tracking_metrics(self, metrics: list[Metric]) -> Experiment: + def add_tracking_metrics(self, metrics: list[Metric]) -> Self: """Add a list of new metrics to the experiment. If any of the metrics are already defined on the experiment, @@ -591,7 +591,7 @@ def add_tracking_metrics(self, metrics: list[Metric]) -> Experiment: self.add_tracking_metric(metric) return self - def update_tracking_metric(self, metric: Metric) -> Experiment: + def update_tracking_metric(self, metric: Metric) -> Self: """Redefine a metric that already exists on the experiment. Args: @@ -603,7 +603,7 @@ def update_tracking_metric(self, metric: Metric) -> Experiment: self._tracking_metrics[metric.name] = metric return self - def remove_tracking_metric(self, metric_name: str) -> Experiment: + def remove_tracking_metric(self, metric_name: str) -> Self: """Remove a metric that already exists on the experiment. Args: diff --git a/ax/core/multi_type_experiment.py b/ax/core/multi_type_experiment.py index d40884ae93d..b6c552d2bcb 100644 --- a/ax/core/multi_type_experiment.py +++ b/ax/core/multi_type_experiment.py @@ -7,7 +7,7 @@ # pyre-strict from collections.abc import Iterable, Sequence -from typing import Any +from typing import Any, Self from ax.core.arm import Arm from ax.core.base_trial import BaseTrial, TrialStatus @@ -96,7 +96,7 @@ def __init__( default_data_type=default_data_type, ) - def add_trial_type(self, trial_type: str, runner: Runner) -> "MultiTypeExperiment": + def add_trial_type(self, trial_type: str, runner: Runner) -> Self: """Add a new trial_type to be supported by this experiment. Args: @@ -122,7 +122,7 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None: self.default_trial_type ) - def update_runner(self, trial_type: str, runner: Runner) -> "MultiTypeExperiment": + def update_runner(self, trial_type: str, runner: Runner) -> Self: """Update the default runner for an existing trial_type. Args: @@ -141,7 +141,7 @@ def add_tracking_metric( metric: Metric, trial_type: str | None = None, canonical_name: str | None = None, - ) -> "MultiTypeExperiment": + ) -> Self: """Add a new metric to the experiment. Args: @@ -199,18 +199,26 @@ def add_tracking_metrics( ) return self - # pyre-fixme[14]: `update_tracking_metric` overrides method defined in - # `Experiment` inconsistently. def update_tracking_metric( - self, metric: Metric, trial_type: str, canonical_name: str | None = None - ) -> "MultiTypeExperiment": + self, + metric: Metric, + trial_type: str | None = None, + canonical_name: str | None = None, + ) -> Self: """Update an existing metric on the experiment. Args: metric: The metric to add. - trial_type: The trial type for which this metric is used. + trial_type: The trial type for which this metric is used. Defaults to + the current trial type of the metric (if set), or the default trial + type otherwise. canonical_name: The default metric for which this metric is a proxy. """ + # Default to the existing trial type if not specified + if trial_type is None: + trial_type = self._metric_to_trial_type.get( + metric.name, self._default_trial_type + ) oc = self.optimization_config oc_metrics = oc.metrics if oc else [] if metric.name in oc_metrics and trial_type != self._default_trial_type: @@ -223,13 +231,13 @@ def update_tracking_metric( raise ValueError(f"`{trial_type}` is not a supported trial type.") super().update_tracking_metric(metric) - self._metric_to_trial_type[metric.name] = trial_type + self._metric_to_trial_type[metric.name] = none_throws(trial_type) if canonical_name is not None: self._metric_to_canonical_name[metric.name] = canonical_name return self @copy_doc(Experiment.remove_tracking_metric) - def remove_tracking_metric(self, metric_name: str) -> "MultiTypeExperiment": + def remove_tracking_metric(self, metric_name: str) -> Self: if metric_name not in self._tracking_metrics: raise ValueError(f"Metric `{metric_name}` doesn't exist on experiment.") diff --git a/ax/core/objective.py b/ax/core/objective.py index 30b59b8e54b..8036ff57ef8 100644 --- a/ax/core/objective.py +++ b/ax/core/objective.py @@ -9,6 +9,7 @@ from __future__ import annotations from collections.abc import Iterable +from typing import Self from ax.core.metric import Metric from ax.exceptions.core import UserInputError @@ -72,9 +73,9 @@ def metric_signatures(self) -> list[str]: """Get a list of objective metric signatures.""" return [m.signature for m in self.metrics] - def clone(self) -> Objective: + def clone(self) -> Self: """Create a copy of the objective.""" - return Objective(self.metric.clone(), self.minimize) + return self.__class__(self.metric.clone(), self.minimize) def __repr__(self) -> str: return 'Objective(metric_name="{}", minimize={})'.format( @@ -129,9 +130,9 @@ def objectives(self) -> list[Objective]: """Get the objectives.""" return self._objectives - def clone(self) -> MultiObjective: + def clone(self) -> Self: """Create a copy of the objective.""" - return MultiObjective(objectives=[o.clone() for o in self.objectives]) + return self.__class__(objectives=[o.clone() for o in self.objectives]) def __repr__(self) -> str: return f"MultiObjective(objectives={self.objectives})" @@ -219,9 +220,9 @@ def expression(self) -> str: return " + ".join(parts).replace(" + -", " - ") - def clone(self) -> ScalarizedObjective: + def clone(self) -> Self: """Create a copy of the objective.""" - return ScalarizedObjective( + return self.__class__( metrics=[m.clone() for m in self.metrics], weights=self.weights.copy(), minimize=self.minimize, diff --git a/ax/core/optimization_config.py b/ax/core/optimization_config.py index 5e08520aa56..ba5d1138f95 100644 --- a/ax/core/optimization_config.py +++ b/ax/core/optimization_config.py @@ -9,6 +9,7 @@ from __future__ import annotations from itertools import groupby +from typing import Self from ax.core.arm import Arm from ax.core.metric import Metric @@ -80,7 +81,7 @@ def __init__( self._outcome_constraints: list[OutcomeConstraint] = constraints self.pruning_target_parameterization = pruning_target_parameterization - def clone(self) -> "OptimizationConfig": + def clone(self) -> Self: """Make a copy of this optimization config.""" return self.clone_with_args() @@ -90,7 +91,7 @@ def clone_with_args( outcome_constraints: None | (list[OutcomeConstraint]) = _NO_OUTCOME_CONSTRAINTS, pruning_target_parameterization: Arm | None = _NO_PRUNING_TARGET_PARAMETERIZATION, - ) -> "OptimizationConfig": + ) -> Self: """Make a copy of this optimization config.""" objective = self.objective.clone() if objective is None else objective outcome_constraints = ( @@ -104,7 +105,7 @@ def clone_with_args( else pruning_target_parameterization ) - return OptimizationConfig( + return self.__class__( objective=objective, outcome_constraints=outcome_constraints, pruning_target_parameterization=pruning_target_parameterization, diff --git a/ax/core/parameter.py b/ax/core/parameter.py index 7591c525a35..523c7224475 100644 --- a/ax/core/parameter.py +++ b/ax/core/parameter.py @@ -14,7 +14,7 @@ from enum import Enum from logging import Logger from math import inf -from typing import Any, cast, Union +from typing import Any, cast, Self, Union from warnings import warn import numpy as np @@ -237,7 +237,7 @@ def dependents(self) -> dict[TParamValue, list[str]]: ) # pyre-fixme[7]: Expected `Parameter` but got implicit return value of `None`. - def clone(self) -> Parameter: + def clone(self) -> Self: pass def disable(self, default_value: TParamValue) -> None: diff --git a/ax/core/runner.py b/ax/core/runner.py index 0fcf8abff20..39d71ed99f8 100644 --- a/ax/core/runner.py +++ b/ax/core/runner.py @@ -10,7 +10,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable -from typing import Any, TYPE_CHECKING +from typing import Any, Self, TYPE_CHECKING from ax.utils.common.base import Base from ax.utils.common.serialization import SerializationMixin @@ -145,7 +145,7 @@ def stop( f"{self.__class__.__name__} does not implement a `stop` method." ) - def clone(self) -> Runner: + def clone(self) -> Self: """Create a copy of this Runner.""" cls = type(self) # pyre-ignore[45]: Cannot instantiate abstract class `Runner`.