Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions ax/core/generator_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import pandas as pd
from ax.core.arm import Arm
from ax.core.experiment_status import ExperimentStatus
from ax.core.optimization_config import OptimizationConfig
from ax.core.search_space import SearchSpace
from ax.core.types import (
Expand Down Expand Up @@ -100,6 +101,7 @@ def __init__(
candidate_metadata_by_arm_signature: None
| (dict[str, TCandidateMetadata]) = None,
generation_node_name: str | None = None,
suggested_experiment_status: ExperimentStatus | None = None,
) -> None:
"""Inits GeneratorRun.

Expand Down Expand Up @@ -142,6 +144,10 @@ def __init__(
via a generation strategy (in which case this name should reflect the
name of the generation node in a generation strategy) or a standalone
generation node (in which case this name should be ``-1``).
suggested_experiment_status: Optional ``ExperimentStatus`` that indicates
what the experiment's status should be once this generator run is
added to a trial. This is propagated from the generation node's
suggested_experiment_status field and is advisory only.
"""
self._arm_weight_table: OrderedDict[str, ArmWeight] = OrderedDict()
if weights is None:
Expand Down Expand Up @@ -191,6 +197,7 @@ def __init__(
)
self._candidate_metadata_by_arm_signature = candidate_metadata_by_arm_signature
self._generation_node_name = generation_node_name
self._suggested_experiment_status = suggested_experiment_status

@property
def arms(self) -> list[Arm]:
Expand Down Expand Up @@ -287,6 +294,11 @@ def candidate_metadata_by_arm_signature(
"""
return self._candidate_metadata_by_arm_signature

@property
def suggested_experiment_status(self) -> ExperimentStatus | None:
"""Optional suggested experiment status for this generator run."""
return self._suggested_experiment_status

@property
def param_df(self) -> pd.DataFrame:
"""
Expand Down Expand Up @@ -327,6 +339,7 @@ def clone(self) -> GeneratorRun:
generator_state_after_gen=self._generator_state_after_gen,
candidate_metadata_by_arm_signature=cand_metadata,
generation_node_name=self._generation_node_name,
suggested_experiment_status=self.suggested_experiment_status,
)
generator_run._time_created = self._time_created
generator_run._generator_key = self._generator_key
Expand Down
20 changes: 20 additions & 0 deletions ax/core/tests/test_generator_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-strict

from ax.core.arm import Arm
from ax.core.experiment_status import ExperimentStatus
from ax.core.generator_run import GeneratorRun
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
Expand Down Expand Up @@ -48,6 +49,10 @@ def setUp(self) -> None:
search_space=self.search_space,
model_predictions=self.model_predictions,
)
self.run_with_suggested_status = GeneratorRun(
arms=self.arms,
suggested_experiment_status=ExperimentStatus.INITIALIZATION,
)

def test_Init(self) -> None:
self.assertEqual(
Expand Down Expand Up @@ -184,3 +189,18 @@ def test_Sortable(self) -> None:
weights=self.weights,
)
self.assertTrue(generator_run1 < generator_run2)

def test_SuggestedExperimentStatus(self) -> None:
self.assertEqual(
self.run_with_suggested_status.suggested_experiment_status,
ExperimentStatus.INITIALIZATION,
)

def test_SuggestedExperimentStatusDefaultNone(self) -> None:
self.assertIsNone(self.unweighted_run.suggested_experiment_status)

def test_ClonePreservesSuggestedExperimentStatus(self) -> None:
cloned = self.run_with_suggested_status.clone()
self.assertEqual(
cloned.suggested_experiment_status, ExperimentStatus.INITIALIZATION
)
1 change: 1 addition & 0 deletions ax/generation_strategy/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ def gen(
)

gr._generation_node_name = self.name
gr._suggested_experiment_status = self.suggested_experiment_status
# TODO: When we start using `trial_type` more commonly, give it a dedicated
# field on the `GeneratorRun` (or start creating trials from GS directly).
if self._trial_type is not None:
Expand Down
29 changes: 29 additions & 0 deletions ax/generation_strategy/tests/test_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def setUp(self) -> None:
generator_specs=[self.sobol_generator_spec],
suggested_experiment_status=ExperimentStatus.INITIALIZATION,
)
self.generation_node_without_exp_status = GenerationNode(
name="test",
generator_specs=[self.sobol_generator_spec],
)
self.branin_experiment = get_branin_experiment(with_completed_trial=True)
self.branin_data = self.branin_experiment.lookup_data()
self.node_short = GenerationNode(
Expand Down Expand Up @@ -209,6 +213,31 @@ def test_gen(self) -> None:
fixed_features=None,
)

def test_suggested_experiment_status_propagation(self) -> None:
"""Test that suggested_experiment_status propagates from node to GR."""
with self.subTest("with_suggested_experiment_status"):
gr = self.sobol_generation_node.gen(
experiment=self.branin_experiment,
data=self.branin_experiment.lookup_data(),
n=1,
pending_observations={"branin": []},
)
self.assertIsNotNone(gr)
self.assertEqual(
gr.suggested_experiment_status,
ExperimentStatus.INITIALIZATION,
)

with self.subTest("without_suggested_experiment_status"):
gr_without = self.generation_node_without_exp_status.gen(
experiment=self.branin_experiment,
data=self.branin_experiment.lookup_data(),
n=1,
pending_observations={"branin": []},
)
self.assertIsNotNone(gr_without)
self.assertIsNone(gr_without.suggested_experiment_status)

@mock_botorch_optimize
def test_gen_with_trial_type(self) -> None:
mbm_short = GenerationNode(
Expand Down
1 change: 1 addition & 0 deletions ax/storage/json_store/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def generator_run_to_dict(generator_run: GeneratorRun) -> dict[str, Any]:
"generator_state_after_gen": gr._generator_state_after_gen,
"candidate_metadata_by_arm_signature": cand_metadata,
"generation_node_name": gr._generation_node_name,
"suggested_experiment_status": gr.suggested_experiment_status,
}


Expand Down
1 change: 1 addition & 0 deletions ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,7 @@ def generator_run_from_sqa(
class_decoder_registry=self.config.json_class_decoder_registry,
),
generation_node_name=generator_run_sqa.generation_node_name,
suggested_experiment_status=generator_run_sqa.suggested_experiment_status,
)
# Remove deprecated kwargs from generator kwargs & adapter kwargs.
if generator_run._generator_kwargs is not None:
Expand Down
1 change: 1 addition & 0 deletions ax/storage/sqa_store/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,7 @@ def generator_run_to_sqa(
class_encoder_registry=self.config.json_class_encoder_registry,
),
generation_node_name=generator_run._generation_node_name,
suggested_experiment_status=generator_run.suggested_experiment_status,
)
return gr_sqa

Expand Down
3 changes: 3 additions & 0 deletions ax/storage/sqa_store/sqa_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ class SQAGeneratorRun(Base):
JSONEncodedTextDict
)
generation_node_name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH))
suggested_experiment_status: Column[ExperimentStatus | None] = Column(
IntEnum(ExperimentStatus), nullable=True
)

# relationships
# Use selectin loading for collections to prevent idle timeout errors
Expand Down
29 changes: 29 additions & 0 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2539,6 +2539,35 @@ def test_generator_run_gen_metadata(self) -> None:
)
self.assertEqual(decoded_gr.gen_metadata, gen_metadata)

def test_generator_run_suggested_experiment_status(self) -> None:
# Test round-trip with a status set.
gr = GeneratorRun(
arms=[],
suggested_experiment_status=ExperimentStatus.OPTIMIZATION,
)
generator_run_sqa = self.encoder.generator_run_to_sqa(gr)
self.assertEqual(
generator_run_sqa.suggested_experiment_status,
ExperimentStatus.OPTIMIZATION,
)
decoded_gr = self.decoder.generator_run_from_sqa(
generator_run_sqa, False, False
)
self.assertEqual(
decoded_gr.suggested_experiment_status,
ExperimentStatus.OPTIMIZATION,
)

def test_generator_run_suggested_experiment_status_none(self) -> None:
# Test round-trip with None (default).
gr = GeneratorRun(arms=[])
generator_run_sqa = self.encoder.generator_run_to_sqa(gr)
self.assertIsNone(generator_run_sqa.suggested_experiment_status)
decoded_gr = self.decoder.generator_run_from_sqa(
generator_run_sqa, False, False
)
self.assertIsNone(decoded_gr.suggested_experiment_status)

def test_update_generation_strategy_incrementally(self) -> None:
experiment = get_branin_experiment()
generation_strategy = choose_generation_strategy(
Expand Down
2 changes: 2 additions & 0 deletions ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ax.core.data import Data
from ax.core.evaluations_to_data import raw_evaluations_to_data
from ax.core.experiment import Experiment
from ax.core.experiment_status import ExperimentStatus
from ax.core.generator_run import GeneratorRun
from ax.core.map_metric import MapMetric
from ax.core.metric import Metric
Expand Down Expand Up @@ -2383,6 +2384,7 @@ def get_generator_run() -> GeneratorRun:
candidate_metadata_by_arm_signature={
a.signature: {"md_key": f"md_val_{a.signature}"} for a in arms
},
suggested_experiment_status=ExperimentStatus.OPTIMIZATION,
)


Expand Down