diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 1f8fa641034..59fdf0990a3 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -254,6 +254,44 @@ def status(self, status: ExperimentStatus | None) -> None: """ self._status = status + @staticmethod + def experiment_status_from_generator_runs( + generator_runs: list[GeneratorRun], + ) -> ExperimentStatus | None: + """Extract and validate suggested experiment status from generator runs. + + Collects the suggested_experiment_status directly from the GeneratorRun + objects, validates that all runs suggest the same status, and returns + that status. + + Args: + generator_runs: List of generator runs to extract statuses from. + + Returns: + The suggested experiment status that all generator runs agree on, + or None if no statuses were found or if there are conflicting statuses. + """ + suggested_statuses: set[ExperimentStatus] = set() + for gr in generator_runs: + if gr.suggested_experiment_status is not None: + suggested_statuses.add(gr.suggested_experiment_status) + + if len(suggested_statuses) > 1: + # TODO: Consider making this invalid state an actual error once + # related development is completed. + logger.warning( + "Multiple different suggested experiment statuses found: " + f"{suggested_statuses}. " + "All generator runs used in a single gen() call should suggest the " + "same experiment status. Skipping updating experiment status." + ) + return None + + if len(suggested_statuses) == 0: + return None + + return suggested_statuses.pop() + @property def search_space(self) -> SearchSpace: """The search space for this experiment. diff --git a/ax/core/generator_run.py b/ax/core/generator_run.py index 2235e2c8e9a..07bb2b0b3f2 100644 --- a/ax/core/generator_run.py +++ b/ax/core/generator_run.py @@ -18,6 +18,7 @@ import pandas as pd from ax.core.arm import Arm +from ax.core.experiment_status import ExperimentStatus from ax.core.optimization_config import OptimizationConfig from ax.core.search_space import SearchSpace from ax.core.types import ( @@ -100,6 +101,7 @@ def __init__( candidate_metadata_by_arm_signature: None | (dict[str, TCandidateMetadata]) = None, generation_node_name: str | None = None, + suggested_experiment_status: ExperimentStatus | None = None, ) -> None: """Inits GeneratorRun. @@ -142,6 +144,10 @@ def __init__( via a generation strategy (in which case this name should reflect the name of the generation node in a generation strategy) or a standalone generation node (in which case this name should be ``-1``). + suggested_experiment_status: Optional ``ExperimentStatus`` that indicates + what the experiment's status should be once this generator run is + added to a trial. This is propagated from the generation node's + suggested_experiment_status field and is advisory only. """ self._arm_weight_table: OrderedDict[str, ArmWeight] = OrderedDict() if weights is None: @@ -191,6 +197,7 @@ def __init__( ) self._candidate_metadata_by_arm_signature = candidate_metadata_by_arm_signature self._generation_node_name = generation_node_name + self._suggested_experiment_status = suggested_experiment_status @property def arms(self) -> list[Arm]: @@ -287,6 +294,11 @@ def candidate_metadata_by_arm_signature( """ return self._candidate_metadata_by_arm_signature + @property + def suggested_experiment_status(self) -> ExperimentStatus | None: + """Optional suggested experiment status for this generator run.""" + return self._suggested_experiment_status + @property def param_df(self) -> pd.DataFrame: """ @@ -327,6 +339,7 @@ def clone(self) -> GeneratorRun: generator_state_after_gen=self._generator_state_after_gen, candidate_metadata_by_arm_signature=cand_metadata, generation_node_name=self._generation_node_name, + suggested_experiment_status=self.suggested_experiment_status, ) generator_run._time_created = self._time_created generator_run._generator_key = self._generator_key diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index e93954368ef..42e4e8bfb39 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -18,6 +18,7 @@ from ax.core.data import Data, sort_by_trial_index_and_arm_name from ax.core.evaluations_to_data import raw_evaluations_to_data from ax.core.experiment_status import ExperimentStatus +from ax.core.generator_run import GeneratorRun from ax.core.map_metric import MapMetric from ax.core.metric import Metric from ax.core.objective import MultiObjective, Objective @@ -44,6 +45,9 @@ UnsupportedError, UserInputError, ) +from ax.generation_strategy.generation_node import GenerationNode +from ax.generation_strategy.generation_strategy import GenerationStrategy +from ax.generation_strategy.generator_spec import GeneratorSpec from ax.metrics.branin import BraninMetric from ax.metrics.hartmann6 import Hartmann6Metric from ax.metrics.noisy_function import NoisyFunctionMetric @@ -1881,6 +1885,76 @@ def test_experiment_status_property(self) -> None: self.experiment.status = ExperimentStatus.DRAFT self.assertEqual(self.experiment.status, ExperimentStatus.DRAFT) + def test_experiment_status_from_generator_runs(self) -> None: + """Test that experiment status is correctly extracted from generator runs.""" + sobol_generator_spec = GeneratorSpec( + generator_enum=Generators.SOBOL, + generator_kwargs={"silently_filter_kwargs": True}, + generator_gen_kwargs={}, + ) + + with self.subTest("gen returns GRs with correct suggested_experiment_status"): + for status in [ + ExperimentStatus.INITIALIZATION, + ExperimentStatus.OPTIMIZATION, + ]: + with self.subTest(status=status): + exp = get_branin_experiment() + node_with_status = GenerationNode( + name="test_node", + generator_specs=[sobol_generator_spec], + suggested_experiment_status=status, + ) + gs = GenerationStrategy(nodes=[node_with_status]) + gs.experiment = exp + + grs = gs.gen(experiment=exp, num_trials=1) + flat_grs = [gr for trial_grs in grs for gr in trial_grs] + + extracted_status = Experiment.experiment_status_from_generator_runs( + flat_grs + ) + self.assertEqual(extracted_status, status) + + with self.subTest("conflicting statuses return None"): + gr1 = GeneratorRun( + arms=[Arm(name="0_0", parameters={"x1": 0.0, "x2": 0.0})], + suggested_experiment_status=ExperimentStatus.INITIALIZATION, + ) + gr2 = GeneratorRun( + arms=[Arm(name="0_1", parameters={"x1": 1.0, "x2": 1.0})], + suggested_experiment_status=ExperimentStatus.OPTIMIZATION, + ) + mixed_grs = [gr1, gr2] + + result = Experiment.experiment_status_from_generator_runs(mixed_grs) + self.assertIsNone(result) + + with self.subTest("multiple trials all carry experiment status"): + exp = get_branin_experiment() + node_with_status = GenerationNode( + name="multi_trial_node", + generator_specs=[sobol_generator_spec], + suggested_experiment_status=ExperimentStatus.INITIALIZATION, + ) + gs = GenerationStrategy(nodes=[node_with_status]) + gs.experiment = exp + + grs = gs.gen(experiment=exp, num_trials=3) + + self.assertEqual(len(grs), 3) + for gr_list in grs: + self.assertEqual(len(gr_list), 1) + self.assertEqual(gr_list[0]._generation_node_name, "multi_trial_node") + self.assertEqual( + gr_list[0].suggested_experiment_status, + ExperimentStatus.INITIALIZATION, + ) + extracted_status = Experiment.experiment_status_from_generator_runs( + [gr for trial_grs in grs for gr in trial_grs] + ) + self.assertEqual(extracted_status, ExperimentStatus.INITIALIZATION) + class ExperimentWithMapDataTest(TestCase): def setUp(self) -> None: diff --git a/ax/core/tests/test_generator_run.py b/ax/core/tests/test_generator_run.py index 8a372be2984..9360c93293f 100644 --- a/ax/core/tests/test_generator_run.py +++ b/ax/core/tests/test_generator_run.py @@ -7,6 +7,7 @@ # pyre-strict from ax.core.arm import Arm +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( @@ -48,6 +49,10 @@ def setUp(self) -> None: search_space=self.search_space, model_predictions=self.model_predictions, ) + self.run_with_suggested_status = GeneratorRun( + arms=self.arms, + suggested_experiment_status=ExperimentStatus.INITIALIZATION, + ) def test_Init(self) -> None: self.assertEqual( @@ -184,3 +189,18 @@ def test_Sortable(self) -> None: weights=self.weights, ) self.assertTrue(generator_run1 < generator_run2) + + def test_SuggestedExperimentStatus(self) -> None: + self.assertEqual( + self.run_with_suggested_status.suggested_experiment_status, + ExperimentStatus.INITIALIZATION, + ) + + def test_SuggestedExperimentStatusDefaultNone(self) -> None: + self.assertIsNone(self.unweighted_run.suggested_experiment_status) + + def test_ClonePreservesSuggestedExperimentStatus(self) -> None: + cloned = self.run_with_suggested_status.clone() + self.assertEqual( + cloned.suggested_experiment_status, ExperimentStatus.INITIALIZATION + ) diff --git a/ax/generation_strategy/generation_node.py b/ax/generation_strategy/generation_node.py index 70b64d641d8..335572861f2 100644 --- a/ax/generation_strategy/generation_node.py +++ b/ax/generation_strategy/generation_node.py @@ -517,6 +517,7 @@ def gen( ) gr._generation_node_name = self.name + gr._suggested_experiment_status = self.suggested_experiment_status # TODO: When we start using `trial_type` more commonly, give it a dedicated # field on the `GeneratorRun` (or start creating trials from GS directly). if self._trial_type is not None: diff --git a/ax/generation_strategy/tests/test_generation_node.py b/ax/generation_strategy/tests/test_generation_node.py index 266873cf5ad..94a1ec7e5d2 100644 --- a/ax/generation_strategy/tests/test_generation_node.py +++ b/ax/generation_strategy/tests/test_generation_node.py @@ -58,6 +58,10 @@ def setUp(self) -> None: generator_specs=[self.sobol_generator_spec], suggested_experiment_status=ExperimentStatus.INITIALIZATION, ) + self.generation_node_without_exp_status = GenerationNode( + name="test", + generator_specs=[self.sobol_generator_spec], + ) self.branin_experiment = get_branin_experiment(with_completed_trial=True) self.branin_data = self.branin_experiment.lookup_data() self.node_short = GenerationNode( @@ -209,6 +213,31 @@ def test_gen(self) -> None: fixed_features=None, ) + def test_suggested_experiment_status_propagation(self) -> None: + """Test that suggested_experiment_status propagates from node to GR.""" + with self.subTest("with_suggested_experiment_status"): + gr = self.sobol_generation_node.gen( + experiment=self.branin_experiment, + data=self.branin_experiment.lookup_data(), + n=1, + pending_observations={"branin": []}, + ) + self.assertIsNotNone(gr) + self.assertEqual( + gr.suggested_experiment_status, + ExperimentStatus.INITIALIZATION, + ) + + with self.subTest("without_suggested_experiment_status"): + gr_without = self.generation_node_without_exp_status.gen( + experiment=self.branin_experiment, + data=self.branin_experiment.lookup_data(), + n=1, + pending_observations={"branin": []}, + ) + self.assertIsNotNone(gr_without) + self.assertIsNone(gr_without.suggested_experiment_status) + @mock_botorch_optimize def test_gen_with_trial_type(self) -> None: mbm_short = GenerationNode( diff --git a/ax/orchestration/orchestrator.py b/ax/orchestration/orchestrator.py index a35cfe8aa51..d9ae586fc8a 100644 --- a/ax/orchestration/orchestrator.py +++ b/ax/orchestration/orchestrator.py @@ -540,6 +540,11 @@ def generate_candidates( if len(new_trials) > 0: new_generator_runs = [gr for t in new_trials for gr in t.generator_runs] + suggested_status = Experiment.experiment_status_from_generator_runs( + new_generator_runs + ) + if suggested_status is not None: + self.experiment.status = suggested_status self._save_or_update_trials_and_generation_strategy_if_possible( experiment=self.experiment, trials=new_trials + self.experiment.trials_by_status[TrialStatus.STALE], diff --git a/ax/orchestration/tests/test_orchestrator.py b/ax/orchestration/tests/test_orchestrator.py index 67f982ca1a1..fd858a30d0f 100644 --- a/ax/orchestration/tests/test_orchestrator.py +++ b/ax/orchestration/tests/test_orchestrator.py @@ -24,6 +24,7 @@ from ax.core.batch_trial import BatchTrial from ax.core.data import Data, MAP_KEY from ax.core.experiment import Experiment +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric from ax.core.multi_type_experiment import MultiTypeExperiment @@ -49,6 +50,7 @@ GenerationStep, GenerationStrategy, ) +from ax.generation_strategy.generator_spec import GeneratorSpec from ax.generation_strategy.transition_criterion import MaxGenerationParallelism from ax.metrics.branin import BraninMetric from ax.metrics.branin_map import BraninTimestampMapMetric @@ -2667,6 +2669,47 @@ def test_generate_candidates_works_for_iteration(self) -> None: len(candidate_trial.arms), none_throws(orchestrator.options.batch_size) ) + def test_generate_candidates_updates_experiment_status(self) -> None: + init_test_engine_and_session_factory(force_init=True) + node_with_status = GenerationNode( + name="test_node", + generator_specs=[ + GeneratorSpec( + generator_enum=Generators.SOBOL, + model_kwargs={}, + ) + ], + suggested_experiment_status=ExperimentStatus.INITIALIZATION, + ) + gs = GenerationStrategy(nodes=[node_with_status]) + + # Create orchestrator with this generation strategy + self.branin_experiment.runner = InfinitePollRunner() + orchestrator = Orchestrator( + experiment=self.branin_experiment, + generation_strategy=gs, + options=OrchestratorOptions( + init_seconds_between_polls=0, + batch_size=1, + trial_type=TrialType.BATCH_TRIAL, + **self.orchestrator_options_kwargs, + ), + db_settings=self.db_settings, + ) + + # Verify the experiment status is not currently ExperimentStatus.INITIALIZATION + self.assertNotEqual( + orchestrator.experiment.status, ExperimentStatus.INITIALIZATION + ) + + # Execute: generate candidates + orchestrator.generate_candidates(num_trials=1) + + # Assert: verify experiment status was updated + self.assertEqual( + orchestrator.experiment.status, ExperimentStatus.INITIALIZATION + ) + def test_generate_candidates_does_not_generate_if_missing_data(self) -> None: # GIVEN a orchestrator that can't fetch data self.branin_experiment.optimization_config = OptimizationConfig( diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index ed17c193c63..87b6b44c045 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -367,6 +367,7 @@ def generator_run_to_dict(generator_run: GeneratorRun) -> dict[str, Any]: "generator_state_after_gen": gr._generator_state_after_gen, "candidate_metadata_by_arm_signature": cand_metadata, "generation_node_name": gr._generation_node_name, + "suggested_experiment_status": gr.suggested_experiment_status, } diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 612f440f716..63b83add043 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -813,6 +813,7 @@ def generator_run_from_sqa( class_decoder_registry=self.config.json_class_decoder_registry, ), generation_node_name=generator_run_sqa.generation_node_name, + suggested_experiment_status=generator_run_sqa.suggested_experiment_status, ) # Remove deprecated kwargs from generator kwargs & adapter kwargs. if generator_run._generator_kwargs is not None: diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index b96877a228f..87bf58cb88f 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -897,6 +897,7 @@ def generator_run_to_sqa( class_encoder_registry=self.config.json_class_encoder_registry, ), generation_node_name=generator_run._generation_node_name, + suggested_experiment_status=generator_run.suggested_experiment_status, ) return gr_sqa diff --git a/ax/storage/sqa_store/save.py b/ax/storage/sqa_store/save.py index cd3f6ba2eb3..722158c40b0 100644 --- a/ax/storage/sqa_store/save.py +++ b/ax/storage/sqa_store/save.py @@ -46,6 +46,24 @@ logger: Logger = get_logger(__name__) +def _assert_experiment_saved(experiment: Experiment) -> int: + """Assert that an experiment has been saved to the database. + + Args: + experiment: The experiment to check. + + Returns: + The experiment's database ID. + + Raises: + UserInputError: If the experiment has not been saved (db_id is None). + """ + exp_id = experiment.db_id + if exp_id is None: + raise UserInputError("Experiment must be saved before being updated.") + return exp_id + + def save_experiment( experiment: Experiment, config: SQAConfig | None = None, @@ -458,14 +476,11 @@ def update_runner_on_experiment( ) -> None: runner_sqa_class = encoder.config.class_to_sqa_class[Runner] - exp_id = experiment.db_id - if exp_id is None: - raise ValueError("Experiment must be saved before being updated.") + exp_id: int = _assert_experiment_saved(experiment) with session_scope() as session: session.query(runner_sqa_class).filter_by(experiment_id=exp_id).delete() - # pyre-fixme[53]: Captured variable `exp_id` is not annotated. # pyre-fixme[3]: Return type must be annotated. def add_experiment_id(sqa: SQARunner): sqa.experiment_id = exp_id @@ -486,9 +501,7 @@ def update_outcome_constraint_on_experiment( ) -> None: oc_sqa_class = encoder.config.class_to_sqa_class[Metric] - exp_id: int | None = experiment.db_id - if exp_id is None: - raise UserInputError("Experiment must be saved before being updated.") + exp_id: int = _assert_experiment_saved(experiment) oc_id = outcome_constraint.db_id if oc_id is not None: with session_scope() as session: @@ -519,9 +532,7 @@ def update_properties_on_experiment( config = SQAConfig() if config is None else config exp_sqa_class = config.class_to_sqa_class[Experiment] - exp_id = experiment_with_updated_properties.db_id - if exp_id is None: - raise ValueError("Experiment must be saved before being updated.") + exp_id = _assert_experiment_saved(experiment_with_updated_properties) with session_scope() as session: session.query(exp_sqa_class).filter_by(id=exp_id).update( @@ -531,6 +542,33 @@ def update_properties_on_experiment( ) +def update_experiment_status( + experiment: Experiment, + config: SQAConfig | None = None, +) -> None: + """Update experiment status in the database. + + This function provides an efficient way to update only the experiment's status + field without re-saving the entire experiment. Use this when you need to persist + status changes immediately after calling status transition methods + (e.g., mark_initialization(), mark_optimization()). + + Note: save_experiment() already handles status updates, so this function is + optional. Use it when you need status-only updates for efficiency. + """ + config = SQAConfig() if config is None else config + exp_sqa_class = config.class_to_sqa_class[Experiment] + + exp_id = _assert_experiment_saved(experiment) + + with session_scope() as session: + session.query(exp_sqa_class).filter_by(id=exp_id).update( + { + "status": experiment.status, + } + ) + + def update_properties_on_trial( trial_with_updated_properties: BaseTrial, config: SQAConfig | None = None, diff --git a/ax/storage/sqa_store/sqa_classes.py b/ax/storage/sqa_store/sqa_classes.py index dff59cb9e7a..1176634b9e2 100644 --- a/ax/storage/sqa_store/sqa_classes.py +++ b/ax/storage/sqa_store/sqa_classes.py @@ -215,6 +215,9 @@ class SQAGeneratorRun(Base): JSONEncodedTextDict ) generation_node_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) + suggested_experiment_status: Column[ExperimentStatus | None] = Column( + IntEnum(ExperimentStatus), nullable=True + ) # relationships # Use selectin loading for collections to prevent idle timeout errors diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 9d72ea425dc..30a915c3a07 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -53,6 +53,7 @@ ObjectNotFoundError, TrialMutationError, UnsupportedError, + UserInputError, ) from ax.exceptions.storage import JSONDecodeError, SQADecodeError, SQAEncodeError from ax.generation_strategy.dispatch_utils import choose_generation_strategy_legacy @@ -2539,6 +2540,35 @@ def test_generator_run_gen_metadata(self) -> None: ) self.assertEqual(decoded_gr.gen_metadata, gen_metadata) + def test_generator_run_suggested_experiment_status(self) -> None: + # Test round-trip with a status set. + gr = GeneratorRun( + arms=[], + suggested_experiment_status=ExperimentStatus.OPTIMIZATION, + ) + generator_run_sqa = self.encoder.generator_run_to_sqa(gr) + self.assertEqual( + generator_run_sqa.suggested_experiment_status, + ExperimentStatus.OPTIMIZATION, + ) + decoded_gr = self.decoder.generator_run_from_sqa( + generator_run_sqa, False, False + ) + self.assertEqual( + decoded_gr.suggested_experiment_status, + ExperimentStatus.OPTIMIZATION, + ) + + def test_generator_run_suggested_experiment_status_none(self) -> None: + # Test round-trip with None (default). + gr = GeneratorRun(arms=[]) + generator_run_sqa = self.encoder.generator_run_to_sqa(gr) + self.assertIsNone(generator_run_sqa.suggested_experiment_status) + decoded_gr = self.decoder.generator_run_from_sqa( + generator_run_sqa, False, False + ) + self.assertIsNone(decoded_gr.suggested_experiment_status) + def test_update_generation_strategy_incrementally(self) -> None: experiment = get_branin_experiment() generation_strategy = choose_generation_strategy( @@ -2659,7 +2689,7 @@ def test_update_generation_strategy_steps(self) -> None: def test_update_runner(self) -> None: experiment = get_branin_experiment() - with self.assertRaisesRegex(ValueError, ".* must be saved before"): + with self.assertRaisesRegex(UserInputError, ".* must be saved before"): update_runner_on_experiment( experiment=experiment, # pyre-fixme[6]: For 2nd param expected `Runner` but got `None`. diff --git a/ax/storage/sqa_store/with_db_settings_base.py b/ax/storage/sqa_store/with_db_settings_base.py index 0defee1c634..3dc3c24e799 100644 --- a/ax/storage/sqa_store/with_db_settings_base.py +++ b/ax/storage/sqa_store/with_db_settings_base.py @@ -65,6 +65,7 @@ _save_or_update_trials, _update_generation_strategy, save_analysis_card, + update_experiment_status, update_properties_on_experiment, update_runner_on_experiment, ) @@ -325,6 +326,11 @@ def _save_or_update_trials_and_generation_strategy_if_possible( new_generator_runs=new_generator_runs, reduce_state_generator_runs=reduce_state_generator_runs, ) + if experiment.status is not None and self.db_settings_set: + update_experiment_status( + experiment=experiment, + config=self.db_settings.encoder.config, + ) return # No retries needed, covered in `self._save_or_update_trials_in_db_if_possible` diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 62b36197304..9c17c956603 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -31,6 +31,7 @@ from ax.core.data import Data from ax.core.evaluations_to_data import raw_evaluations_to_data from ax.core.experiment import Experiment +from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.map_metric import MapMetric from ax.core.metric import Metric @@ -2383,6 +2384,7 @@ def get_generator_run() -> GeneratorRun: candidate_metadata_by_arm_signature={ a.signature: {"md_key": f"md_val_{a.signature}"} for a in arms }, + suggested_experiment_status=ExperimentStatus.OPTIMIZATION, )