Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ax/core/arm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def md5hash(parameters: Mapping[str, TParamValue]) -> str:
new_parameters = {}
for k, v in parameters.items():
new_parameters[k] = numpy_type_to_python_type(v)
parameters_str = json.dumps(parameters, sort_keys=True)
parameters_str = json.dumps(new_parameters, sort_keys=True)
return hashlib.md5(parameters_str.encode("utf-8")).hexdigest()

def clone(self, clear_name: bool = False) -> "Arm":
Expand Down
2 changes: 1 addition & 1 deletion ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ def bulk_configure_metrics_of_class(
UserInputError(
f"Metric class {metric_class} does not contain the requested "
"attributes to update. Requested updates to attributes: "
f"{set(attributes_to_update.keys())} but metric class defines"
f"{set(attributes_to_update.keys())} but metric class defines "
f"{metric_attributes}."
)
)
Expand Down
8 changes: 8 additions & 0 deletions ax/core/generator_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,14 @@ def clone(self) -> GeneratorRun:
if self._generator_state_after_gen is not None
else None
)
generator_run._gen_metadata = (
self._gen_metadata.copy() if self._gen_metadata is not None else None
)
generator_run._candidate_metadata_by_arm_signature = (
self._candidate_metadata_by_arm_signature.copy()
if self._candidate_metadata_by_arm_signature is not None
else None
)
return generator_run

def add_arm(self, arm: Arm, weight: float = 1.0) -> None:
Expand Down
2 changes: 1 addition & 1 deletion ax/core/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def period_of_new_data_after_trial_completion(cls) -> timedelta:
return timedelta(0)

@classmethod
def is_reconverable_fetch_e(cls, metric_fetch_e: MetricFetchE) -> bool:
def is_recoverable_fetch_e(cls, metric_fetch_e: MetricFetchE) -> bool:
"""Checks whether the given MetricFetchE is recoverable for this metric class
in ``orchestrator._fetch_and_process_trials_data_results``.
"""
Expand Down
6 changes: 3 additions & 3 deletions ax/core/outcome_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

CONSTRAINT_THRESHOLD_WARNING_MESSAGE: str = (
"Constraint threshold on {name} appears invalid: {bound} bound on metric "
+ "for which {is_better} values is better."
+ "for which {is_better} values are better."
)
UPPER_BOUND_THRESHOLD: dict[str, str] = {"bound": "Positive", "is_better": "lower"}
LOWER_BOUND_THRESHOLD: dict[str, str] = {"bound": "Negative", "is_better": "higher"}
Expand Down Expand Up @@ -120,7 +120,7 @@ def _validate_metric_constraint_op(
if op == ComparisonOp.LEQ and not metric.lower_is_better:
fmt_data = UPPER_BOUND_MISMATCH
if fmt_data is not None:
fmt_data["name"] = metric.name
fmt_data = {**fmt_data, "name": metric.name}
msg = CONSTRAINT_WARNING_MESSAGE.format(**fmt_data)
logger.debug(msg)
return False, msg
Expand Down Expand Up @@ -156,7 +156,7 @@ def _validate_constraint(self) -> tuple[bool, str]:
if self.bound > 0 and not self.metric.lower_is_better:
fmt_data = LOWER_BOUND_THRESHOLD
if fmt_data is not None:
fmt_data["name"] = self.metric.name
fmt_data = {**fmt_data, "name": self.metric.name}
msg += CONSTRAINT_THRESHOLD_WARNING_MESSAGE.format(**fmt_data)
logger.debug(msg)
return False, msg
Expand Down
4 changes: 2 additions & 2 deletions ax/core/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,13 +423,13 @@ def _validate_range_param(
ParameterType.FLOAT,
):
raise UserInputError(
f"RangeParameter {self.name}type must be int or float."
f"RangeParameter {self.name} type must be int or float."
)

upper = float(upper)
if lower >= upper:
raise UserInputError(
f"Upper bound of {self.name} must be strictly larger than lower."
f"Upper bound of {self.name} must be strictly larger than lower. "
f"Got: ({lower}, {upper})."
)
width: float = upper - lower
Expand Down
2 changes: 1 addition & 1 deletion ax/core/parameter_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def validate_constraint_parameters(parameters: Sequence[Parameter]) -> None:
for parameter in parameters:
if not isinstance(parameter, RangeParameter):
raise ValueError(
"All parameters in a parameter constraint must be RangeParameters."
"All parameters in a parameter constraint must be RangeParameters. "
f"Found {parameter}"
)

Expand Down
8 changes: 4 additions & 4 deletions ax/core/tests/test_outcome_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ def test_OutcomeConstraintFail(self) -> None:
metric=self.minimize_metric, op=ComparisonOp.GEQ, bound=self.bound
)
mock_warning.debug.assert_called_once_with(
CONSTRAINT_WARNING_MESSAGE.format(**LOWER_BOUND_MISMATCH)
CONSTRAINT_WARNING_MESSAGE.format(**LOWER_BOUND_MISMATCH, name="bar")
)
with mock.patch(logger_name) as mock_warning:
OutcomeConstraint(
metric=self.maximize_metric, op=ComparisonOp.LEQ, bound=self.bound
)
mock_warning.debug.assert_called_once_with(
CONSTRAINT_WARNING_MESSAGE.format(**UPPER_BOUND_MISMATCH)
CONSTRAINT_WARNING_MESSAGE.format(**UPPER_BOUND_MISMATCH, name="baz")
)

def test_Sortable(self) -> None:
Expand Down Expand Up @@ -144,14 +144,14 @@ def test_ObjectiveThresholdFail(self) -> None:
metric=self.minimize_metric, op=ComparisonOp.GEQ, bound=self.bound
)
mock_warning.debug.assert_called_once_with(
CONSTRAINT_WARNING_MESSAGE.format(**LOWER_BOUND_MISMATCH)
CONSTRAINT_WARNING_MESSAGE.format(**LOWER_BOUND_MISMATCH, name="bar")
)
with mock.patch(logger_name) as mock_warning:
ObjectiveThreshold(
metric=self.maximize_metric, op=ComparisonOp.LEQ, bound=self.bound
)
mock_warning.debug.assert_called_once_with(
CONSTRAINT_WARNING_MESSAGE.format(**UPPER_BOUND_MISMATCH)
CONSTRAINT_WARNING_MESSAGE.format(**UPPER_BOUND_MISMATCH, name="baz")
)

def test_Relativize(self) -> None:
Expand Down
8 changes: 4 additions & 4 deletions ax/orchestration/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,7 @@ def _check_exit_status_and_report_results(
if not self.options.wait_for_running_trials:
return None
return self.wait_for_completed_trials_and_report_results(
idle_callback, force_refit=True
idle_callback, force_refit=force_refit
)

def run(self, max_new_trials: int, timeout_hours: float | None = None) -> bool:
Expand Down Expand Up @@ -2093,9 +2093,9 @@ def _fetch_and_process_trials_data_results(
if (
optimization_config is not None
and metric_name in optimization_config.metrics.keys()
and not self.experiment.metrics[
metric_name
].is_reconverable_fetch_e(metric_fetch_e=metric_fetch_e)
and not self.experiment.metrics[metric_name].is_recoverable_fetch_e(
metric_fetch_e=metric_fetch_e
)
):
status = self._mark_err_trial_status(
trial=self.experiment.trials[trial_index],
Expand Down
6 changes: 3 additions & 3 deletions ax/storage/json_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,10 +1089,10 @@ def generator_spec_from_json(
kwargs = generator_spec_json.pop("model_kwargs", None)
else:
kwargs = generator_spec_json.pop("generator_kwargs", None)
for k in _DEPRECATED_GENERATOR_KWARGS:
# Remove deprecated model kwargs.
kwargs.pop(k, None)
if kwargs is not None:
for k in _DEPRECATED_GENERATOR_KWARGS:
# Remove deprecated model kwargs.
kwargs.pop(k, None)
kwargs = _sanitize_surrogate_spec_input(object_json=kwargs)
if "model_gen_kwargs" in generator_spec_json:
gen_kwargs = generator_spec_json.pop("model_gen_kwargs", None)
Expand Down
4 changes: 2 additions & 2 deletions ax/storage/sqa_store/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ def create_all_tables(engine: Engine) -> None:
"""
if engine.dialect.name == "mysql" and engine.dialect.default_schema_name == "ax":
raise ValueError(
"The open-source Ax table creation is likely not applicable in this case,"
+ "please contact the Adaptive Experimentation team if you need help."
"The open-source Ax table creation is likely not applicable in this case, "
"please contact the Adaptive Experimentation team if you need help."
)
Base.metadata.create_all(engine)

Expand Down
13 changes: 7 additions & 6 deletions ax/storage/sqa_store/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def validate_experiment_metadata(
raise ValueError(
f"An experiment already exists with the name {experiment.name}. "
"If you need to override this existing experiment, first delete it "
"via `delete_experiment` in ax/ax/storage/sqa_store/delete.py, "
"via `delete_experiment` in ax/storage/sqa_store/delete.py, "
"and then resave."
)

Expand Down Expand Up @@ -391,14 +391,15 @@ def get_metric_type_and_properties(
json blob.
"""
metric_class = type(metric)
metric_type = int(self.config.metric_registry.get(metric_class))
if metric_type is None:
metric_type_or_none = self.config.metric_registry.get(metric_class)
if metric_type_or_none is None:
raise SQAEncodeError(
"Cannot encode metric to SQLAlchemy because metric's "
f"subclass ({metric_class}) is missing from the registry. "
"The metric registry currently contains the following: "
f"{','.join(map(str, self.config.metric_registry.keys()))} "
)
metric_type = int(metric_type_or_none)

properties = metric_class.serialize_init_args(obj=metric)
return metric_type, object_to_json(
Expand Down Expand Up @@ -646,13 +647,13 @@ def outcome_constraint_to_sqa(
def scalarized_outcome_constraint_to_sqa(
self, outcome_constraint: ScalarizedOutcomeConstraint
) -> SQAMetric:
"""Convert Ax SCalarized OutcomeConstraint to SQLAlchemy."""
"""Convert Ax Scalarized OutcomeConstraint to SQLAlchemy."""
metrics, weights = outcome_constraint.metrics, outcome_constraint.weights

if metrics is None or weights is None or len(metrics) != len(weights):
raise SQAEncodeError(
"Metrics and weights in scalarized OutcomeConstraint \
must be lists of equal length."
"Metrics and weights in scalarized OutcomeConstraint "
"must be lists of equal length."
)

metrics_by_name = self.get_children_metrics_by_name(
Expand Down
2 changes: 1 addition & 1 deletion ax/storage/sqa_store/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ def _query_historical_experiments_given_parameters(
)
)
.filter(SQAExperiment.is_test == False) # noqa E712 `is` won't work for SQA
.filter(SQAExperiment.id is not None)
.filter(SQAExperiment.id.isnot(None))
# Experiments with some data
.join(SQAData, SQAParameter.experiment_id == SQAData.experiment_id)
)
Expand Down
6 changes: 4 additions & 2 deletions ax/storage/sqa_store/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,9 @@ def trial_to_reduced_state_sqa_encoder(t: BaseTrial) -> SQATrial:
objs=trials_to_reduce_state,
encode_func=trial_to_reduced_state_sqa_encoder,
decode_func=decoder.trial_from_sqa,
decode_args_list=[{"experiment": experiment} for _ in range(len(trials))],
decode_args_list=[
{"experiment": experiment} for _ in range(len(trials_to_reduce_state))
],
modify_sqa=add_experiment_id,
batch_size=batch_size,
)
Expand All @@ -251,7 +253,7 @@ def trial_to_reduced_state_sqa_encoder(t: BaseTrial) -> SQATrial:
objs=[latest_trial],
encode_func=encoder.trial_to_sqa,
decode_func=decoder.trial_from_sqa,
decode_args_list=[{"experiment": experiment} for _ in range(len(trials))],
decode_args_list=[{"experiment": experiment}],
modify_sqa=add_experiment_id,
batch_size=batch_size,
)
Expand Down
8 changes: 5 additions & 3 deletions ax/storage/sqa_store/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@
"_metric_fetching_errors",
"_data_rows",
}
SKIP_ATTRS_ERROR_SUFFIX = "Consider adding to COPY_DB_IDS_ATTRS_TO_SKIP if appropriate."
SKIP_ATTRS_ERROR_SUFFIX = (
" Consider adding to COPY_DB_IDS_ATTRS_TO_SKIP if appropriate."
)


def is_foreign_key_field(field: str) -> bool:
Expand Down Expand Up @@ -74,7 +76,7 @@ def copy_db_ids(source: Any, target: Any, path: list[str] | None = None) -> None
if len(path) > 15:
# This shouldn't happen, but is a precaution against accidentally
# introducing infinite loops
raise SQADecodeError(error_message_prefix + "Encountered path of length > 10.")
raise SQADecodeError(error_message_prefix + "Encountered path of length > 15.")

if type(source) is not type(target):
if not issubclass(type(target), type(source)):
Expand Down Expand Up @@ -111,7 +113,7 @@ def copy_db_ids(source: Any, target: Any, path: list[str] | None = None) -> None
source_json = getattr(source, attr)
target_json = getattr(target, attr)
if source_json != target_json:
SQADecodeError(
raise SQADecodeError(
error_message_prefix + f"Json attribute {attr} not matching "
f"between source: {source_json} and target: {target_json}."
)
Expand Down
5 changes: 3 additions & 2 deletions ax/storage/sqa_store/with_db_settings_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ def __init__(
if db_settings and (not DBSettings or not isinstance(db_settings, DBSettings)):
raise ValueError(
"`db_settings` argument should be of type ax.storage.sqa_store."
f"(Got: {db_settings} of type {type(db_settings)}. "
"structs.DBSettings. To use `DBSettings`, you will need SQLAlchemy "
"structs.DBSettings. "
f"(Got: {db_settings} of type {type(db_settings)}). "
"To use `DBSettings`, you will need SQLAlchemy "
"installed in your environment (can be installed through pip)."
)
self._db_settings = db_settings or self._get_default_db_settings()
Expand Down