Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ax/core/arm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import hashlib
import json
from collections.abc import Mapping
from typing import Self

from ax.core.types import TParameterization, TParamValue
from ax.utils.common.base import SortableBase
Expand Down Expand Up @@ -93,7 +94,7 @@ def md5hash(parameters: Mapping[str, TParamValue]) -> str:
parameters_str = json.dumps(parameters, sort_keys=True)
return hashlib.md5(parameters_str.encode("utf-8")).hexdigest()

def clone(self, clear_name: bool = False) -> "Arm":
def clone(self, clear_name: bool = False) -> Self:
"""Create a copy of this arm.

Args:
Expand All @@ -102,7 +103,7 @@ def clone(self, clear_name: bool = False) -> "Arm":
Defaults to False.
"""
clear_name = clear_name or not self.has_name
return Arm(
return self.__class__(
parameters=self.parameters.copy(), name=None if clear_name else self.name
)

Expand Down
28 changes: 12 additions & 16 deletions ax/core/base_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def update_stop_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
self._stop_metadata.update(metadata)
return self._stop_metadata

def run(self) -> BaseTrial:
def run(self) -> Self:
"""Deploys the trial according to the behavior on the runner.

The runner returns a `run_metadata` dict containining metadata
Expand All @@ -349,7 +349,7 @@ def run(self) -> BaseTrial:
self.mark_running()
return self

def stop(self, new_status: TrialStatus, reason: str | None = None) -> BaseTrial:
def stop(self, new_status: TrialStatus, reason: str | None = None) -> Self:
"""Stops the trial according to the behavior on the runner.

The runner returns a `stop_metadata` dict containining metadata
Expand Down Expand Up @@ -384,7 +384,7 @@ def stop(self, new_status: TrialStatus, reason: str | None = None) -> BaseTrial:
self.mark_as(new_status)
return self

def complete(self, reason: str | None = None) -> BaseTrial:
def complete(self, reason: str | None = None) -> Self:
"""Stops the trial if functionality is defined on runner
and marks trial completed.

Expand Down Expand Up @@ -524,7 +524,7 @@ def status_reason(self) -> str | None:
"""Reason string for the trial status (failed, abandoned, or early stopped)."""
return self._status_reason

def mark_staged(self, unsafe: bool = False) -> BaseTrial:
def mark_staged(self, unsafe: bool = False) -> Self:
"""Mark the trial as being staged for running.

Args:
Expand All @@ -542,7 +542,7 @@ def mark_staged(self, unsafe: bool = False) -> BaseTrial:

def mark_running(
self, no_runner_required: bool = False, unsafe: bool = False
) -> BaseTrial:
) -> Self:
"""Mark trial has started running.

Args:
Expand Down Expand Up @@ -572,7 +572,7 @@ def mark_running(

def mark_completed(
self, unsafe: bool = False, time_completed: str | None = None
) -> BaseTrial:
) -> Self:
"""Mark trial as completed.

Args:
Expand All @@ -596,9 +596,7 @@ def mark_completed(
)
return self

def mark_abandoned(
self, reason: str | None = None, unsafe: bool = False
) -> BaseTrial:
def mark_abandoned(self, reason: str | None = None, unsafe: bool = False) -> Self:
"""Mark trial as abandoned.

NOTE: Arms in abandoned trials are considered to be 'pending points'
Expand All @@ -624,7 +622,7 @@ def mark_abandoned(
self._time_completed = datetime.now()
return self

def mark_failed(self, reason: str | None = None, unsafe: bool = False) -> BaseTrial:
def mark_failed(self, reason: str | None = None, unsafe: bool = False) -> Self:
"""Mark trial as failed.

Args:
Expand All @@ -644,7 +642,7 @@ def mark_failed(self, reason: str | None = None, unsafe: bool = False) -> BaseTr

def mark_early_stopped(
self, reason: str | None = None, unsafe: bool = False
) -> BaseTrial:
) -> Self:
"""Mark trial as early stopped.

Args:
Expand All @@ -670,7 +668,7 @@ def mark_early_stopped(
self._time_completed = datetime.now()
return self

def mark_stale(self, unsafe: bool = False) -> BaseTrial:
def mark_stale(self, unsafe: bool = False) -> Self:
"""Mark trial as stale.

Args:
Expand All @@ -691,9 +689,7 @@ def mark_stale(self, unsafe: bool = False) -> BaseTrial:
self._time_completed = datetime.now()
return self

def mark_as(
self, status: TrialStatus, unsafe: bool = False, **kwargs: Any
) -> BaseTrial:
def mark_as(self, status: TrialStatus, unsafe: bool = False, **kwargs: Any) -> Self:
"""Mark trial with a new TrialStatus.

Args:
Expand Down Expand Up @@ -724,7 +720,7 @@ def mark_as(
raise TrialMutationError(f"Cannot mark trial as {status}.")
return self

def mark_arm_abandoned(self, arm_name: str, reason: str | None = None) -> BaseTrial:
def mark_arm_abandoned(self, arm_name: str, reason: str | None = None) -> Self:
raise NotImplementedError(
"Abandoning arms is only supported for `BatchTrial`. "
"Use `trial.mark_abandoned` if applicable."
Expand Down
6 changes: 2 additions & 4 deletions ax/core/batch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dataclasses import dataclass
from datetime import datetime
from logging import Logger
from typing import Any, TYPE_CHECKING
from typing import Any, Self, TYPE_CHECKING

import numpy as np
from ax.core.arm import Arm
Expand Down Expand Up @@ -467,9 +467,7 @@ def normalized_arm_weights(
weights = weights * (total / np.sum(weights))
return OrderedDict(zip(self.arms, weights))

def mark_arm_abandoned(
self, arm_name: str, reason: str | None = None
) -> BatchTrial:
def mark_arm_abandoned(self, arm_name: str, reason: str | None = None) -> Self:
"""Mark a arm abandoned.

Usually done after deployment when one arm causes issues but
Expand Down
10 changes: 5 additions & 5 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from collections.abc import Hashable, Iterable, Mapping, Sequence
from datetime import datetime
from functools import partial, reduce
from typing import Any, cast, Union
from typing import Any, cast, Self, Union

import ax.core.observation as observation
import pandas as pd
Expand Down Expand Up @@ -553,7 +553,7 @@ def immutable_search_space_and_opt_config(self) -> bool:
def tracking_metrics(self) -> list[Metric]:
return list(self._tracking_metrics.values())

def add_tracking_metric(self, metric: Metric) -> Experiment:
def add_tracking_metric(self, metric: Metric) -> Self:
"""Add a new metric to the experiment.

Args:
Expand All @@ -576,7 +576,7 @@ def add_tracking_metric(self, metric: Metric) -> Experiment:
self._tracking_metrics[metric.name] = metric
return self

def add_tracking_metrics(self, metrics: list[Metric]) -> Experiment:
def add_tracking_metrics(self, metrics: list[Metric]) -> Self:
"""Add a list of new metrics to the experiment.

If any of the metrics are already defined on the experiment,
Expand All @@ -591,7 +591,7 @@ def add_tracking_metrics(self, metrics: list[Metric]) -> Experiment:
self.add_tracking_metric(metric)
return self

def update_tracking_metric(self, metric: Metric) -> Experiment:
def update_tracking_metric(self, metric: Metric) -> Self:
"""Redefine a metric that already exists on the experiment.

Args:
Expand All @@ -603,7 +603,7 @@ def update_tracking_metric(self, metric: Metric) -> Experiment:
self._tracking_metrics[metric.name] = metric
return self

def remove_tracking_metric(self, metric_name: str) -> Experiment:
def remove_tracking_metric(self, metric_name: str) -> Self:
"""Remove a metric that already exists on the experiment.

Args:
Expand Down
30 changes: 19 additions & 11 deletions ax/core/multi_type_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# pyre-strict

from collections.abc import Iterable, Sequence
from typing import Any
from typing import Any, Self

from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial, TrialStatus
Expand Down Expand Up @@ -96,7 +96,7 @@ def __init__(
default_data_type=default_data_type,
)

def add_trial_type(self, trial_type: str, runner: Runner) -> "MultiTypeExperiment":
def add_trial_type(self, trial_type: str, runner: Runner) -> Self:
"""Add a new trial_type to be supported by this experiment.

Args:
Expand All @@ -122,7 +122,7 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None:
self.default_trial_type
)

def update_runner(self, trial_type: str, runner: Runner) -> "MultiTypeExperiment":
def update_runner(self, trial_type: str, runner: Runner) -> Self:
"""Update the default runner for an existing trial_type.

Args:
Expand All @@ -141,7 +141,7 @@ def add_tracking_metric(
metric: Metric,
trial_type: str | None = None,
canonical_name: str | None = None,
) -> "MultiTypeExperiment":
) -> Self:
"""Add a new metric to the experiment.

Args:
Expand Down Expand Up @@ -199,18 +199,26 @@ def add_tracking_metrics(
)
return self

# pyre-fixme[14]: `update_tracking_metric` overrides method defined in
# `Experiment` inconsistently.
def update_tracking_metric(
self, metric: Metric, trial_type: str, canonical_name: str | None = None
) -> "MultiTypeExperiment":
self,
metric: Metric,
trial_type: str | None = None,
canonical_name: str | None = None,
) -> Self:
"""Update an existing metric on the experiment.

Args:
metric: The metric to add.
trial_type: The trial type for which this metric is used.
trial_type: The trial type for which this metric is used. Defaults to
the current trial type of the metric (if set), or the default trial
type otherwise.
canonical_name: The default metric for which this metric is a proxy.
"""
# Default to the existing trial type if not specified
if trial_type is None:
trial_type = self._metric_to_trial_type.get(
metric.name, self._default_trial_type
)
oc = self.optimization_config
oc_metrics = oc.metrics if oc else []
if metric.name in oc_metrics and trial_type != self._default_trial_type:
Expand All @@ -223,13 +231,13 @@ def update_tracking_metric(
raise ValueError(f"`{trial_type}` is not a supported trial type.")

super().update_tracking_metric(metric)
self._metric_to_trial_type[metric.name] = trial_type
self._metric_to_trial_type[metric.name] = none_throws(trial_type)
if canonical_name is not None:
self._metric_to_canonical_name[metric.name] = canonical_name
return self

@copy_doc(Experiment.remove_tracking_metric)
def remove_tracking_metric(self, metric_name: str) -> "MultiTypeExperiment":
def remove_tracking_metric(self, metric_name: str) -> Self:
if metric_name not in self._tracking_metrics:
raise ValueError(f"Metric `{metric_name}` doesn't exist on experiment.")

Expand Down
13 changes: 7 additions & 6 deletions ax/core/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from __future__ import annotations

from collections.abc import Iterable
from typing import Self

from ax.core.metric import Metric
from ax.exceptions.core import UserInputError
Expand Down Expand Up @@ -72,9 +73,9 @@ def metric_signatures(self) -> list[str]:
"""Get a list of objective metric signatures."""
return [m.signature for m in self.metrics]

def clone(self) -> Objective:
def clone(self) -> Self:
"""Create a copy of the objective."""
return Objective(self.metric.clone(), self.minimize)
return self.__class__(self.metric.clone(), self.minimize)

def __repr__(self) -> str:
return 'Objective(metric_name="{}", minimize={})'.format(
Expand Down Expand Up @@ -129,9 +130,9 @@ def objectives(self) -> list[Objective]:
"""Get the objectives."""
return self._objectives

def clone(self) -> MultiObjective:
def clone(self) -> Self:
"""Create a copy of the objective."""
return MultiObjective(objectives=[o.clone() for o in self.objectives])
return self.__class__(objectives=[o.clone() for o in self.objectives])

def __repr__(self) -> str:
return f"MultiObjective(objectives={self.objectives})"
Expand Down Expand Up @@ -219,9 +220,9 @@ def expression(self) -> str:

return " + ".join(parts).replace(" + -", " - ")

def clone(self) -> ScalarizedObjective:
def clone(self) -> Self:
"""Create a copy of the objective."""
return ScalarizedObjective(
return self.__class__(
metrics=[m.clone() for m in self.metrics],
weights=self.weights.copy(),
minimize=self.minimize,
Expand Down
7 changes: 4 additions & 3 deletions ax/core/optimization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from __future__ import annotations

from itertools import groupby
from typing import Self

from ax.core.arm import Arm
from ax.core.metric import Metric
Expand Down Expand Up @@ -80,7 +81,7 @@ def __init__(
self._outcome_constraints: list[OutcomeConstraint] = constraints
self.pruning_target_parameterization = pruning_target_parameterization

def clone(self) -> "OptimizationConfig":
def clone(self) -> Self:
"""Make a copy of this optimization config."""
return self.clone_with_args()

Expand All @@ -90,7 +91,7 @@ def clone_with_args(
outcome_constraints: None | (list[OutcomeConstraint]) = _NO_OUTCOME_CONSTRAINTS,
pruning_target_parameterization: Arm
| None = _NO_PRUNING_TARGET_PARAMETERIZATION,
) -> "OptimizationConfig":
) -> Self:
"""Make a copy of this optimization config."""
objective = self.objective.clone() if objective is None else objective
outcome_constraints = (
Expand All @@ -104,7 +105,7 @@ def clone_with_args(
else pruning_target_parameterization
)

return OptimizationConfig(
return self.__class__(
objective=objective,
outcome_constraints=outcome_constraints,
pruning_target_parameterization=pruning_target_parameterization,
Expand Down
4 changes: 2 additions & 2 deletions ax/core/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from enum import Enum
from logging import Logger
from math import inf
from typing import Any, cast, Union
from typing import Any, cast, Self, Union
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -237,7 +237,7 @@ def dependents(self) -> dict[TParamValue, list[str]]:
)

# pyre-fixme[7]: Expected `Parameter` but got implicit return value of `None`.
def clone(self) -> Parameter:
def clone(self) -> Self:
pass

def disable(self, default_value: TParamValue) -> None:
Expand Down
Loading