From 98732490461a86d130ca682316c4276fc175ea7a Mon Sep 17 00:00:00 2001 From: Cristian Lara Date: Fri, 13 Feb 2026 22:28:22 -0800 Subject: [PATCH 1/3] Add suggested_experiment_status column to GeneratorRun (#4886) Summary: ## Summary Add `suggested_experiment_status` column to `GeneratorRun`. Some benefits: 1. We don't need to modify the GS.gen() or Orchestrator methods to pass along a suggested status via tuple, instead it's baked into the GeneratorRuns that are already being passed along 2. The suggested status are more clearly stored in the database for historical tracking Prior to this approach I tried changing `GS.gen()` to return a tuple including the `suggested_experiment_status` but that over-complicated callsites. ## AOSC DIFF D92476170 Reviewed By: lena-kashtelyan Differential Revision: D88091530 --- ax/core/generator_run.py | 13 +++++++++ ax/core/tests/test_generator_run.py | 20 ++++++++++++++ ax/storage/json_store/encoders.py | 1 + ax/storage/sqa_store/decoder.py | 1 + ax/storage/sqa_store/encoder.py | 1 + ax/storage/sqa_store/sqa_classes.py | 3 ++ ax/storage/sqa_store/tests/test_sqa_store.py | 29 ++++++++++++++++++++ ax/utils/testing/core_stubs.py | 2 ++ 8 files changed, 70 insertions(+) diff --git a/ax/core/generator_run.py b/ax/core/generator_run.py index 2235e2c8e9a..07bb2b0b3f2 100644 --- a/ax/core/generator_run.py +++ b/ax/core/generator_run.py @@ -18,6 +18,7 @@ import pandas as pd from ax.core.arm import Arm +from ax.core.experiment_status import ExperimentStatus from ax.core.optimization_config import OptimizationConfig from ax.core.search_space import SearchSpace from ax.core.types import ( @@ -100,6 +101,7 @@ def __init__( candidate_metadata_by_arm_signature: None | (dict[str, TCandidateMetadata]) = None, generation_node_name: str | None = None, + suggested_experiment_status: ExperimentStatus | None = None, ) -> None: """Inits GeneratorRun. @@ -142,6 +144,10 @@ def __init__( via a generation strategy (in which case this name should reflect the name of the generation node in a generation strategy) or a standalone generation node (in which case this name should be ``-1``). + suggested_experiment_status: Optional ``ExperimentStatus`` that indicates + what the experiment's status should be once this generator run is + added to a trial. This is propagated from the generation node's + suggested_experiment_status field and is advisory only. """ self._arm_weight_table: OrderedDict[str, ArmWeight] = OrderedDict() if weights is None: @@ -191,6 +197,7 @@ def __init__( ) self._candidate_metadata_by_arm_signature = candidate_metadata_by_arm_signature self._generation_node_name = generation_node_name + self._suggested_experiment_status = suggested_experiment_status @property def arms(self) -> list[Arm]: @@ -287,6 +294,11 @@ def candidate_metadata_by_arm_signature( """ return self._candidate_metadata_by_arm_signature + @property + def suggested_experiment_status(self) -> ExperimentStatus | None: + """Optional suggested experiment status for this generator run.""" + return self._suggested_experiment_status + @property def param_df(self) -> pd.DataFrame: """ @@ -327,6 +339,7 @@ def clone(self) -> GeneratorRun: generator_state_after_gen=self._generator_state_after_gen, candidate_metadata_by_arm_signature=cand_metadata, generation_node_name=self._generation_node_name, + suggested_experiment_status=self.suggested_experiment_status, ) generator_run._time_created = self._time_created generator_run._generator_key = self._generator_key diff --git a/ax/core/tests/test_generator_run.py b/ax/core/tests/test_generator_run.py index 8a372be2984..9360c93293f 100644 --- a/ax/core/tests/test_generator_run.py +++ b/ax/core/tests/test_generator_run.py @@ -7,6 +7,7 @@ # pyre-strict from ax.core.arm import Arm +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( @@ -48,6 +49,10 @@ def setUp(self) -> None: search_space=self.search_space, model_predictions=self.model_predictions, ) + self.run_with_suggested_status = GeneratorRun( + arms=self.arms, + suggested_experiment_status=ExperimentStatus.INITIALIZATION, + ) def test_Init(self) -> None: self.assertEqual( @@ -184,3 +189,18 @@ def test_Sortable(self) -> None: weights=self.weights, ) self.assertTrue(generator_run1 < generator_run2) + + def test_SuggestedExperimentStatus(self) -> None: + self.assertEqual( + self.run_with_suggested_status.suggested_experiment_status, + ExperimentStatus.INITIALIZATION, + ) + + def test_SuggestedExperimentStatusDefaultNone(self) -> None: + self.assertIsNone(self.unweighted_run.suggested_experiment_status) + + def test_ClonePreservesSuggestedExperimentStatus(self) -> None: + cloned = self.run_with_suggested_status.clone() + self.assertEqual( + cloned.suggested_experiment_status, ExperimentStatus.INITIALIZATION + ) diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index ed17c193c63..87b6b44c045 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -367,6 +367,7 @@ def generator_run_to_dict(generator_run: GeneratorRun) -> dict[str, Any]: "generator_state_after_gen": gr._generator_state_after_gen, "candidate_metadata_by_arm_signature": cand_metadata, "generation_node_name": gr._generation_node_name, + "suggested_experiment_status": gr.suggested_experiment_status, } diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 612f440f716..63b83add043 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -813,6 +813,7 @@ def generator_run_from_sqa( class_decoder_registry=self.config.json_class_decoder_registry, ), generation_node_name=generator_run_sqa.generation_node_name, + suggested_experiment_status=generator_run_sqa.suggested_experiment_status, ) # Remove deprecated kwargs from generator kwargs & adapter kwargs. if generator_run._generator_kwargs is not None: diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index b96877a228f..87bf58cb88f 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -897,6 +897,7 @@ def generator_run_to_sqa( class_encoder_registry=self.config.json_class_encoder_registry, ), generation_node_name=generator_run._generation_node_name, + suggested_experiment_status=generator_run.suggested_experiment_status, ) return gr_sqa diff --git a/ax/storage/sqa_store/sqa_classes.py b/ax/storage/sqa_store/sqa_classes.py index dff59cb9e7a..1176634b9e2 100644 --- a/ax/storage/sqa_store/sqa_classes.py +++ b/ax/storage/sqa_store/sqa_classes.py @@ -215,6 +215,9 @@ class SQAGeneratorRun(Base): JSONEncodedTextDict ) generation_node_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + suggested_experiment_status: Column[ExperimentStatus | None] = Column( + IntEnum(ExperimentStatus), nullable=True + ) # relationships # Use selectin loading for collections to prevent idle timeout errors diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 9d72ea425dc..2bc73a7ed88 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -2539,6 +2539,35 @@ def test_generator_run_gen_metadata(self) -> None: ) self.assertEqual(decoded_gr.gen_metadata, gen_metadata) + def test_generator_run_suggested_experiment_status(self) -> None: + # Test round-trip with a status set. + gr = GeneratorRun( + arms=[], + suggested_experiment_status=ExperimentStatus.OPTIMIZATION, + ) + generator_run_sqa = self.encoder.generator_run_to_sqa(gr) + self.assertEqual( + generator_run_sqa.suggested_experiment_status, + ExperimentStatus.OPTIMIZATION, + ) + decoded_gr = self.decoder.generator_run_from_sqa( + generator_run_sqa, False, False + ) + self.assertEqual( + decoded_gr.suggested_experiment_status, + ExperimentStatus.OPTIMIZATION, + ) + + def test_generator_run_suggested_experiment_status_none(self) -> None: + # Test round-trip with None (default). + gr = GeneratorRun(arms=[]) + generator_run_sqa = self.encoder.generator_run_to_sqa(gr) + self.assertIsNone(generator_run_sqa.suggested_experiment_status) + decoded_gr = self.decoder.generator_run_from_sqa( + generator_run_sqa, False, False + ) + self.assertIsNone(decoded_gr.suggested_experiment_status) + def test_update_generation_strategy_incrementally(self) -> None: experiment = get_branin_experiment() generation_strategy = choose_generation_strategy( diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 62b36197304..9c17c956603 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -31,6 +31,7 @@ from ax.core.data import Data from ax.core.evaluations_to_data import raw_evaluations_to_data from ax.core.experiment import Experiment +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.map_metric import MapMetric from ax.core.metric import Metric @@ -2383,6 +2384,7 @@ def get_generator_run() -> GeneratorRun: candidate_metadata_by_arm_signature={ a.signature: {"md_key": f"md_val_{a.signature}"} for a in arms }, + suggested_experiment_status=ExperimentStatus.OPTIMIZATION, ) From 4ef1456c4f7dfaff5e06953ec693a0e475890340 Mon Sep 17 00:00:00 2001 From: Cristian Lara Date: Fri, 13 Feb 2026 22:28:22 -0800 Subject: [PATCH 2/3] Propagate suggested_experiment_status from GenerationNode to GeneratorRun (#4885) Summary: In the previous diff (D88091530) we added `suggested_experiment_status` the column to GeneratorRun, now we populate it during creation from GenerationNode. Reviewed By: lena-kashtelyan Differential Revision: D92555215 --- ax/generation_strategy/generation_node.py | 1 + .../tests/test_generation_node.py | 29 +++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/ax/generation_strategy/generation_node.py b/ax/generation_strategy/generation_node.py index 70b64d641d8..335572861f2 100644 --- a/ax/generation_strategy/generation_node.py +++ b/ax/generation_strategy/generation_node.py @@ -517,6 +517,7 @@ def gen( ) gr._generation_node_name = self.name + gr._suggested_experiment_status = self.suggested_experiment_status # TODO: When we start using `trial_type` more commonly, give it a dedicated # field on the `GeneratorRun` (or start creating trials from GS directly). if self._trial_type is not None: diff --git a/ax/generation_strategy/tests/test_generation_node.py b/ax/generation_strategy/tests/test_generation_node.py index 266873cf5ad..94a1ec7e5d2 100644 --- a/ax/generation_strategy/tests/test_generation_node.py +++ b/ax/generation_strategy/tests/test_generation_node.py @@ -58,6 +58,10 @@ def setUp(self) -> None: generator_specs=[self.sobol_generator_spec], suggested_experiment_status=ExperimentStatus.INITIALIZATION, ) + self.generation_node_without_exp_status = GenerationNode( + name="test", + generator_specs=[self.sobol_generator_spec], + ) self.branin_experiment = get_branin_experiment(with_completed_trial=True) self.branin_data = self.branin_experiment.lookup_data() self.node_short = GenerationNode( @@ -209,6 +213,31 @@ def test_gen(self) -> None: fixed_features=None, ) + def test_suggested_experiment_status_propagation(self) -> None: + """Test that suggested_experiment_status propagates from node to GR.""" + with self.subTest("with_suggested_experiment_status"): + gr = self.sobol_generation_node.gen( + experiment=self.branin_experiment, + data=self.branin_experiment.lookup_data(), + n=1, + pending_observations={"branin": []}, + ) + self.assertIsNotNone(gr) + self.assertEqual( + gr.suggested_experiment_status, + ExperimentStatus.INITIALIZATION, + ) + + with self.subTest("without_suggested_experiment_status"): + gr_without = self.generation_node_without_exp_status.gen( + experiment=self.branin_experiment, + data=self.branin_experiment.lookup_data(), + n=1, + pending_observations={"branin": []}, + ) + self.assertIsNotNone(gr_without) + self.assertIsNone(gr_without.suggested_experiment_status) + @mock_botorch_optimize def test_gen_with_trial_type(self) -> None: mbm_short = GenerationNode( From 00c300fd5eb88c429598adc655edcee397d4a26f Mon Sep 17 00:00:00 2001 From: Cristian Lara Date: Fri, 13 Feb 2026 22:28:22 -0800 Subject: [PATCH 3/3] Method to consolidate Experiment.status from generator runs (#4900) Summary: Add a new static method `experiment_status_from_generator_runs()` to `GenerationStrategy` that extracts and validates a suggested ExperimentStatus from a list of GeneratorRun objects. It collects all unique suggested_experiment_status values from the runs and: - Returns None with a warning if there are conflicting statuses across runs - Returns None with an info log if no statuses are found - Returns the single agreed-upon status otherwise Reviewed By: lena-kashtelyan Differential Revision: D92985915 --- ax/core/experiment.py | 38 ++++++++++++++++ ax/core/tests/test_experiment.py | 74 ++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+) diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 1f8fa641034..59fdf0990a3 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -254,6 +254,44 @@ def status(self, status: ExperimentStatus | None) -> None: """ self._status = status + @staticmethod + def experiment_status_from_generator_runs( + generator_runs: list[GeneratorRun], + ) -> ExperimentStatus | None: + """Extract and validate suggested experiment status from generator runs. + + Collects the suggested_experiment_status directly from the GeneratorRun + objects, validates that all runs suggest the same status, and returns + that status. + + Args: + generator_runs: List of generator runs to extract statuses from. + + Returns: + The suggested experiment status that all generator runs agree on, + or None if no statuses were found or if there are conflicting statuses. + """ + suggested_statuses: set[ExperimentStatus] = set() + for gr in generator_runs: + if gr.suggested_experiment_status is not None: + suggested_statuses.add(gr.suggested_experiment_status) + + if len(suggested_statuses) > 1: + # TODO: Consider making this invalid state an actual error once + # related development is completed. + logger.warning( + "Multiple different suggested experiment statuses found: " + f"{suggested_statuses}. " + "All generator runs used in a single gen() call should suggest the " + "same experiment status. Skipping updating experiment status." + ) + return None + + if len(suggested_statuses) == 0: + return None + + return suggested_statuses.pop() + @property def search_space(self) -> SearchSpace: """The search space for this experiment. diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index e93954368ef..42e4e8bfb39 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -18,6 +18,7 @@ from ax.core.data import Data, sort_by_trial_index_and_arm_name from ax.core.evaluations_to_data import raw_evaluations_to_data from ax.core.experiment_status import ExperimentStatus +from ax.core.generator_run import GeneratorRun from ax.core.map_metric import MapMetric from ax.core.metric import Metric from ax.core.objective import MultiObjective, Objective @@ -44,6 +45,9 @@ UnsupportedError, UserInputError, ) +from ax.generation_strategy.generation_node import GenerationNode +from ax.generation_strategy.generation_strategy import GenerationStrategy +from ax.generation_strategy.generator_spec import GeneratorSpec from ax.metrics.branin import BraninMetric from ax.metrics.hartmann6 import Hartmann6Metric from ax.metrics.noisy_function import NoisyFunctionMetric @@ -1881,6 +1885,76 @@ def test_experiment_status_property(self) -> None: self.experiment.status = ExperimentStatus.DRAFT self.assertEqual(self.experiment.status, ExperimentStatus.DRAFT) + def test_experiment_status_from_generator_runs(self) -> None: + """Test that experiment status is correctly extracted from generator runs.""" + sobol_generator_spec = GeneratorSpec( + generator_enum=Generators.SOBOL, + generator_kwargs={"silently_filter_kwargs": True}, + generator_gen_kwargs={}, + ) + + with self.subTest("gen returns GRs with correct suggested_experiment_status"): + for status in [ + ExperimentStatus.INITIALIZATION, + ExperimentStatus.OPTIMIZATION, + ]: + with self.subTest(status=status): + exp = get_branin_experiment() + node_with_status = GenerationNode( + name="test_node", + generator_specs=[sobol_generator_spec], + suggested_experiment_status=status, + ) + gs = GenerationStrategy(nodes=[node_with_status]) + gs.experiment = exp + + grs = gs.gen(experiment=exp, num_trials=1) + flat_grs = [gr for trial_grs in grs for gr in trial_grs] + + extracted_status = Experiment.experiment_status_from_generator_runs( + flat_grs + ) + self.assertEqual(extracted_status, status) + + with self.subTest("conflicting statuses return None"): + gr1 = GeneratorRun( + arms=[Arm(name="0_0", parameters={"x1": 0.0, "x2": 0.0})], + suggested_experiment_status=ExperimentStatus.INITIALIZATION, + ) + gr2 = GeneratorRun( + arms=[Arm(name="0_1", parameters={"x1": 1.0, "x2": 1.0})], + suggested_experiment_status=ExperimentStatus.OPTIMIZATION, + ) + mixed_grs = [gr1, gr2] + + result = Experiment.experiment_status_from_generator_runs(mixed_grs) + self.assertIsNone(result) + + with self.subTest("multiple trials all carry experiment status"): + exp = get_branin_experiment() + node_with_status = GenerationNode( + name="multi_trial_node", + generator_specs=[sobol_generator_spec], + suggested_experiment_status=ExperimentStatus.INITIALIZATION, + ) + gs = GenerationStrategy(nodes=[node_with_status]) + gs.experiment = exp + + grs = gs.gen(experiment=exp, num_trials=3) + + self.assertEqual(len(grs), 3) + for gr_list in grs: + self.assertEqual(len(gr_list), 1) + self.assertEqual(gr_list[0]._generation_node_name, "multi_trial_node") + self.assertEqual( + gr_list[0].suggested_experiment_status, + ExperimentStatus.INITIALIZATION, + ) + extracted_status = Experiment.experiment_status_from_generator_runs( + [gr for trial_grs in grs for gr in trial_grs] + ) + self.assertEqual(extracted_status, ExperimentStatus.INITIALIZATION) + class ExperimentWithMapDataTest(TestCase): def setUp(self) -> None: