Skip to content
12 changes: 5 additions & 7 deletions ax/adapter/adapter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def extract_objective_thresholds(
# Check that all thresholds correspond to a metric.
if set(objective_threshold_dict.keys()).difference(set(objective.metric_names)):
raise ValueError(
"Some objective thresholds do not have corresponding metrics."
"Some objective thresholds do not have corresponding metrics. "
f"Got {objective_thresholds=} and {objective=}."
)

Expand Down Expand Up @@ -566,12 +566,10 @@ def get_pareto_frontier_and_configs(
"`observation_data` will not be used.",
stacklevel=2,
)
else:
if observation_data is None:
raise ValueError(
"`observation_data` must not be None when `use_model_predictions` is "
"True."
)
elif observation_data is None:
raise ValueError(
"`observation_data` must not be None when `use_model_predictions` is False."
)

array_to_tensor = adapter._array_to_tensor
if use_model_predictions:
Expand Down
2 changes: 1 addition & 1 deletion ax/adapter/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ def _update_w_aux_exp_datasets(
if pe_data.df.empty:
raise DataRequiredError(
"No data found in the auxiliary preference exploration "
"experiment. Play the preference game first or use another"
"experiment. Play the preference game first or use another "
"preference profile with recorded preference data."
)

Expand Down
7 changes: 1 addition & 6 deletions ax/adapter/transforms/int_range_to_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,13 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
and p.cardinality() <= self.max_choices
):
values = list(range(int(p.lower), int(p.upper) + 1))
target_value = (
None
if p.target_value is None
else next(i for i, v in enumerate(values) if v == p.target_value)
)
transformed_parameters[p_name] = ChoiceParameter(
name=p_name,
parameter_type=p.parameter_type,
values=values, # pyre-fixme[6]
is_ordered=True,
is_fidelity=p.is_fidelity,
target_value=target_value,
target_value=p.target_value,
)
else:
transformed_parameters[p.name] = p
Expand Down
2 changes: 1 addition & 1 deletion ax/adapter/transforms/relativize.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def get_metric_index(data: ObservationData, metric_signature: str) -> int:
"""Get the index of a metric in the ObservationData."""
try:
return data.metric_signatures.index(metric_signature)
except (IndexError, StopIteration):
except ValueError:
raise ValueError(
"Relativization cannot be performed because "
"ObservationData for status quo is missing metrics"
Expand Down
16 changes: 12 additions & 4 deletions ax/adapter/transforms/tests/test_int_range_to_choice_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,21 @@
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.utils.common.testutils import TestCase
from pyre_extensions import assert_is_instance


class IntRangeToChoiceTransformTest(TestCase):
def setUp(self) -> None:
super().setUp()
self.search_space = SearchSpace(
parameters=[
RangeParameter("a", lower=1, upper=5, parameter_type=ParameterType.INT),
RangeParameter(
"a",
lower=1,
upper=5,
parameter_type=ParameterType.INT,
target_value=2,
),
ChoiceParameter(
"b", parameter_type=ParameterType.STRING, values=["a", "b", "c"]
),
Expand All @@ -43,9 +50,10 @@ def test_TransformObservationFeatures(self) -> None:
def test_TransformSearchSpace(self) -> None:
ss2 = deepcopy(self.search_space)
ss2 = self.t.transform_search_space(ss2)
self.assertTrue(isinstance(ss2.parameters["a"], ChoiceParameter))
# pyre-fixme[16]: `Parameter` has no attribute `values`.
self.assertTrue(ss2.parameters["a"].values, [1, 2, 3, 4, 5])
new_a = assert_is_instance(ss2.parameters["a"], ChoiceParameter)
self.assertTrue(new_a, ChoiceParameter)
self.assertTrue(new_a.values, [1, 2, 3, 4, 5])
self.assertEqual(new_a.target_value, 2)

def test_num_choices(self) -> None:
parameters = {
Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/healthcheck/regression_detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def compute_regression_probabilities_single_trial(

if len(metric_names) == 0:
raise ValueError(
"No common metrics between the provided data and the size thresholds."
"No common metrics between the provided data and the size thresholds. "
"Need to provide both data and the size thresholds for metrics of interest."
)

Expand Down Expand Up @@ -159,7 +159,7 @@ def detect_regressions_single_trial(

if len(metric_names) == 0:
raise ValueError(
"No common metrics between the provided data and the thresholds."
"No common metrics between the provided data and the thresholds. "
"Need to provide both data and the size thresholds for metrics of interest."
)

Expand Down
2 changes: 1 addition & 1 deletion ax/analysis/plotly/arm_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def _get_subtitle(
f"{'Modeled' if use_model_predictions else 'Observed'} effects on "
f"{metric_label}"
)
trial_clause = f" for Trial {trial_index}." if trial_index is not None else ""
trial_clause = f" for Trial {trial_index}" if trial_index is not None else ""
first_sentence = f"{first_clause}{trial_clause}."

if use_model_predictions:
Expand Down
2 changes: 1 addition & 1 deletion ax/analysis/plotly/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def _prepare_figure(
if "status_quo" in df["arm_name"].values:
x = df[df["arm_name"] == "status_quo"][f"{x_metric_name}_mean"].iloc[0]
y = df[df["arm_name"] == "status_quo"][f"{y_metric_name}_mean"].iloc[0]
if not np.isnan(x) or not np.isnan(y):
if not np.isnan(x) and not np.isnan(y):
figure.add_shape(
type="line",
yref="paper",
Expand Down
2 changes: 1 addition & 1 deletion ax/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def prepare_arm_data(
else:
if additional_arms is not None:
raise UserInputError(
"Cannot provide additional arms when use_model_predictions=False since"
"Cannot provide additional arms when use_model_predictions=False since "
"there is no observed raw data for the additional arms that are not "
"part of the Experiment."
)
Expand Down
2 changes: 1 addition & 1 deletion ax/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,7 +1350,7 @@ def _from_json_snapshot(
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
)
if "generation_strategy" in snapshot
if snapshot.get("generation_strategy") is not None
else None
)

Expand Down
4 changes: 3 additions & 1 deletion ax/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,9 @@ def compute_baseline_value_from_sobol(
n_repeats: Number of times to repeat the five Sobol trials.
"""
method = get_sobol_benchmark_method()
target_fidelity_and_task = {} if target_fidelity_and_task is None else {}
target_fidelity_and_task = (
{} if target_fidelity_and_task is None else target_fidelity_and_task
)

# set up a dummy problem so we can use `benchmark_replication`
# MOO problems are always higher-is-better because they use hypervolume
Expand Down
2 changes: 1 addition & 1 deletion ax/core/arm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def md5hash(parameters: Mapping[str, TParamValue]) -> str:
new_parameters = {}
for k, v in parameters.items():
new_parameters[k] = numpy_type_to_python_type(v)
parameters_str = json.dumps(parameters, sort_keys=True)
parameters_str = json.dumps(new_parameters, sort_keys=True)
return hashlib.md5(parameters_str.encode("utf-8")).hexdigest()

def clone(self, clear_name: bool = False) -> "Arm":
Expand Down
2 changes: 1 addition & 1 deletion ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ def bulk_configure_metrics_of_class(
UserInputError(
f"Metric class {metric_class} does not contain the requested "
"attributes to update. Requested updates to attributes: "
f"{set(attributes_to_update.keys())} but metric class defines"
f"{set(attributes_to_update.keys())} but metric class defines "
f"{metric_attributes}."
)
)
Expand Down
8 changes: 8 additions & 0 deletions ax/core/generator_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,14 @@ def clone(self) -> GeneratorRun:
if self._generator_state_after_gen is not None
else None
)
generator_run._gen_metadata = (
self._gen_metadata.copy() if self._gen_metadata is not None else None
)
generator_run._candidate_metadata_by_arm_signature = (
self._candidate_metadata_by_arm_signature.copy()
if self._candidate_metadata_by_arm_signature is not None
else None
)
return generator_run

def add_arm(self, arm: Arm, weight: float = 1.0) -> None:
Expand Down
2 changes: 1 addition & 1 deletion ax/core/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def period_of_new_data_after_trial_completion(cls) -> timedelta:
return timedelta(0)

@classmethod
def is_reconverable_fetch_e(cls, metric_fetch_e: MetricFetchE) -> bool:
def is_recoverable_fetch_e(cls, metric_fetch_e: MetricFetchE) -> bool:
"""Checks whether the given MetricFetchE is recoverable for this metric class
in ``orchestrator._fetch_and_process_trials_data_results``.
"""
Expand Down
6 changes: 3 additions & 3 deletions ax/core/outcome_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

CONSTRAINT_THRESHOLD_WARNING_MESSAGE: str = (
"Constraint threshold on {name} appears invalid: {bound} bound on metric "
+ "for which {is_better} values is better."
+ "for which {is_better} values are better."
)
UPPER_BOUND_THRESHOLD: dict[str, str] = {"bound": "Positive", "is_better": "lower"}
LOWER_BOUND_THRESHOLD: dict[str, str] = {"bound": "Negative", "is_better": "higher"}
Expand Down Expand Up @@ -120,7 +120,7 @@ def _validate_metric_constraint_op(
if op == ComparisonOp.LEQ and not metric.lower_is_better:
fmt_data = UPPER_BOUND_MISMATCH
if fmt_data is not None:
fmt_data["name"] = metric.name
fmt_data = {**fmt_data, "name": metric.name}
msg = CONSTRAINT_WARNING_MESSAGE.format(**fmt_data)
logger.debug(msg)
return False, msg
Expand Down Expand Up @@ -156,7 +156,7 @@ def _validate_constraint(self) -> tuple[bool, str]:
if self.bound > 0 and not self.metric.lower_is_better:
fmt_data = LOWER_BOUND_THRESHOLD
if fmt_data is not None:
fmt_data["name"] = self.metric.name
fmt_data = {**fmt_data, "name": self.metric.name}
msg += CONSTRAINT_THRESHOLD_WARNING_MESSAGE.format(**fmt_data)
logger.debug(msg)
return False, msg
Expand Down
4 changes: 2 additions & 2 deletions ax/core/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,13 +423,13 @@ def _validate_range_param(
ParameterType.FLOAT,
):
raise UserInputError(
f"RangeParameter {self.name}type must be int or float."
f"RangeParameter {self.name} type must be int or float."
)

upper = float(upper)
if lower >= upper:
raise UserInputError(
f"Upper bound of {self.name} must be strictly larger than lower."
f"Upper bound of {self.name} must be strictly larger than lower. "
f"Got: ({lower}, {upper})."
)
width: float = upper - lower
Expand Down
2 changes: 1 addition & 1 deletion ax/core/parameter_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def validate_constraint_parameters(parameters: Sequence[Parameter]) -> None:
for parameter in parameters:
if not isinstance(parameter, RangeParameter):
raise ValueError(
"All parameters in a parameter constraint must be RangeParameters."
"All parameters in a parameter constraint must be RangeParameters. "
f"Found {parameter}"
)

Expand Down
8 changes: 4 additions & 4 deletions ax/core/tests/test_outcome_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ def test_OutcomeConstraintFail(self) -> None:
metric=self.minimize_metric, op=ComparisonOp.GEQ, bound=self.bound
)
mock_warning.debug.assert_called_once_with(
CONSTRAINT_WARNING_MESSAGE.format(**LOWER_BOUND_MISMATCH)
CONSTRAINT_WARNING_MESSAGE.format(**LOWER_BOUND_MISMATCH, name="bar")
)
with mock.patch(logger_name) as mock_warning:
OutcomeConstraint(
metric=self.maximize_metric, op=ComparisonOp.LEQ, bound=self.bound
)
mock_warning.debug.assert_called_once_with(
CONSTRAINT_WARNING_MESSAGE.format(**UPPER_BOUND_MISMATCH)
CONSTRAINT_WARNING_MESSAGE.format(**UPPER_BOUND_MISMATCH, name="baz")
)

def test_Sortable(self) -> None:
Expand Down Expand Up @@ -144,14 +144,14 @@ def test_ObjectiveThresholdFail(self) -> None:
metric=self.minimize_metric, op=ComparisonOp.GEQ, bound=self.bound
)
mock_warning.debug.assert_called_once_with(
CONSTRAINT_WARNING_MESSAGE.format(**LOWER_BOUND_MISMATCH)
CONSTRAINT_WARNING_MESSAGE.format(**LOWER_BOUND_MISMATCH, name="bar")
)
with mock.patch(logger_name) as mock_warning:
ObjectiveThreshold(
metric=self.maximize_metric, op=ComparisonOp.LEQ, bound=self.bound
)
mock_warning.debug.assert_called_once_with(
CONSTRAINT_WARNING_MESSAGE.format(**UPPER_BOUND_MISMATCH)
CONSTRAINT_WARNING_MESSAGE.format(**UPPER_BOUND_MISMATCH, name="baz")
)

def test_Relativize(self) -> None:
Expand Down
18 changes: 1 addition & 17 deletions ax/early_stopping/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import logging
from abc import ABC, abstractmethod
from collections.abc import Iterable, Sequence
from collections.abc import Iterable
from logging import Logger
from typing import cast

Expand Down Expand Up @@ -665,19 +665,3 @@ def _lookup_and_validate_data(
full_df = full_df[full_df[MAP_KEY] >= self.min_progression_modeling]
map_data = Data(df=full_df)
return map_data

def get_training_data(
self,
experiment: Experiment,
map_data: Data,
max_training_size: int | None = None,
outcomes: Sequence[str] | None = None,
parameters: list[str] | None = None,
) -> None:
# Deprecated in Ax 1.1.0, so should be removed in Ax 1.2.0+.
raise DeprecationWarning(
"`ModelBasedEarlyStoppingStrategy.get_training_data` is deprecated. "
"Subclasses should either extract the training data manually, "
"or rely on the fitted surrogates available in the current generation "
"node that is passed into `should_stop_trials_early`."
)
12 changes: 8 additions & 4 deletions ax/generation_strategy/dispatch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def _make_sobol_step(
generator=Generators.SOBOL,
num_trials=num_trials,
# NOTE: ceil(-1 / 2) = 0, so this is safe to do when num trials is -1.
min_trials_observed=min_trials_observed or ceil(num_trials / 2),
min_trials_observed=(
ceil(num_trials / 2) if min_trials_observed is None else min_trials_observed
),
enforce_num_trials=enforce_num_trials,
max_parallelism=max_parallelism,
generator_kwargs={"deduplicate": True, "seed": seed},
Expand Down Expand Up @@ -124,7 +126,9 @@ def _make_botorch_step(
generator=generator,
num_trials=num_trials,
# NOTE: ceil(-1 / 2) = 0, so this is safe to do when num trials is -1.
min_trials_observed=min_trials_observed or ceil(num_trials / 2),
min_trials_observed=(
ceil(num_trials / 2) if min_trials_observed is None else min_trials_observed
),
enforce_num_trials=enforce_num_trials,
max_parallelism=max_parallelism,
generator_kwargs=generator_kwargs,
Expand Down Expand Up @@ -432,13 +436,13 @@ def choose_generation_strategy_legacy(

if not force_random_search and suggested_model is not None:
if not enforce_sequential_optimization and (
max_parallelism_override or max_parallelism_cap
max_parallelism_override is not None or max_parallelism_cap is not None
):
logger.info(
"If `enforce_sequential_optimization` is False, max parallelism is "
"not enforced and other max parallelism settings will be ignored."
)
if max_parallelism_override and max_parallelism_cap:
if max_parallelism_override is not None and max_parallelism_cap is not None:
raise ValueError(
"If `max_parallelism_override` specified, cannot also apply "
"`max_parallelism_cap`."
Expand Down
2 changes: 1 addition & 1 deletion ax/generation_strategy/external_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def _gen(
if pending_observations:
for obs in pending_observations.values():
for o in obs:
if o not in pending_parameters:
if o.parameters not in pending_parameters:
pending_parameters.append(o.parameters)
generated_params: list[TParameterization] = []
for _ in range(n):
Expand Down
8 changes: 4 additions & 4 deletions ax/generation_strategy/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,9 @@ def gen_single_trial(
# list of GeneratorRuns for the first (and only) trial.
if len(grs_for_trials) != 1 or len(grs := grs_for_trials[0]) != 1:
raise AxError( # Unexpected state of the GS; raise informatively.
"By calling into GenerationStrategy.gen_single_trial(), you are should"
" be expecting a single `Trial` with only one `GeneratorRun`. However,"
"the underlying GenerationStrategy returned the following list "
"By calling into GenerationStrategy.gen_single_trial(), you should "
"be expecting a single `Trial` with only one `GeneratorRun`. However, "
"the underlying GenerationStrategy returned the following list"
f" of `GeneratorRun`-s: {grs_for_trials}."
)
return grs[0]
Expand Down Expand Up @@ -543,7 +543,7 @@ def _gen_with_multiple_nodes(
)
node_to_gen_from = self.nodes_by_name[node_to_gen_from_name]
if should_transition:
node_to_gen_from._previous_node_name = node_to_gen_from_name
node_to_gen_from._previous_node_name = self._curr.name
# reset should skip as conditions may have changed, do not reset
# until now so node properties can be as up to date as possible
node_to_gen_from._should_skip = False
Expand Down
Loading