diff --git a/ax/adapter/adapter_utils.py b/ax/adapter/adapter_utils.py index 24a1fa52475..db7c00b64c5 100644 --- a/ax/adapter/adapter_utils.py +++ b/ax/adapter/adapter_utils.py @@ -228,7 +228,7 @@ def extract_objective_thresholds( # Check that all thresholds correspond to a metric. if set(objective_threshold_dict.keys()).difference(set(objective.metric_names)): raise ValueError( - "Some objective thresholds do not have corresponding metrics." + "Some objective thresholds do not have corresponding metrics. " f"Got {objective_thresholds=} and {objective=}." ) @@ -566,12 +566,10 @@ def get_pareto_frontier_and_configs( "`observation_data` will not be used.", stacklevel=2, ) - else: - if observation_data is None: - raise ValueError( - "`observation_data` must not be None when `use_model_predictions` is " - "True." - ) + elif observation_data is None: + raise ValueError( + "`observation_data` must not be None when `use_model_predictions` is False." + ) array_to_tensor = adapter._array_to_tensor if use_model_predictions: diff --git a/ax/adapter/torch.py b/ax/adapter/torch.py index 2c73f9971d6..a1a7bf6a2e3 100644 --- a/ax/adapter/torch.py +++ b/ax/adapter/torch.py @@ -693,7 +693,7 @@ def _update_w_aux_exp_datasets( if pe_data.df.empty: raise DataRequiredError( "No data found in the auxiliary preference exploration " - "experiment. Play the preference game first or use another" + "experiment. Play the preference game first or use another " "preference profile with recorded preference data." ) diff --git a/ax/adapter/transforms/int_range_to_choice.py b/ax/adapter/transforms/int_range_to_choice.py index 2143ce75f81..6cb3fd4fb49 100644 --- a/ax/adapter/transforms/int_range_to_choice.py +++ b/ax/adapter/transforms/int_range_to_choice.py @@ -67,18 +67,13 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace: and p.cardinality() <= self.max_choices ): values = list(range(int(p.lower), int(p.upper) + 1)) - target_value = ( - None - if p.target_value is None - else next(i for i, v in enumerate(values) if v == p.target_value) - ) transformed_parameters[p_name] = ChoiceParameter( name=p_name, parameter_type=p.parameter_type, values=values, # pyre-fixme[6] is_ordered=True, is_fidelity=p.is_fidelity, - target_value=target_value, + target_value=p.target_value, ) else: transformed_parameters[p.name] = p diff --git a/ax/adapter/transforms/relativize.py b/ax/adapter/transforms/relativize.py index 4e108fa6435..d7d4c041f42 100644 --- a/ax/adapter/transforms/relativize.py +++ b/ax/adapter/transforms/relativize.py @@ -335,7 +335,7 @@ def get_metric_index(data: ObservationData, metric_signature: str) -> int: """Get the index of a metric in the ObservationData.""" try: return data.metric_signatures.index(metric_signature) - except (IndexError, StopIteration): + except ValueError: raise ValueError( "Relativization cannot be performed because " "ObservationData for status quo is missing metrics" diff --git a/ax/adapter/transforms/tests/test_int_range_to_choice_transform.py b/ax/adapter/transforms/tests/test_int_range_to_choice_transform.py index cf093ce2052..3c28d96bd45 100644 --- a/ax/adapter/transforms/tests/test_int_range_to_choice_transform.py +++ b/ax/adapter/transforms/tests/test_int_range_to_choice_transform.py @@ -13,6 +13,7 @@ from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter from ax.core.search_space import SearchSpace from ax.utils.common.testutils import TestCase +from pyre_extensions import assert_is_instance class IntRangeToChoiceTransformTest(TestCase): @@ -20,7 +21,13 @@ def setUp(self) -> None: super().setUp() self.search_space = SearchSpace( parameters=[ - RangeParameter("a", lower=1, upper=5, parameter_type=ParameterType.INT), + RangeParameter( + "a", + lower=1, + upper=5, + parameter_type=ParameterType.INT, + target_value=2, + ), ChoiceParameter( "b", parameter_type=ParameterType.STRING, values=["a", "b", "c"] ), @@ -43,9 +50,10 @@ def test_TransformObservationFeatures(self) -> None: def test_TransformSearchSpace(self) -> None: ss2 = deepcopy(self.search_space) ss2 = self.t.transform_search_space(ss2) - self.assertTrue(isinstance(ss2.parameters["a"], ChoiceParameter)) - # pyre-fixme[16]: `Parameter` has no attribute `values`. - self.assertTrue(ss2.parameters["a"].values, [1, 2, 3, 4, 5]) + new_a = assert_is_instance(ss2.parameters["a"], ChoiceParameter) + self.assertTrue(new_a, ChoiceParameter) + self.assertTrue(new_a.values, [1, 2, 3, 4, 5]) + self.assertEqual(new_a.target_value, 2) def test_num_choices(self) -> None: parameters = { diff --git a/ax/analysis/healthcheck/regression_detection_utils.py b/ax/analysis/healthcheck/regression_detection_utils.py index e2cc51f77d2..4db141d8807 100644 --- a/ax/analysis/healthcheck/regression_detection_utils.py +++ b/ax/analysis/healthcheck/regression_detection_utils.py @@ -93,7 +93,7 @@ def compute_regression_probabilities_single_trial( if len(metric_names) == 0: raise ValueError( - "No common metrics between the provided data and the size thresholds." + "No common metrics between the provided data and the size thresholds. " "Need to provide both data and the size thresholds for metrics of interest." ) @@ -159,7 +159,7 @@ def detect_regressions_single_trial( if len(metric_names) == 0: raise ValueError( - "No common metrics between the provided data and the thresholds." + "No common metrics between the provided data and the thresholds. " "Need to provide both data and the size thresholds for metrics of interest." ) diff --git a/ax/analysis/plotly/arm_effects.py b/ax/analysis/plotly/arm_effects.py index cd1a354a123..d1bf800959a 100644 --- a/ax/analysis/plotly/arm_effects.py +++ b/ax/analysis/plotly/arm_effects.py @@ -523,7 +523,7 @@ def _get_subtitle( f"{'Modeled' if use_model_predictions else 'Observed'} effects on " f"{metric_label}" ) - trial_clause = f" for Trial {trial_index}." if trial_index is not None else "" + trial_clause = f" for Trial {trial_index}" if trial_index is not None else "" first_sentence = f"{first_clause}{trial_clause}." if use_model_predictions: diff --git a/ax/analysis/plotly/scatter.py b/ax/analysis/plotly/scatter.py index e996a959e68..eb3903013ce 100644 --- a/ax/analysis/plotly/scatter.py +++ b/ax/analysis/plotly/scatter.py @@ -525,7 +525,7 @@ def _prepare_figure( if "status_quo" in df["arm_name"].values: x = df[df["arm_name"] == "status_quo"][f"{x_metric_name}_mean"].iloc[0] y = df[df["arm_name"] == "status_quo"][f"{y_metric_name}_mean"].iloc[0] - if not np.isnan(x) or not np.isnan(y): + if not np.isnan(x) and not np.isnan(y): figure.add_shape( type="line", yref="paper", diff --git a/ax/analysis/utils.py b/ax/analysis/utils.py index b2be7792686..194545d00c5 100644 --- a/ax/analysis/utils.py +++ b/ax/analysis/utils.py @@ -210,7 +210,7 @@ def prepare_arm_data( else: if additional_arms is not None: raise UserInputError( - "Cannot provide additional arms when use_model_predictions=False since" + "Cannot provide additional arms when use_model_predictions=False since " "there is no observed raw data for the additional arms that are not " "part of the Experiment." ) diff --git a/ax/api/client.py b/ax/api/client.py index 28c012971ed..5f9a0a1ab14 100644 --- a/ax/api/client.py +++ b/ax/api/client.py @@ -1350,7 +1350,7 @@ def _from_json_snapshot( decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ) - if "generation_strategy" in snapshot + if snapshot.get("generation_strategy") is not None else None ) diff --git a/ax/benchmark/benchmark.py b/ax/benchmark/benchmark.py index a7b93ee82a7..99c9ee3d110 100644 --- a/ax/benchmark/benchmark.py +++ b/ax/benchmark/benchmark.py @@ -698,7 +698,9 @@ def compute_baseline_value_from_sobol( n_repeats: Number of times to repeat the five Sobol trials. """ method = get_sobol_benchmark_method() - target_fidelity_and_task = {} if target_fidelity_and_task is None else {} + target_fidelity_and_task = ( + {} if target_fidelity_and_task is None else target_fidelity_and_task + ) # set up a dummy problem so we can use `benchmark_replication` # MOO problems are always higher-is-better because they use hypervolume diff --git a/ax/core/arm.py b/ax/core/arm.py index af2fc40b924..ffc54976b2e 100644 --- a/ax/core/arm.py +++ b/ax/core/arm.py @@ -90,7 +90,7 @@ def md5hash(parameters: Mapping[str, TParamValue]) -> str: new_parameters = {} for k, v in parameters.items(): new_parameters[k] = numpy_type_to_python_type(v) - parameters_str = json.dumps(parameters, sort_keys=True) + parameters_str = json.dumps(new_parameters, sort_keys=True) return hashlib.md5(parameters_str.encode("utf-8")).hexdigest() def clone(self, clear_name: bool = False) -> "Arm": diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 3662eb51cd2..0f2755be912 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -727,7 +727,7 @@ def bulk_configure_metrics_of_class( UserInputError( f"Metric class {metric_class} does not contain the requested " "attributes to update. Requested updates to attributes: " - f"{set(attributes_to_update.keys())} but metric class defines" + f"{set(attributes_to_update.keys())} but metric class defines " f"{metric_attributes}." ) ) diff --git a/ax/core/generator_run.py b/ax/core/generator_run.py index b4b17290662..2235e2c8e9a 100644 --- a/ax/core/generator_run.py +++ b/ax/core/generator_run.py @@ -343,6 +343,14 @@ def clone(self) -> GeneratorRun: if self._generator_state_after_gen is not None else None ) + generator_run._gen_metadata = ( + self._gen_metadata.copy() if self._gen_metadata is not None else None + ) + generator_run._candidate_metadata_by_arm_signature = ( + self._candidate_metadata_by_arm_signature.copy() + if self._candidate_metadata_by_arm_signature is not None + else None + ) return generator_run def add_arm(self, arm: Arm, weight: float = 1.0) -> None: diff --git a/ax/core/metric.py b/ax/core/metric.py index db006e36823..c5f2289c191 100644 --- a/ax/core/metric.py +++ b/ax/core/metric.py @@ -165,7 +165,7 @@ def period_of_new_data_after_trial_completion(cls) -> timedelta: return timedelta(0) @classmethod - def is_reconverable_fetch_e(cls, metric_fetch_e: MetricFetchE) -> bool: + def is_recoverable_fetch_e(cls, metric_fetch_e: MetricFetchE) -> bool: """Checks whether the given MetricFetchE is recoverable for this metric class in ``orchestrator._fetch_and_process_trials_data_results``. """ diff --git a/ax/core/outcome_constraint.py b/ax/core/outcome_constraint.py index f7f88386171..1d22616f3f4 100644 --- a/ax/core/outcome_constraint.py +++ b/ax/core/outcome_constraint.py @@ -28,7 +28,7 @@ CONSTRAINT_THRESHOLD_WARNING_MESSAGE: str = ( "Constraint threshold on {name} appears invalid: {bound} bound on metric " - + "for which {is_better} values is better." + + "for which {is_better} values are better." ) UPPER_BOUND_THRESHOLD: dict[str, str] = {"bound": "Positive", "is_better": "lower"} LOWER_BOUND_THRESHOLD: dict[str, str] = {"bound": "Negative", "is_better": "higher"} @@ -120,7 +120,7 @@ def _validate_metric_constraint_op( if op == ComparisonOp.LEQ and not metric.lower_is_better: fmt_data = UPPER_BOUND_MISMATCH if fmt_data is not None: - fmt_data["name"] = metric.name + fmt_data = {**fmt_data, "name": metric.name} msg = CONSTRAINT_WARNING_MESSAGE.format(**fmt_data) logger.debug(msg) return False, msg @@ -156,7 +156,7 @@ def _validate_constraint(self) -> tuple[bool, str]: if self.bound > 0 and not self.metric.lower_is_better: fmt_data = LOWER_BOUND_THRESHOLD if fmt_data is not None: - fmt_data["name"] = self.metric.name + fmt_data = {**fmt_data, "name": self.metric.name} msg += CONSTRAINT_THRESHOLD_WARNING_MESSAGE.format(**fmt_data) logger.debug(msg) return False, msg diff --git a/ax/core/parameter.py b/ax/core/parameter.py index 7591c525a35..e396c792677 100644 --- a/ax/core/parameter.py +++ b/ax/core/parameter.py @@ -423,13 +423,13 @@ def _validate_range_param( ParameterType.FLOAT, ): raise UserInputError( - f"RangeParameter {self.name}type must be int or float." + f"RangeParameter {self.name} type must be int or float." ) upper = float(upper) if lower >= upper: raise UserInputError( - f"Upper bound of {self.name} must be strictly larger than lower." + f"Upper bound of {self.name} must be strictly larger than lower. " f"Got: ({lower}, {upper})." ) width: float = upper - lower diff --git a/ax/core/parameter_constraint.py b/ax/core/parameter_constraint.py index ef1621ab36d..3eabdc6db64 100644 --- a/ax/core/parameter_constraint.py +++ b/ax/core/parameter_constraint.py @@ -129,7 +129,7 @@ def validate_constraint_parameters(parameters: Sequence[Parameter]) -> None: for parameter in parameters: if not isinstance(parameter, RangeParameter): raise ValueError( - "All parameters in a parameter constraint must be RangeParameters." + "All parameters in a parameter constraint must be RangeParameters. " f"Found {parameter}" ) diff --git a/ax/core/tests/test_outcome_constraint.py b/ax/core/tests/test_outcome_constraint.py index 58072e169e1..fcc6ed79cf5 100644 --- a/ax/core/tests/test_outcome_constraint.py +++ b/ax/core/tests/test_outcome_constraint.py @@ -63,14 +63,14 @@ def test_OutcomeConstraintFail(self) -> None: metric=self.minimize_metric, op=ComparisonOp.GEQ, bound=self.bound ) mock_warning.debug.assert_called_once_with( - CONSTRAINT_WARNING_MESSAGE.format(**LOWER_BOUND_MISMATCH) + CONSTRAINT_WARNING_MESSAGE.format(**LOWER_BOUND_MISMATCH, name="bar") ) with mock.patch(logger_name) as mock_warning: OutcomeConstraint( metric=self.maximize_metric, op=ComparisonOp.LEQ, bound=self.bound ) mock_warning.debug.assert_called_once_with( - CONSTRAINT_WARNING_MESSAGE.format(**UPPER_BOUND_MISMATCH) + CONSTRAINT_WARNING_MESSAGE.format(**UPPER_BOUND_MISMATCH, name="baz") ) def test_Sortable(self) -> None: @@ -144,14 +144,14 @@ def test_ObjectiveThresholdFail(self) -> None: metric=self.minimize_metric, op=ComparisonOp.GEQ, bound=self.bound ) mock_warning.debug.assert_called_once_with( - CONSTRAINT_WARNING_MESSAGE.format(**LOWER_BOUND_MISMATCH) + CONSTRAINT_WARNING_MESSAGE.format(**LOWER_BOUND_MISMATCH, name="bar") ) with mock.patch(logger_name) as mock_warning: ObjectiveThreshold( metric=self.maximize_metric, op=ComparisonOp.LEQ, bound=self.bound ) mock_warning.debug.assert_called_once_with( - CONSTRAINT_WARNING_MESSAGE.format(**UPPER_BOUND_MISMATCH) + CONSTRAINT_WARNING_MESSAGE.format(**UPPER_BOUND_MISMATCH, name="baz") ) def test_Relativize(self) -> None: diff --git a/ax/early_stopping/strategies/base.py b/ax/early_stopping/strategies/base.py index 796a4d8e5f6..15c340fb72c 100644 --- a/ax/early_stopping/strategies/base.py +++ b/ax/early_stopping/strategies/base.py @@ -8,7 +8,7 @@ import logging from abc import ABC, abstractmethod -from collections.abc import Iterable, Sequence +from collections.abc import Iterable from logging import Logger from typing import cast @@ -665,19 +665,3 @@ def _lookup_and_validate_data( full_df = full_df[full_df[MAP_KEY] >= self.min_progression_modeling] map_data = Data(df=full_df) return map_data - - def get_training_data( - self, - experiment: Experiment, - map_data: Data, - max_training_size: int | None = None, - outcomes: Sequence[str] | None = None, - parameters: list[str] | None = None, - ) -> None: - # Deprecated in Ax 1.1.0, so should be removed in Ax 1.2.0+. - raise DeprecationWarning( - "`ModelBasedEarlyStoppingStrategy.get_training_data` is deprecated. " - "Subclasses should either extract the training data manually, " - "or rely on the fitted surrogates available in the current generation " - "node that is passed into `should_stop_trials_early`." - ) diff --git a/ax/generation_strategy/dispatch_utils.py b/ax/generation_strategy/dispatch_utils.py index ba5304a2ed0..f2721d47fc6 100644 --- a/ax/generation_strategy/dispatch_utils.py +++ b/ax/generation_strategy/dispatch_utils.py @@ -58,7 +58,9 @@ def _make_sobol_step( generator=Generators.SOBOL, num_trials=num_trials, # NOTE: ceil(-1 / 2) = 0, so this is safe to do when num trials is -1. - min_trials_observed=min_trials_observed or ceil(num_trials / 2), + min_trials_observed=( + ceil(num_trials / 2) if min_trials_observed is None else min_trials_observed + ), enforce_num_trials=enforce_num_trials, max_parallelism=max_parallelism, generator_kwargs={"deduplicate": True, "seed": seed}, @@ -124,7 +126,9 @@ def _make_botorch_step( generator=generator, num_trials=num_trials, # NOTE: ceil(-1 / 2) = 0, so this is safe to do when num trials is -1. - min_trials_observed=min_trials_observed or ceil(num_trials / 2), + min_trials_observed=( + ceil(num_trials / 2) if min_trials_observed is None else min_trials_observed + ), enforce_num_trials=enforce_num_trials, max_parallelism=max_parallelism, generator_kwargs=generator_kwargs, @@ -432,13 +436,13 @@ def choose_generation_strategy_legacy( if not force_random_search and suggested_model is not None: if not enforce_sequential_optimization and ( - max_parallelism_override or max_parallelism_cap + max_parallelism_override is not None or max_parallelism_cap is not None ): logger.info( "If `enforce_sequential_optimization` is False, max parallelism is " "not enforced and other max parallelism settings will be ignored." ) - if max_parallelism_override and max_parallelism_cap: + if max_parallelism_override is not None and max_parallelism_cap is not None: raise ValueError( "If `max_parallelism_override` specified, cannot also apply " "`max_parallelism_cap`." diff --git a/ax/generation_strategy/external_generation_node.py b/ax/generation_strategy/external_generation_node.py index b5192a97b9d..043d1131cb2 100644 --- a/ax/generation_strategy/external_generation_node.py +++ b/ax/generation_strategy/external_generation_node.py @@ -196,7 +196,7 @@ def _gen( if pending_observations: for obs in pending_observations.values(): for o in obs: - if o not in pending_parameters: + if o.parameters not in pending_parameters: pending_parameters.append(o.parameters) generated_params: list[TParameterization] = [] for _ in range(n): diff --git a/ax/generation_strategy/generation_strategy.py b/ax/generation_strategy/generation_strategy.py index 2076605ec43..f44a5fdebcd 100644 --- a/ax/generation_strategy/generation_strategy.py +++ b/ax/generation_strategy/generation_strategy.py @@ -240,9 +240,9 @@ def gen_single_trial( # list of GeneratorRuns for the first (and only) trial. if len(grs_for_trials) != 1 or len(grs := grs_for_trials[0]) != 1: raise AxError( # Unexpected state of the GS; raise informatively. - "By calling into GenerationStrategy.gen_single_trial(), you are should" - " be expecting a single `Trial` with only one `GeneratorRun`. However," - "the underlying GenerationStrategy returned the following list " + "By calling into GenerationStrategy.gen_single_trial(), you should " + "be expecting a single `Trial` with only one `GeneratorRun`. However, " + "the underlying GenerationStrategy returned the following list" f" of `GeneratorRun`-s: {grs_for_trials}." ) return grs[0] @@ -543,7 +543,7 @@ def _gen_with_multiple_nodes( ) node_to_gen_from = self.nodes_by_name[node_to_gen_from_name] if should_transition: - node_to_gen_from._previous_node_name = node_to_gen_from_name + node_to_gen_from._previous_node_name = self._curr.name # reset should skip as conditions may have changed, do not reset # until now so node properties can be as up to date as possible node_to_gen_from._should_skip = False diff --git a/ax/generation_strategy/transition_criterion.py b/ax/generation_strategy/transition_criterion.py index 7b4aceb6c46..363b6a96487 100644 --- a/ax/generation_strategy/transition_criterion.py +++ b/ax/generation_strategy/transition_criterion.py @@ -28,7 +28,7 @@ DATA_REQUIRED_MSG = ( "All trials for current node {node_name} have been generated, " "but not enough data has been observed to proceed to the next " - "Generation node. Try again when more is are available." + "Generation node. Try again when more data is available." ) diff --git a/ax/generators/torch/botorch_modular/utils.py b/ax/generators/torch/botorch_modular/utils.py index 02eaeff81d8..ed6b1135c5b 100644 --- a/ax/generators/torch/botorch_modular/utils.py +++ b/ax/generators/torch/botorch_modular/utils.py @@ -466,7 +466,7 @@ def construct_acquisition_and_optimizer_options( if len(botorch_acqf_classes_with_options) > 1: warnings.warn( message="botorch_acqf_options are being ignored, due to using " - "MultiAcquisition. Specify options for each acquistion function" + "MultiAcquisition. Specify options for each acquisition function " "via botorch_acqf_classes_with_options.", category=AxWarning, stacklevel=4, diff --git a/ax/generators/torch/tests/test_utils.py b/ax/generators/torch/tests/test_utils.py index c0a047dc0bc..11bb4f7d37e 100644 --- a/ax/generators/torch/tests/test_utils.py +++ b/ax/generators/torch/tests/test_utils.py @@ -465,7 +465,7 @@ def test_construct_acquisition_and_optimizer_options(self) -> None: self.assertEqual( str(warning.message), "botorch_acqf_options are being ignored, due to using " - "MultiAcquisition. Specify options for each acquistion function" + "MultiAcquisition. Specify options for each acquisition function " "via botorch_acqf_classes_with_options.", ) diff --git a/ax/global_stopping/strategies/improvement.py b/ax/global_stopping/strategies/improvement.py index 78529091155..777097f1781 100644 --- a/ax/global_stopping/strategies/improvement.py +++ b/ax/global_stopping/strategies/improvement.py @@ -130,8 +130,8 @@ def _should_stop_optimization( trial_to_check = max_completed_trial elif trial_to_check > max_completed_trial: raise ValueError( - "trial_to_check is larger than the total number of " - f"trials (={max_completed_trial})." + "trial_to_check is larger than the maximum completed " + f"trial index (={max_completed_trial})." ) # Only counting the trials up to trial_to_check. diff --git a/ax/global_stopping/tests/test_strategies.py b/ax/global_stopping/tests/test_strategies.py index 363bfe234d8..1b1d627f0b3 100644 --- a/ax/global_stopping/tests/test_strategies.py +++ b/ax/global_stopping/tests/test_strategies.py @@ -85,7 +85,7 @@ def test_base_cases(self) -> None: # Should raise ValueError if trying to check an invalid trial with self.assertRaisesRegex( ValueError, - r"trial_to_check is larger than the total number of trials \(=4\).", + r"trial_to_check is larger than the maximum completed trial index \(=4\).", ): stop, message = gss.should_stop_optimization( experiment=exp, trial_to_check=5 diff --git a/ax/metrics/branin_map.py b/ax/metrics/branin_map.py index afa9be9e782..a103b556a55 100644 --- a/ax/metrics/branin_map.py +++ b/ax/metrics/branin_map.py @@ -117,9 +117,8 @@ def fetch_trial_data( "sem": self.noise_sd if noisy else 0.0, "trial_index": trial.index, "mean": [ - item["mean"] + self.noise_sd * np.random.randn() - if noisy - else 0.0 + item["mean"] + + (self.noise_sd * np.random.randn() if noisy else 0.0) for item in res ], "metric_signature": self.signature, diff --git a/ax/metrics/noisy_function_map.py b/ax/metrics/noisy_function_map.py index e89e7eba9ed..1bcb37d164d 100644 --- a/ax/metrics/noisy_function_map.py +++ b/ax/metrics/noisy_function_map.py @@ -84,9 +84,8 @@ def fetch_trial_data( "sem": self.noise_sd if noisy else 0.0, "trial_index": trial.index, "mean": [ - item["mean"] + self.noise_sd * np.random.randn() - if noisy - else 0.0 + item["mean"] + + (self.noise_sd * np.random.randn() if noisy else 0.0) for item in res ], "metric_signature": self.signature, diff --git a/ax/orchestration/orchestrator.py b/ax/orchestration/orchestrator.py index 8d4804893b7..a35cfe8aa51 100644 --- a/ax/orchestration/orchestrator.py +++ b/ax/orchestration/orchestrator.py @@ -1227,7 +1227,7 @@ def _check_exit_status_and_report_results( if not self.options.wait_for_running_trials: return None return self.wait_for_completed_trials_and_report_results( - idle_callback, force_refit=True + idle_callback, force_refit=force_refit ) def run(self, max_new_trials: int, timeout_hours: float | None = None) -> bool: @@ -2093,9 +2093,9 @@ def _fetch_and_process_trials_data_results( if ( optimization_config is not None and metric_name in optimization_config.metrics.keys() - and not self.experiment.metrics[ - metric_name - ].is_reconverable_fetch_e(metric_fetch_e=metric_fetch_e) + and not self.experiment.metrics[metric_name].is_recoverable_fetch_e( + metric_fetch_e=metric_fetch_e + ) ): status = self._mark_err_trial_status( trial=self.experiment.trials[trial_index], diff --git a/ax/plot/pareto_utils.py b/ax/plot/pareto_utils.py index a11af3c67dc..d9dce954597 100644 --- a/ax/plot/pareto_utils.py +++ b/ax/plot/pareto_utils.py @@ -387,7 +387,7 @@ def compute_posterior_pareto_frontier( try: data = ( experiment.trials[trial_index].fetch_data() - if trial_index + if trial_index is not None else experiment.fetch_data() ) except Exception as e: diff --git a/ax/plot/scatter.py b/ax/plot/scatter.py index 27af4eb3c97..661be2354ae 100644 --- a/ax/plot/scatter.py +++ b/ax/plot/scatter.py @@ -520,12 +520,12 @@ def plot_multiple_metrics( }, ], xaxis={ - "title": metric_x + (" (%)" if rel else ""), + "title": metric_x + (" (%)" if rel_x else ""), "zeroline": True, "zerolinecolor": "red", }, yaxis={ - "title": metric_y + (" (%)" if rel else ""), + "title": metric_y + (" (%)" if rel_y else ""), "zeroline": True, "zerolinecolor": "red", }, diff --git a/ax/runners/map_replay.py b/ax/runners/map_replay.py index 1b192c3d715..59ada1248c5 100644 --- a/ax/runners/map_replay.py +++ b/ax/runners/map_replay.py @@ -41,7 +41,7 @@ def poll_trial_status( # depending on whether or not there is more data available, # mark it either RUNNING or COMPLETED. for t in trials: - if not t.run_metadata.get(STARTED_KEY, "False"): + if not t.run_metadata.get(STARTED_KEY, False): result[TrialStatus.CANDIDATE].add(t.index) elif not self.replay_metric.has_trial_data(t.index): result[TrialStatus.ABANDONED].add(t.index) diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 592c316f037..cb49ed1e39c 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -877,7 +877,12 @@ def get_max_parallelism(self) -> list[tuple[int, int]]: num_trials = node.num_trials except UserInputError: num_trials = -1 - parallelism_settings.append((num_trials, max_parallelism or num_trials)) + parallelism_settings.append( + ( + num_trials, + max_parallelism if max_parallelism is not None else num_trials, + ) + ) return parallelism_settings def get_optimization_trace( @@ -1044,7 +1049,7 @@ def get_feature_importances(self, relative: bool = True) -> AxPlotConfig: raise ValueError( "Could not obtain feature_importances for any metrics " - " as a model that can produce feature importances, such as a " + "as a model that can produce feature importances, such as a " "Gaussian Process, has not yet been trained in the course " "of this optimization." ) @@ -1684,7 +1689,7 @@ def _validate_early_stopping_strategy( if self._early_stopping_strategy is not None and not support_intermediate_data: raise ValueError( "Early stopping is only supported for experiments which allow " - " reporting intermediate trial data by setting passing " + "reporting intermediate trial data by passing " "`support_intermediate_data=True`." ) diff --git a/ax/service/managed_loop.py b/ax/service/managed_loop.py index 114eaa4a026..9225d3c673d 100644 --- a/ax/service/managed_loop.py +++ b/ax/service/managed_loop.py @@ -232,7 +232,7 @@ def full_run(self) -> OptimizationLoop: self.run_trial() except SearchSpaceExhausted as err: logger.info( - f"Stopped optimization as the search space is exhaused. Message " + f"Stopped optimization as the search space is exhausted. Message " f"from generation strategy: {err}." ) return self diff --git a/ax/service/utils/best_point.py b/ax/service/utils/best_point.py index 1670c83d783..3636adc65db 100644 --- a/ax/service/utils/best_point.py +++ b/ax/service/utils/best_point.py @@ -49,7 +49,6 @@ from ax.utils.common.logger import get_logger from ax.utils.preference.preference_utils import get_preference_adapter from botorch.utils.multi_objective.box_decompositions import DominatedPartitioning -from numpy import nan from numpy.typing import NDArray from pyre_extensions import assert_is_instance, none_throws @@ -646,7 +645,7 @@ def _is_all_noiseless(df: pd.DataFrame, metric_name: str) -> bool: name_mask = df["metric_name"] == metric_name df_metric_arms_sems = df[name_mask]["sem"] - return ((df_metric_arms_sems == 0) | df_metric_arms_sems == nan).all() + return ((df_metric_arms_sems == 0) | df_metric_arms_sems.isna()).all() def get_values_of_outcomes_single_or_scalarized_objective( diff --git a/ax/service/utils/instantiation.py b/ax/service/utils/instantiation.py index d08849d6f0d..4f35e10deda 100644 --- a/ax/service/utils/instantiation.py +++ b/ax/service/utils/instantiation.py @@ -354,7 +354,7 @@ def parameter_from_json( unexpected_keys = set(representation.keys()) - EXPECTED_KEYS_IN_PARAM_REPR if unexpected_keys: raise ValueError( - f"Unexpected keys {unexpected_keys} in parameter representation." + f"Unexpected keys {unexpected_keys} in parameter representation. " f"Exhaustive set of expected keys: {EXPECTED_KEYS_IN_PARAM_REPR}." ) name = representation["name"] @@ -752,7 +752,7 @@ def make_search_space( logger.debug(f"Created search space: {ss}.") if ss.is_hierarchical: logger.debug( - "Hieararchical structure of the search space: \n" + "Hierarchical structure of the search space: \n" f"{ss.hierarchical_structure_str(parameter_names_only=True)}" ) @@ -1117,7 +1117,7 @@ def _process_monomial(monomial_str: str) -> tuple[float, str]: multiplier = 1.0 else: raise ValueError( - "Monomial format does not match `multiplier*parameter_name`." + "Monomial format does not match `multiplier*parameter_name`. " f"Got `{monomial_str}`." ) return multiplier, parameter diff --git a/ax/service/utils/report_utils.py b/ax/service/utils/report_utils.py index 4ddabe4c211..e5242ff8ac8 100644 --- a/ax/service/utils/report_utils.py +++ b/ax/service/utils/report_utils.py @@ -365,7 +365,7 @@ def get_standard_plots( logger.debug("Finished global sensitivity analysis.") except Exception as e: logger.debug( - f"Failed to compute signed global feature sensitivities: {e}" + f"Failed to compute signed global feature sensitivities: {e}. " "Trying to get unsigned feature sensitivities." ) try: @@ -1121,7 +1121,7 @@ def _update_fig_in_place(orchestrator: Orchestrator) -> None: new_fig = plot_fn(orchestrator) except RuntimeError as e: logging.warning( - f"Plotting function called via callback failed with error {e}." + f"Plotting function called via callback failed with error {e}. " "Skipping plot update." ) return diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 0beccd47c09..e31e63a5cb3 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -1089,10 +1089,10 @@ def generator_spec_from_json( kwargs = generator_spec_json.pop("model_kwargs", None) else: kwargs = generator_spec_json.pop("generator_kwargs", None) - for k in _DEPRECATED_GENERATOR_KWARGS: - # Remove deprecated model kwargs. - kwargs.pop(k, None) if kwargs is not None: + for k in _DEPRECATED_GENERATOR_KWARGS: + # Remove deprecated model kwargs. + kwargs.pop(k, None) kwargs = _sanitize_surrogate_spec_input(object_json=kwargs) if "model_gen_kwargs" in generator_spec_json: gen_kwargs = generator_spec_json.pop("model_gen_kwargs", None) diff --git a/ax/storage/sqa_store/db.py b/ax/storage/sqa_store/db.py index 19d883bfd37..9da5d98e4e2 100644 --- a/ax/storage/sqa_store/db.py +++ b/ax/storage/sqa_store/db.py @@ -218,8 +218,8 @@ def create_all_tables(engine: Engine) -> None: """ if engine.dialect.name == "mysql" and engine.dialect.default_schema_name == "ax": raise ValueError( - "The open-source Ax table creation is likely not applicable in this case," - + "please contact the Adaptive Experimentation team if you need help." + "The open-source Ax table creation is likely not applicable in this case, " + "please contact the Adaptive Experimentation team if you need help." ) Base.metadata.create_all(engine) diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index b96877a228f..56fbb1ddb5f 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -128,7 +128,7 @@ def validate_experiment_metadata( raise ValueError( f"An experiment already exists with the name {experiment.name}. " "If you need to override this existing experiment, first delete it " - "via `delete_experiment` in ax/ax/storage/sqa_store/delete.py, " + "via `delete_experiment` in ax/storage/sqa_store/delete.py, " "and then resave." ) @@ -391,14 +391,15 @@ def get_metric_type_and_properties( json blob. """ metric_class = type(metric) - metric_type = int(self.config.metric_registry.get(metric_class)) - if metric_type is None: + metric_type_or_none = self.config.metric_registry.get(metric_class) + if metric_type_or_none is None: raise SQAEncodeError( "Cannot encode metric to SQLAlchemy because metric's " f"subclass ({metric_class}) is missing from the registry. " "The metric registry currently contains the following: " f"{','.join(map(str, self.config.metric_registry.keys()))} " ) + metric_type = int(metric_type_or_none) properties = metric_class.serialize_init_args(obj=metric) return metric_type, object_to_json( @@ -646,13 +647,13 @@ def outcome_constraint_to_sqa( def scalarized_outcome_constraint_to_sqa( self, outcome_constraint: ScalarizedOutcomeConstraint ) -> SQAMetric: - """Convert Ax SCalarized OutcomeConstraint to SQLAlchemy.""" + """Convert Ax Scalarized OutcomeConstraint to SQLAlchemy.""" metrics, weights = outcome_constraint.metrics, outcome_constraint.weights if metrics is None or weights is None or len(metrics) != len(weights): raise SQAEncodeError( - "Metrics and weights in scalarized OutcomeConstraint \ - must be lists of equal length." + "Metrics and weights in scalarized OutcomeConstraint " + "must be lists of equal length." ) metrics_by_name = self.get_children_metrics_by_name( diff --git a/ax/storage/sqa_store/load.py b/ax/storage/sqa_store/load.py index bb2f6e65b88..30c19732f30 100644 --- a/ax/storage/sqa_store/load.py +++ b/ax/storage/sqa_store/load.py @@ -757,7 +757,7 @@ def _query_historical_experiments_given_parameters( ) ) .filter(SQAExperiment.is_test == False) # noqa E712 `is` won't work for SQA - .filter(SQAExperiment.id is not None) + .filter(SQAExperiment.id.isnot(None)) # Experiments with some data .join(SQAData, SQAParameter.experiment_id == SQAData.experiment_id) ) diff --git a/ax/storage/sqa_store/save.py b/ax/storage/sqa_store/save.py index cd3f6ba2eb3..846150ddb8d 100644 --- a/ax/storage/sqa_store/save.py +++ b/ax/storage/sqa_store/save.py @@ -242,7 +242,9 @@ def trial_to_reduced_state_sqa_encoder(t: BaseTrial) -> SQATrial: objs=trials_to_reduce_state, encode_func=trial_to_reduced_state_sqa_encoder, decode_func=decoder.trial_from_sqa, - decode_args_list=[{"experiment": experiment} for _ in range(len(trials))], + decode_args_list=[ + {"experiment": experiment} for _ in range(len(trials_to_reduce_state)) + ], modify_sqa=add_experiment_id, batch_size=batch_size, ) @@ -251,7 +253,7 @@ def trial_to_reduced_state_sqa_encoder(t: BaseTrial) -> SQATrial: objs=[latest_trial], encode_func=encoder.trial_to_sqa, decode_func=decoder.trial_from_sqa, - decode_args_list=[{"experiment": experiment} for _ in range(len(trials))], + decode_args_list=[{"experiment": experiment}], modify_sqa=add_experiment_id, batch_size=batch_size, ) diff --git a/ax/storage/sqa_store/utils.py b/ax/storage/sqa_store/utils.py index 3fc233ff15e..5d3d89cbf7a 100644 --- a/ax/storage/sqa_store/utils.py +++ b/ax/storage/sqa_store/utils.py @@ -46,7 +46,9 @@ "_metric_fetching_errors", "_data_rows", } -SKIP_ATTRS_ERROR_SUFFIX = "Consider adding to COPY_DB_IDS_ATTRS_TO_SKIP if appropriate." +SKIP_ATTRS_ERROR_SUFFIX = ( + " Consider adding to COPY_DB_IDS_ATTRS_TO_SKIP if appropriate." +) def is_foreign_key_field(field: str) -> bool: @@ -74,7 +76,7 @@ def copy_db_ids(source: Any, target: Any, path: list[str] | None = None) -> None if len(path) > 15: # This shouldn't happen, but is a precaution against accidentally # introducing infinite loops - raise SQADecodeError(error_message_prefix + "Encountered path of length > 10.") + raise SQADecodeError(error_message_prefix + "Encountered path of length > 15.") if type(source) is not type(target): if not issubclass(type(target), type(source)): @@ -111,7 +113,7 @@ def copy_db_ids(source: Any, target: Any, path: list[str] | None = None) -> None source_json = getattr(source, attr) target_json = getattr(target, attr) if source_json != target_json: - SQADecodeError( + raise SQADecodeError( error_message_prefix + f"Json attribute {attr} not matching " f"between source: {source_json} and target: {target_json}." ) diff --git a/ax/storage/sqa_store/with_db_settings_base.py b/ax/storage/sqa_store/with_db_settings_base.py index 0defee1c634..54de2bd8486 100644 --- a/ax/storage/sqa_store/with_db_settings_base.py +++ b/ax/storage/sqa_store/with_db_settings_base.py @@ -105,8 +105,9 @@ def __init__( if db_settings and (not DBSettings or not isinstance(db_settings, DBSettings)): raise ValueError( "`db_settings` argument should be of type ax.storage.sqa_store." - f"(Got: {db_settings} of type {type(db_settings)}. " - "structs.DBSettings. To use `DBSettings`, you will need SQLAlchemy " + "structs.DBSettings. " + f"(Got: {db_settings} of type {type(db_settings)}). " + "To use `DBSettings`, you will need SQLAlchemy " "installed in your environment (can be installed through pip)." ) self._db_settings = db_settings or self._get_default_db_settings() diff --git a/ax/utils/common/equality.py b/ax/utils/common/equality.py index a45e753795b..8a6094e77f5 100644 --- a/ax/utils/common/equality.py +++ b/ax/utils/common/equality.py @@ -104,7 +104,7 @@ def datetime_equals(dt1: datetime | None, dt2: datetime | None) -> bool: return True if not (dt1 and dt2): return False - return (dt1 - dt2).total_seconds() < 1.0 + return abs((dt1 - dt2).total_seconds()) < 1.0 def dataframe_equals(df1: pd.DataFrame, df2: pd.DataFrame) -> bool: @@ -130,8 +130,8 @@ def object_attribute_dicts_equal( are the same. - NOTE: Special-cases some Ax object attributes, like "_experiment" or - "_model", where full equality is hard to check. + NOTE: Special-cases some Ax object attributes, like "_experiment", + where full equality is hard to check. Args: one_dict: First object's attribute dict (``obj.__dict__``). @@ -210,22 +210,6 @@ def object_attribute_dicts_find_unequal_fields( equal = one_val is other_val is None or (one_val.db_id == other_val.db_id) elif field == "_db_id": equal = skip_db_id_check or one_val == other_val - elif field == "_model": - # TODO[T52643706]: replace with per-`Adapter` method like - # `equivalent_models`, to compare models more meaningfully. - if not hasattr(one_val, "model") or not hasattr(other_val, "model"): - equal = not hasattr(other_val, "model") and not hasattr( - other_val, "model" - ) - else: - # If adapters have a `model` attribute, the types of the - # values of those attributes should be equal if the model - # adapter is the same. - equal = ( - hasattr(one_val, "model") - and hasattr(other_val, "model") - and isinstance(one_val.model, type(other_val.model)) - ) # Do not check the inequality_str for ParameterConstraints, checking the bound # and coefficients dict is sufficient. elif field == "_inequality_str": diff --git a/ax/utils/common/testutils.py b/ax/utils/common/testutils.py index e666dc95a7d..b12e032d0ab 100644 --- a/ax/utils/common/testutils.py +++ b/ax/utils/common/testutils.py @@ -445,8 +445,8 @@ def assertDictsAlmostEqual( set_a = set(a.keys()) set_b = set(b.keys()) key_msg = ( - "Dict keys differ." - f"Keys that are in a but not b: {set_a - set_b}." + "Dict keys differ. " + f"Keys that are in a but not b: {set_a - set_b}. " f"Keys that are in b but not a: {set_b - set_a}." ) self.assertEqual(set_a, set_b, msg=key_msg) diff --git a/ax/utils/stats/model_fit_stats.py b/ax/utils/stats/model_fit_stats.py index e66d3aaaa85..d9af572e73e 100644 --- a/ax/utils/stats/model_fit_stats.py +++ b/ax/utils/stats/model_fit_stats.py @@ -348,6 +348,7 @@ def _fisher_exact_test_p( TOTAL_RAW_EFFECT: ModelFitMetricDirection.MAXIMIZE, CORRELATION_COEFFICIENT: ModelFitMetricDirection.MAXIMIZE, RANK_CORRELATION: ModelFitMetricDirection.MAXIMIZE, + KENDALL_TAU_RANK_CORRELATION: ModelFitMetricDirection.MAXIMIZE, FISHER_EXACT_TEST_P: ModelFitMetricDirection.MINIMIZE, LOG_LIKELIHOOD: ModelFitMetricDirection.MAXIMIZE, MSE: ModelFitMetricDirection.MINIMIZE, diff --git a/ax/utils/stats/statstools.py b/ax/utils/stats/statstools.py index a73451cdbea..ad79772849e 100644 --- a/ax/utils/stats/statstools.py +++ b/ax/utils/stats/statstools.py @@ -201,7 +201,7 @@ def marginal_effects( ) for cov in covariates: if len(df[cov].unique()) <= 1: - next + continue df_gb = df.groupby(cov) for name, group_df in df_gb: group_mean, group_var = inverse_variance_weight(