diff --git a/ax/core/generator_run.py b/ax/core/generator_run.py index 2235e2c8e9a..07bb2b0b3f2 100644 --- a/ax/core/generator_run.py +++ b/ax/core/generator_run.py @@ -18,6 +18,7 @@ import pandas as pd from ax.core.arm import Arm +from ax.core.experiment_status import ExperimentStatus from ax.core.optimization_config import OptimizationConfig from ax.core.search_space import SearchSpace from ax.core.types import ( @@ -100,6 +101,7 @@ def __init__( candidate_metadata_by_arm_signature: None | (dict[str, TCandidateMetadata]) = None, generation_node_name: str | None = None, + suggested_experiment_status: ExperimentStatus | None = None, ) -> None: """Inits GeneratorRun. @@ -142,6 +144,10 @@ def __init__( via a generation strategy (in which case this name should reflect the name of the generation node in a generation strategy) or a standalone generation node (in which case this name should be ``-1``). + suggested_experiment_status: Optional ``ExperimentStatus`` that indicates + what the experiment's status should be once this generator run is + added to a trial. This is propagated from the generation node's + suggested_experiment_status field and is advisory only. """ self._arm_weight_table: OrderedDict[str, ArmWeight] = OrderedDict() if weights is None: @@ -191,6 +197,7 @@ def __init__( ) self._candidate_metadata_by_arm_signature = candidate_metadata_by_arm_signature self._generation_node_name = generation_node_name + self._suggested_experiment_status = suggested_experiment_status @property def arms(self) -> list[Arm]: @@ -287,6 +294,11 @@ def candidate_metadata_by_arm_signature( """ return self._candidate_metadata_by_arm_signature + @property + def suggested_experiment_status(self) -> ExperimentStatus | None: + """Optional suggested experiment status for this generator run.""" + return self._suggested_experiment_status + @property def param_df(self) -> pd.DataFrame: """ @@ -327,6 +339,7 @@ def clone(self) -> GeneratorRun: generator_state_after_gen=self._generator_state_after_gen, candidate_metadata_by_arm_signature=cand_metadata, generation_node_name=self._generation_node_name, + suggested_experiment_status=self.suggested_experiment_status, ) generator_run._time_created = self._time_created generator_run._generator_key = self._generator_key diff --git a/ax/core/tests/test_generator_run.py b/ax/core/tests/test_generator_run.py index 8a372be2984..9360c93293f 100644 --- a/ax/core/tests/test_generator_run.py +++ b/ax/core/tests/test_generator_run.py @@ -7,6 +7,7 @@ # pyre-strict from ax.core.arm import Arm +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( @@ -48,6 +49,10 @@ def setUp(self) -> None: search_space=self.search_space, model_predictions=self.model_predictions, ) + self.run_with_suggested_status = GeneratorRun( + arms=self.arms, + suggested_experiment_status=ExperimentStatus.INITIALIZATION, + ) def test_Init(self) -> None: self.assertEqual( @@ -184,3 +189,18 @@ def test_Sortable(self) -> None: weights=self.weights, ) self.assertTrue(generator_run1 < generator_run2) + + def test_SuggestedExperimentStatus(self) -> None: + self.assertEqual( + self.run_with_suggested_status.suggested_experiment_status, + ExperimentStatus.INITIALIZATION, + ) + + def test_SuggestedExperimentStatusDefaultNone(self) -> None: + self.assertIsNone(self.unweighted_run.suggested_experiment_status) + + def test_ClonePreservesSuggestedExperimentStatus(self) -> None: + cloned = self.run_with_suggested_status.clone() + self.assertEqual( + cloned.suggested_experiment_status, ExperimentStatus.INITIALIZATION + ) diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index ed17c193c63..87b6b44c045 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -367,6 +367,7 @@ def generator_run_to_dict(generator_run: GeneratorRun) -> dict[str, Any]: "generator_state_after_gen": gr._generator_state_after_gen, "candidate_metadata_by_arm_signature": cand_metadata, "generation_node_name": gr._generation_node_name, + "suggested_experiment_status": gr.suggested_experiment_status, } diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 612f440f716..63b83add043 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -813,6 +813,7 @@ def generator_run_from_sqa( class_decoder_registry=self.config.json_class_decoder_registry, ), generation_node_name=generator_run_sqa.generation_node_name, + suggested_experiment_status=generator_run_sqa.suggested_experiment_status, ) # Remove deprecated kwargs from generator kwargs & adapter kwargs. if generator_run._generator_kwargs is not None: diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index b96877a228f..87bf58cb88f 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -897,6 +897,7 @@ def generator_run_to_sqa( class_encoder_registry=self.config.json_class_encoder_registry, ), generation_node_name=generator_run._generation_node_name, + suggested_experiment_status=generator_run.suggested_experiment_status, ) return gr_sqa diff --git a/ax/storage/sqa_store/sqa_classes.py b/ax/storage/sqa_store/sqa_classes.py index dff59cb9e7a..1176634b9e2 100644 --- a/ax/storage/sqa_store/sqa_classes.py +++ b/ax/storage/sqa_store/sqa_classes.py @@ -215,6 +215,9 @@ class SQAGeneratorRun(Base): JSONEncodedTextDict ) generation_node_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + suggested_experiment_status: Column[ExperimentStatus | None] = Column( + IntEnum(ExperimentStatus), nullable=True + ) # relationships # Use selectin loading for collections to prevent idle timeout errors diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 9d72ea425dc..2bc73a7ed88 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -2539,6 +2539,35 @@ def test_generator_run_gen_metadata(self) -> None: ) self.assertEqual(decoded_gr.gen_metadata, gen_metadata) + def test_generator_run_suggested_experiment_status(self) -> None: + # Test round-trip with a status set. + gr = GeneratorRun( + arms=[], + suggested_experiment_status=ExperimentStatus.OPTIMIZATION, + ) + generator_run_sqa = self.encoder.generator_run_to_sqa(gr) + self.assertEqual( + generator_run_sqa.suggested_experiment_status, + ExperimentStatus.OPTIMIZATION, + ) + decoded_gr = self.decoder.generator_run_from_sqa( + generator_run_sqa, False, False + ) + self.assertEqual( + decoded_gr.suggested_experiment_status, + ExperimentStatus.OPTIMIZATION, + ) + + def test_generator_run_suggested_experiment_status_none(self) -> None: + # Test round-trip with None (default). + gr = GeneratorRun(arms=[]) + generator_run_sqa = self.encoder.generator_run_to_sqa(gr) + self.assertIsNone(generator_run_sqa.suggested_experiment_status) + decoded_gr = self.decoder.generator_run_from_sqa( + generator_run_sqa, False, False + ) + self.assertIsNone(decoded_gr.suggested_experiment_status) + def test_update_generation_strategy_incrementally(self) -> None: experiment = get_branin_experiment() generation_strategy = choose_generation_strategy( diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 62b36197304..9c17c956603 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -31,6 +31,7 @@ from ax.core.data import Data from ax.core.evaluations_to_data import raw_evaluations_to_data from ax.core.experiment import Experiment +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.map_metric import MapMetric from ax.core.metric import Metric @@ -2383,6 +2384,7 @@ def get_generator_run() -> GeneratorRun: candidate_metadata_by_arm_signature={ a.signature: {"md_key": f"md_val_{a.signature}"} for a in arms }, + suggested_experiment_status=ExperimentStatus.OPTIMIZATION, )