diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 1f8fa641034..59fdf0990a3 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -254,6 +254,44 @@ def status(self, status: ExperimentStatus | None) -> None: """ self._status = status + @staticmethod + def experiment_status_from_generator_runs( + generator_runs: list[GeneratorRun], + ) -> ExperimentStatus | None: + """Extract and validate suggested experiment status from generator runs. + + Collects the suggested_experiment_status directly from the GeneratorRun + objects, validates that all runs suggest the same status, and returns + that status. + + Args: + generator_runs: List of generator runs to extract statuses from. + + Returns: + The suggested experiment status that all generator runs agree on, + or None if no statuses were found or if there are conflicting statuses. + """ + suggested_statuses: set[ExperimentStatus] = set() + for gr in generator_runs: + if gr.suggested_experiment_status is not None: + suggested_statuses.add(gr.suggested_experiment_status) + + if len(suggested_statuses) > 1: + # TODO: Consider making this invalid state an actual error once + # related development is completed. + logger.warning( + "Multiple different suggested experiment statuses found: " + f"{suggested_statuses}. " + "All generator runs used in a single gen() call should suggest the " + "same experiment status. Skipping updating experiment status." + ) + return None + + if len(suggested_statuses) == 0: + return None + + return suggested_statuses.pop() + @property def search_space(self) -> SearchSpace: """The search space for this experiment. diff --git a/ax/core/generator_run.py b/ax/core/generator_run.py index 2235e2c8e9a..07bb2b0b3f2 100644 --- a/ax/core/generator_run.py +++ b/ax/core/generator_run.py @@ -18,6 +18,7 @@ import pandas as pd from ax.core.arm import Arm +from ax.core.experiment_status import ExperimentStatus from ax.core.optimization_config import OptimizationConfig from ax.core.search_space import SearchSpace from ax.core.types import ( @@ -100,6 +101,7 @@ def __init__( candidate_metadata_by_arm_signature: None | (dict[str, TCandidateMetadata]) = None, generation_node_name: str | None = None, + suggested_experiment_status: ExperimentStatus | None = None, ) -> None: """Inits GeneratorRun. @@ -142,6 +144,10 @@ def __init__( via a generation strategy (in which case this name should reflect the name of the generation node in a generation strategy) or a standalone generation node (in which case this name should be ``-1``). + suggested_experiment_status: Optional ``ExperimentStatus`` that indicates + what the experiment's status should be once this generator run is + added to a trial. This is propagated from the generation node's + suggested_experiment_status field and is advisory only. """ self._arm_weight_table: OrderedDict[str, ArmWeight] = OrderedDict() if weights is None: @@ -191,6 +197,7 @@ def __init__( ) self._candidate_metadata_by_arm_signature = candidate_metadata_by_arm_signature self._generation_node_name = generation_node_name + self._suggested_experiment_status = suggested_experiment_status @property def arms(self) -> list[Arm]: @@ -287,6 +294,11 @@ def candidate_metadata_by_arm_signature( """ return self._candidate_metadata_by_arm_signature + @property + def suggested_experiment_status(self) -> ExperimentStatus | None: + """Optional suggested experiment status for this generator run.""" + return self._suggested_experiment_status + @property def param_df(self) -> pd.DataFrame: """ @@ -327,6 +339,7 @@ def clone(self) -> GeneratorRun: generator_state_after_gen=self._generator_state_after_gen, candidate_metadata_by_arm_signature=cand_metadata, generation_node_name=self._generation_node_name, + suggested_experiment_status=self.suggested_experiment_status, ) generator_run._time_created = self._time_created generator_run._generator_key = self._generator_key diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index e93954368ef..42e4e8bfb39 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -18,6 +18,7 @@ from ax.core.data import Data, sort_by_trial_index_and_arm_name from ax.core.evaluations_to_data import raw_evaluations_to_data from ax.core.experiment_status import ExperimentStatus +from ax.core.generator_run import GeneratorRun from ax.core.map_metric import MapMetric from ax.core.metric import Metric from ax.core.objective import MultiObjective, Objective @@ -44,6 +45,9 @@ UnsupportedError, UserInputError, ) +from ax.generation_strategy.generation_node import GenerationNode +from ax.generation_strategy.generation_strategy import GenerationStrategy +from ax.generation_strategy.generator_spec import GeneratorSpec from ax.metrics.branin import BraninMetric from ax.metrics.hartmann6 import Hartmann6Metric from ax.metrics.noisy_function import NoisyFunctionMetric @@ -1881,6 +1885,76 @@ def test_experiment_status_property(self) -> None: self.experiment.status = ExperimentStatus.DRAFT self.assertEqual(self.experiment.status, ExperimentStatus.DRAFT) + def test_experiment_status_from_generator_runs(self) -> None: + """Test that experiment status is correctly extracted from generator runs.""" + sobol_generator_spec = GeneratorSpec( + generator_enum=Generators.SOBOL, + generator_kwargs={"silently_filter_kwargs": True}, + generator_gen_kwargs={}, + ) + + with self.subTest("gen returns GRs with correct suggested_experiment_status"): + for status in [ + ExperimentStatus.INITIALIZATION, + ExperimentStatus.OPTIMIZATION, + ]: + with self.subTest(status=status): + exp = get_branin_experiment() + node_with_status = GenerationNode( + name="test_node", + generator_specs=[sobol_generator_spec], + suggested_experiment_status=status, + ) + gs = GenerationStrategy(nodes=[node_with_status]) + gs.experiment = exp + + grs = gs.gen(experiment=exp, num_trials=1) + flat_grs = [gr for trial_grs in grs for gr in trial_grs] + + extracted_status = Experiment.experiment_status_from_generator_runs( + flat_grs + ) + self.assertEqual(extracted_status, status) + + with self.subTest("conflicting statuses return None"): + gr1 = GeneratorRun( + arms=[Arm(name="0_0", parameters={"x1": 0.0, "x2": 0.0})], + suggested_experiment_status=ExperimentStatus.INITIALIZATION, + ) + gr2 = GeneratorRun( + arms=[Arm(name="0_1", parameters={"x1": 1.0, "x2": 1.0})], + suggested_experiment_status=ExperimentStatus.OPTIMIZATION, + ) + mixed_grs = [gr1, gr2] + + result = Experiment.experiment_status_from_generator_runs(mixed_grs) + self.assertIsNone(result) + + with self.subTest("multiple trials all carry experiment status"): + exp = get_branin_experiment() + node_with_status = GenerationNode( + name="multi_trial_node", + generator_specs=[sobol_generator_spec], + suggested_experiment_status=ExperimentStatus.INITIALIZATION, + ) + gs = GenerationStrategy(nodes=[node_with_status]) + gs.experiment = exp + + grs = gs.gen(experiment=exp, num_trials=3) + + self.assertEqual(len(grs), 3) + for gr_list in grs: + self.assertEqual(len(gr_list), 1) + self.assertEqual(gr_list[0]._generation_node_name, "multi_trial_node") + self.assertEqual( + gr_list[0].suggested_experiment_status, + ExperimentStatus.INITIALIZATION, + ) + extracted_status = Experiment.experiment_status_from_generator_runs( + [gr for trial_grs in grs for gr in trial_grs] + ) + self.assertEqual(extracted_status, ExperimentStatus.INITIALIZATION) + class ExperimentWithMapDataTest(TestCase): def setUp(self) -> None: diff --git a/ax/core/tests/test_generator_run.py b/ax/core/tests/test_generator_run.py index 8a372be2984..9360c93293f 100644 --- a/ax/core/tests/test_generator_run.py +++ b/ax/core/tests/test_generator_run.py @@ -7,6 +7,7 @@ # pyre-strict from ax.core.arm import Arm +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( @@ -48,6 +49,10 @@ def setUp(self) -> None: search_space=self.search_space, model_predictions=self.model_predictions, ) + self.run_with_suggested_status = GeneratorRun( + arms=self.arms, + suggested_experiment_status=ExperimentStatus.INITIALIZATION, + ) def test_Init(self) -> None: self.assertEqual( @@ -184,3 +189,18 @@ def test_Sortable(self) -> None: weights=self.weights, ) self.assertTrue(generator_run1 < generator_run2) + + def test_SuggestedExperimentStatus(self) -> None: + self.assertEqual( + self.run_with_suggested_status.suggested_experiment_status, + ExperimentStatus.INITIALIZATION, + ) + + def test_SuggestedExperimentStatusDefaultNone(self) -> None: + self.assertIsNone(self.unweighted_run.suggested_experiment_status) + + def test_ClonePreservesSuggestedExperimentStatus(self) -> None: + cloned = self.run_with_suggested_status.clone() + self.assertEqual( + cloned.suggested_experiment_status, ExperimentStatus.INITIALIZATION + ) diff --git a/ax/generation_strategy/generation_node.py b/ax/generation_strategy/generation_node.py index 70b64d641d8..335572861f2 100644 --- a/ax/generation_strategy/generation_node.py +++ b/ax/generation_strategy/generation_node.py @@ -517,6 +517,7 @@ def gen( ) gr._generation_node_name = self.name + gr._suggested_experiment_status = self.suggested_experiment_status # TODO: When we start using `trial_type` more commonly, give it a dedicated # field on the `GeneratorRun` (or start creating trials from GS directly). if self._trial_type is not None: diff --git a/ax/generation_strategy/tests/test_generation_node.py b/ax/generation_strategy/tests/test_generation_node.py index 266873cf5ad..94a1ec7e5d2 100644 --- a/ax/generation_strategy/tests/test_generation_node.py +++ b/ax/generation_strategy/tests/test_generation_node.py @@ -58,6 +58,10 @@ def setUp(self) -> None: generator_specs=[self.sobol_generator_spec], suggested_experiment_status=ExperimentStatus.INITIALIZATION, ) + self.generation_node_without_exp_status = GenerationNode( + name="test", + generator_specs=[self.sobol_generator_spec], + ) self.branin_experiment = get_branin_experiment(with_completed_trial=True) self.branin_data = self.branin_experiment.lookup_data() self.node_short = GenerationNode( @@ -209,6 +213,31 @@ def test_gen(self) -> None: fixed_features=None, ) + def test_suggested_experiment_status_propagation(self) -> None: + """Test that suggested_experiment_status propagates from node to GR.""" + with self.subTest("with_suggested_experiment_status"): + gr = self.sobol_generation_node.gen( + experiment=self.branin_experiment, + data=self.branin_experiment.lookup_data(), + n=1, + pending_observations={"branin": []}, + ) + self.assertIsNotNone(gr) + self.assertEqual( + gr.suggested_experiment_status, + ExperimentStatus.INITIALIZATION, + ) + + with self.subTest("without_suggested_experiment_status"): + gr_without = self.generation_node_without_exp_status.gen( + experiment=self.branin_experiment, + data=self.branin_experiment.lookup_data(), + n=1, + pending_observations={"branin": []}, + ) + self.assertIsNotNone(gr_without) + self.assertIsNone(gr_without.suggested_experiment_status) + @mock_botorch_optimize def test_gen_with_trial_type(self) -> None: mbm_short = GenerationNode( diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index ed17c193c63..87b6b44c045 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -367,6 +367,7 @@ def generator_run_to_dict(generator_run: GeneratorRun) -> dict[str, Any]: "generator_state_after_gen": gr._generator_state_after_gen, "candidate_metadata_by_arm_signature": cand_metadata, "generation_node_name": gr._generation_node_name, + "suggested_experiment_status": gr.suggested_experiment_status, } diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 612f440f716..63b83add043 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -813,6 +813,7 @@ def generator_run_from_sqa( class_decoder_registry=self.config.json_class_decoder_registry, ), generation_node_name=generator_run_sqa.generation_node_name, + suggested_experiment_status=generator_run_sqa.suggested_experiment_status, ) # Remove deprecated kwargs from generator kwargs & adapter kwargs. if generator_run._generator_kwargs is not None: diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index b96877a228f..87bf58cb88f 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -897,6 +897,7 @@ def generator_run_to_sqa( class_encoder_registry=self.config.json_class_encoder_registry, ), generation_node_name=generator_run._generation_node_name, + suggested_experiment_status=generator_run.suggested_experiment_status, ) return gr_sqa diff --git a/ax/storage/sqa_store/sqa_classes.py b/ax/storage/sqa_store/sqa_classes.py index dff59cb9e7a..1176634b9e2 100644 --- a/ax/storage/sqa_store/sqa_classes.py +++ b/ax/storage/sqa_store/sqa_classes.py @@ -215,6 +215,9 @@ class SQAGeneratorRun(Base): JSONEncodedTextDict ) generation_node_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + suggested_experiment_status: Column[ExperimentStatus | None] = Column( + IntEnum(ExperimentStatus), nullable=True + ) # relationships # Use selectin loading for collections to prevent idle timeout errors diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 9d72ea425dc..2bc73a7ed88 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -2539,6 +2539,35 @@ def test_generator_run_gen_metadata(self) -> None: ) self.assertEqual(decoded_gr.gen_metadata, gen_metadata) + def test_generator_run_suggested_experiment_status(self) -> None: + # Test round-trip with a status set. + gr = GeneratorRun( + arms=[], + suggested_experiment_status=ExperimentStatus.OPTIMIZATION, + ) + generator_run_sqa = self.encoder.generator_run_to_sqa(gr) + self.assertEqual( + generator_run_sqa.suggested_experiment_status, + ExperimentStatus.OPTIMIZATION, + ) + decoded_gr = self.decoder.generator_run_from_sqa( + generator_run_sqa, False, False + ) + self.assertEqual( + decoded_gr.suggested_experiment_status, + ExperimentStatus.OPTIMIZATION, + ) + + def test_generator_run_suggested_experiment_status_none(self) -> None: + # Test round-trip with None (default). + gr = GeneratorRun(arms=[]) + generator_run_sqa = self.encoder.generator_run_to_sqa(gr) + self.assertIsNone(generator_run_sqa.suggested_experiment_status) + decoded_gr = self.decoder.generator_run_from_sqa( + generator_run_sqa, False, False + ) + self.assertIsNone(decoded_gr.suggested_experiment_status) + def test_update_generation_strategy_incrementally(self) -> None: experiment = get_branin_experiment() generation_strategy = choose_generation_strategy( diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 62b36197304..9c17c956603 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -31,6 +31,7 @@ from ax.core.data import Data from ax.core.evaluations_to_data import raw_evaluations_to_data from ax.core.experiment import Experiment +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.map_metric import MapMetric from ax.core.metric import Metric @@ -2383,6 +2384,7 @@ def get_generator_run() -> GeneratorRun: candidate_metadata_by_arm_signature={ a.signature: {"md_key": f"md_val_{a.signature}"} for a in arms }, + suggested_experiment_status=ExperimentStatus.OPTIMIZATION, )