diff --git a/ax/core/arm.py b/ax/core/arm.py index af2fc40b924..ffc54976b2e 100644 --- a/ax/core/arm.py +++ b/ax/core/arm.py @@ -90,7 +90,7 @@ def md5hash(parameters: Mapping[str, TParamValue]) -> str: new_parameters = {} for k, v in parameters.items(): new_parameters[k] = numpy_type_to_python_type(v) - parameters_str = json.dumps(parameters, sort_keys=True) + parameters_str = json.dumps(new_parameters, sort_keys=True) return hashlib.md5(parameters_str.encode("utf-8")).hexdigest() def clone(self, clear_name: bool = False) -> "Arm": diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 3662eb51cd2..0f2755be912 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -727,7 +727,7 @@ def bulk_configure_metrics_of_class( UserInputError( f"Metric class {metric_class} does not contain the requested " "attributes to update. Requested updates to attributes: " - f"{set(attributes_to_update.keys())} but metric class defines" + f"{set(attributes_to_update.keys())} but metric class defines " f"{metric_attributes}." ) ) diff --git a/ax/core/generator_run.py b/ax/core/generator_run.py index b4b17290662..2235e2c8e9a 100644 --- a/ax/core/generator_run.py +++ b/ax/core/generator_run.py @@ -343,6 +343,14 @@ def clone(self) -> GeneratorRun: if self._generator_state_after_gen is not None else None ) + generator_run._gen_metadata = ( + self._gen_metadata.copy() if self._gen_metadata is not None else None + ) + generator_run._candidate_metadata_by_arm_signature = ( + self._candidate_metadata_by_arm_signature.copy() + if self._candidate_metadata_by_arm_signature is not None + else None + ) return generator_run def add_arm(self, arm: Arm, weight: float = 1.0) -> None: diff --git a/ax/core/metric.py b/ax/core/metric.py index db006e36823..c5f2289c191 100644 --- a/ax/core/metric.py +++ b/ax/core/metric.py @@ -165,7 +165,7 @@ def period_of_new_data_after_trial_completion(cls) -> timedelta: return timedelta(0) @classmethod - def is_reconverable_fetch_e(cls, metric_fetch_e: MetricFetchE) -> bool: + def is_recoverable_fetch_e(cls, metric_fetch_e: MetricFetchE) -> bool: """Checks whether the given MetricFetchE is recoverable for this metric class in ``orchestrator._fetch_and_process_trials_data_results``. """ diff --git a/ax/core/outcome_constraint.py b/ax/core/outcome_constraint.py index f7f88386171..1d22616f3f4 100644 --- a/ax/core/outcome_constraint.py +++ b/ax/core/outcome_constraint.py @@ -28,7 +28,7 @@ CONSTRAINT_THRESHOLD_WARNING_MESSAGE: str = ( "Constraint threshold on {name} appears invalid: {bound} bound on metric " - + "for which {is_better} values is better." + + "for which {is_better} values are better." ) UPPER_BOUND_THRESHOLD: dict[str, str] = {"bound": "Positive", "is_better": "lower"} LOWER_BOUND_THRESHOLD: dict[str, str] = {"bound": "Negative", "is_better": "higher"} @@ -120,7 +120,7 @@ def _validate_metric_constraint_op( if op == ComparisonOp.LEQ and not metric.lower_is_better: fmt_data = UPPER_BOUND_MISMATCH if fmt_data is not None: - fmt_data["name"] = metric.name + fmt_data = {**fmt_data, "name": metric.name} msg = CONSTRAINT_WARNING_MESSAGE.format(**fmt_data) logger.debug(msg) return False, msg @@ -156,7 +156,7 @@ def _validate_constraint(self) -> tuple[bool, str]: if self.bound > 0 and not self.metric.lower_is_better: fmt_data = LOWER_BOUND_THRESHOLD if fmt_data is not None: - fmt_data["name"] = self.metric.name + fmt_data = {**fmt_data, "name": self.metric.name} msg += CONSTRAINT_THRESHOLD_WARNING_MESSAGE.format(**fmt_data) logger.debug(msg) return False, msg diff --git a/ax/core/parameter.py b/ax/core/parameter.py index 7591c525a35..e396c792677 100644 --- a/ax/core/parameter.py +++ b/ax/core/parameter.py @@ -423,13 +423,13 @@ def _validate_range_param( ParameterType.FLOAT, ): raise UserInputError( - f"RangeParameter {self.name}type must be int or float." + f"RangeParameter {self.name} type must be int or float." ) upper = float(upper) if lower >= upper: raise UserInputError( - f"Upper bound of {self.name} must be strictly larger than lower." + f"Upper bound of {self.name} must be strictly larger than lower. " f"Got: ({lower}, {upper})." ) width: float = upper - lower diff --git a/ax/core/parameter_constraint.py b/ax/core/parameter_constraint.py index ef1621ab36d..3eabdc6db64 100644 --- a/ax/core/parameter_constraint.py +++ b/ax/core/parameter_constraint.py @@ -129,7 +129,7 @@ def validate_constraint_parameters(parameters: Sequence[Parameter]) -> None: for parameter in parameters: if not isinstance(parameter, RangeParameter): raise ValueError( - "All parameters in a parameter constraint must be RangeParameters." + "All parameters in a parameter constraint must be RangeParameters. " f"Found {parameter}" ) diff --git a/ax/core/tests/test_outcome_constraint.py b/ax/core/tests/test_outcome_constraint.py index 58072e169e1..fcc6ed79cf5 100644 --- a/ax/core/tests/test_outcome_constraint.py +++ b/ax/core/tests/test_outcome_constraint.py @@ -63,14 +63,14 @@ def test_OutcomeConstraintFail(self) -> None: metric=self.minimize_metric, op=ComparisonOp.GEQ, bound=self.bound ) mock_warning.debug.assert_called_once_with( - CONSTRAINT_WARNING_MESSAGE.format(**LOWER_BOUND_MISMATCH) + CONSTRAINT_WARNING_MESSAGE.format(**LOWER_BOUND_MISMATCH, name="bar") ) with mock.patch(logger_name) as mock_warning: OutcomeConstraint( metric=self.maximize_metric, op=ComparisonOp.LEQ, bound=self.bound ) mock_warning.debug.assert_called_once_with( - CONSTRAINT_WARNING_MESSAGE.format(**UPPER_BOUND_MISMATCH) + CONSTRAINT_WARNING_MESSAGE.format(**UPPER_BOUND_MISMATCH, name="baz") ) def test_Sortable(self) -> None: @@ -144,14 +144,14 @@ def test_ObjectiveThresholdFail(self) -> None: metric=self.minimize_metric, op=ComparisonOp.GEQ, bound=self.bound ) mock_warning.debug.assert_called_once_with( - CONSTRAINT_WARNING_MESSAGE.format(**LOWER_BOUND_MISMATCH) + CONSTRAINT_WARNING_MESSAGE.format(**LOWER_BOUND_MISMATCH, name="bar") ) with mock.patch(logger_name) as mock_warning: ObjectiveThreshold( metric=self.maximize_metric, op=ComparisonOp.LEQ, bound=self.bound ) mock_warning.debug.assert_called_once_with( - CONSTRAINT_WARNING_MESSAGE.format(**UPPER_BOUND_MISMATCH) + CONSTRAINT_WARNING_MESSAGE.format(**UPPER_BOUND_MISMATCH, name="baz") ) def test_Relativize(self) -> None: diff --git a/ax/orchestration/orchestrator.py b/ax/orchestration/orchestrator.py index 8d4804893b7..a35cfe8aa51 100644 --- a/ax/orchestration/orchestrator.py +++ b/ax/orchestration/orchestrator.py @@ -1227,7 +1227,7 @@ def _check_exit_status_and_report_results( if not self.options.wait_for_running_trials: return None return self.wait_for_completed_trials_and_report_results( - idle_callback, force_refit=True + idle_callback, force_refit=force_refit ) def run(self, max_new_trials: int, timeout_hours: float | None = None) -> bool: @@ -2093,9 +2093,9 @@ def _fetch_and_process_trials_data_results( if ( optimization_config is not None and metric_name in optimization_config.metrics.keys() - and not self.experiment.metrics[ - metric_name - ].is_reconverable_fetch_e(metric_fetch_e=metric_fetch_e) + and not self.experiment.metrics[metric_name].is_recoverable_fetch_e( + metric_fetch_e=metric_fetch_e + ) ): status = self._mark_err_trial_status( trial=self.experiment.trials[trial_index], diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 0beccd47c09..e31e63a5cb3 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -1089,10 +1089,10 @@ def generator_spec_from_json( kwargs = generator_spec_json.pop("model_kwargs", None) else: kwargs = generator_spec_json.pop("generator_kwargs", None) - for k in _DEPRECATED_GENERATOR_KWARGS: - # Remove deprecated model kwargs. - kwargs.pop(k, None) if kwargs is not None: + for k in _DEPRECATED_GENERATOR_KWARGS: + # Remove deprecated model kwargs. + kwargs.pop(k, None) kwargs = _sanitize_surrogate_spec_input(object_json=kwargs) if "model_gen_kwargs" in generator_spec_json: gen_kwargs = generator_spec_json.pop("model_gen_kwargs", None) diff --git a/ax/storage/sqa_store/db.py b/ax/storage/sqa_store/db.py index 19d883bfd37..9da5d98e4e2 100644 --- a/ax/storage/sqa_store/db.py +++ b/ax/storage/sqa_store/db.py @@ -218,8 +218,8 @@ def create_all_tables(engine: Engine) -> None: """ if engine.dialect.name == "mysql" and engine.dialect.default_schema_name == "ax": raise ValueError( - "The open-source Ax table creation is likely not applicable in this case," - + "please contact the Adaptive Experimentation team if you need help." + "The open-source Ax table creation is likely not applicable in this case, " + "please contact the Adaptive Experimentation team if you need help." ) Base.metadata.create_all(engine) diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index b96877a228f..56fbb1ddb5f 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -128,7 +128,7 @@ def validate_experiment_metadata( raise ValueError( f"An experiment already exists with the name {experiment.name}. " "If you need to override this existing experiment, first delete it " - "via `delete_experiment` in ax/ax/storage/sqa_store/delete.py, " + "via `delete_experiment` in ax/storage/sqa_store/delete.py, " "and then resave." ) @@ -391,14 +391,15 @@ def get_metric_type_and_properties( json blob. """ metric_class = type(metric) - metric_type = int(self.config.metric_registry.get(metric_class)) - if metric_type is None: + metric_type_or_none = self.config.metric_registry.get(metric_class) + if metric_type_or_none is None: raise SQAEncodeError( "Cannot encode metric to SQLAlchemy because metric's " f"subclass ({metric_class}) is missing from the registry. " "The metric registry currently contains the following: " f"{','.join(map(str, self.config.metric_registry.keys()))} " ) + metric_type = int(metric_type_or_none) properties = metric_class.serialize_init_args(obj=metric) return metric_type, object_to_json( @@ -646,13 +647,13 @@ def outcome_constraint_to_sqa( def scalarized_outcome_constraint_to_sqa( self, outcome_constraint: ScalarizedOutcomeConstraint ) -> SQAMetric: - """Convert Ax SCalarized OutcomeConstraint to SQLAlchemy.""" + """Convert Ax Scalarized OutcomeConstraint to SQLAlchemy.""" metrics, weights = outcome_constraint.metrics, outcome_constraint.weights if metrics is None or weights is None or len(metrics) != len(weights): raise SQAEncodeError( - "Metrics and weights in scalarized OutcomeConstraint \ - must be lists of equal length." + "Metrics and weights in scalarized OutcomeConstraint " + "must be lists of equal length." ) metrics_by_name = self.get_children_metrics_by_name( diff --git a/ax/storage/sqa_store/load.py b/ax/storage/sqa_store/load.py index bb2f6e65b88..30c19732f30 100644 --- a/ax/storage/sqa_store/load.py +++ b/ax/storage/sqa_store/load.py @@ -757,7 +757,7 @@ def _query_historical_experiments_given_parameters( ) ) .filter(SQAExperiment.is_test == False) # noqa E712 `is` won't work for SQA - .filter(SQAExperiment.id is not None) + .filter(SQAExperiment.id.isnot(None)) # Experiments with some data .join(SQAData, SQAParameter.experiment_id == SQAData.experiment_id) ) diff --git a/ax/storage/sqa_store/save.py b/ax/storage/sqa_store/save.py index cd3f6ba2eb3..846150ddb8d 100644 --- a/ax/storage/sqa_store/save.py +++ b/ax/storage/sqa_store/save.py @@ -242,7 +242,9 @@ def trial_to_reduced_state_sqa_encoder(t: BaseTrial) -> SQATrial: objs=trials_to_reduce_state, encode_func=trial_to_reduced_state_sqa_encoder, decode_func=decoder.trial_from_sqa, - decode_args_list=[{"experiment": experiment} for _ in range(len(trials))], + decode_args_list=[ + {"experiment": experiment} for _ in range(len(trials_to_reduce_state)) + ], modify_sqa=add_experiment_id, batch_size=batch_size, ) @@ -251,7 +253,7 @@ def trial_to_reduced_state_sqa_encoder(t: BaseTrial) -> SQATrial: objs=[latest_trial], encode_func=encoder.trial_to_sqa, decode_func=decoder.trial_from_sqa, - decode_args_list=[{"experiment": experiment} for _ in range(len(trials))], + decode_args_list=[{"experiment": experiment}], modify_sqa=add_experiment_id, batch_size=batch_size, ) diff --git a/ax/storage/sqa_store/utils.py b/ax/storage/sqa_store/utils.py index 3fc233ff15e..5d3d89cbf7a 100644 --- a/ax/storage/sqa_store/utils.py +++ b/ax/storage/sqa_store/utils.py @@ -46,7 +46,9 @@ "_metric_fetching_errors", "_data_rows", } -SKIP_ATTRS_ERROR_SUFFIX = "Consider adding to COPY_DB_IDS_ATTRS_TO_SKIP if appropriate." +SKIP_ATTRS_ERROR_SUFFIX = ( + " Consider adding to COPY_DB_IDS_ATTRS_TO_SKIP if appropriate." +) def is_foreign_key_field(field: str) -> bool: @@ -74,7 +76,7 @@ def copy_db_ids(source: Any, target: Any, path: list[str] | None = None) -> None if len(path) > 15: # This shouldn't happen, but is a precaution against accidentally # introducing infinite loops - raise SQADecodeError(error_message_prefix + "Encountered path of length > 10.") + raise SQADecodeError(error_message_prefix + "Encountered path of length > 15.") if type(source) is not type(target): if not issubclass(type(target), type(source)): @@ -111,7 +113,7 @@ def copy_db_ids(source: Any, target: Any, path: list[str] | None = None) -> None source_json = getattr(source, attr) target_json = getattr(target, attr) if source_json != target_json: - SQADecodeError( + raise SQADecodeError( error_message_prefix + f"Json attribute {attr} not matching " f"between source: {source_json} and target: {target_json}." ) diff --git a/ax/storage/sqa_store/with_db_settings_base.py b/ax/storage/sqa_store/with_db_settings_base.py index 0defee1c634..54de2bd8486 100644 --- a/ax/storage/sqa_store/with_db_settings_base.py +++ b/ax/storage/sqa_store/with_db_settings_base.py @@ -105,8 +105,9 @@ def __init__( if db_settings and (not DBSettings or not isinstance(db_settings, DBSettings)): raise ValueError( "`db_settings` argument should be of type ax.storage.sqa_store." - f"(Got: {db_settings} of type {type(db_settings)}. " - "structs.DBSettings. To use `DBSettings`, you will need SQLAlchemy " + "structs.DBSettings. " + f"(Got: {db_settings} of type {type(db_settings)}). " + "To use `DBSettings`, you will need SQLAlchemy " "installed in your environment (can be installed through pip)." ) self._db_settings = db_settings or self._get_default_db_settings()