From 64576960873c777317bdf54207c8a97937aa3bcd Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Thu, 5 Feb 2026 05:55:38 -0800 Subject: [PATCH 1/4] improve optimization complete logic (#4828) Summary: This criteria updates the completion state logic to assume if a node can transition, and that transition is to itself, then the optimization is complete. This works because should_transition_to_next_node only considers transtion blocking criteria (ie not max parallelism) when thinking about should transition or not. And if a node points to itself, we can assume that signifies the end of the optimiztion (steps are initialized this way earlier in this stack). this allows allows for the gs to be re-called into, and the tc criterion to change thus putting it back into a non-complete state. An alternative I considered is to check if all transition edges are completed, and at least one points to self. This would look something like the below snippet. It would be much more expensive to evaluate, and is guarding against a malformed strategy. Edges are already known to be created in order of importance, and self transition edges should be considered ending edges when their importance is considered ``` property def optimization_complete(self) -> bool: if len(self._curr.transition_criteria) == 0: return False # Check ALL transition edges, not just the first matching one for next_node, all_tc in self._curr.transition_edges.items(): transition_blocking = [tc for tc in all_tc if tc.block_transition_if_unmet] if not transition_blocking: continue all_met = all( tc.is_met(experiment=self.experiment, curr_node=self._curr) for tc in transition_blocking ) if all_met: # An edge's criteria are met - check where it points if next_node != self._curr.name: return False # Can transition to different node, not complete # All met edges (if any) point to self # Check if we actually have any met criteria pointing to self can_transition, next_node = self._curr.should_transition_to_next_node( raise_data_required_error=False ) return can_transition and next_node == self._curr.name ``` The thrid alternative is to instate "compeletion node", which i think could be viable in the future if we have more complex generation strategies than we currently support, and the self generation logic is too cumbersome. For now though, I think this is a pretty nice simplification that also should have some compute wins. Going from O (number of nodes * number of TC per node), to O(number of tc on current node) Differential Revision: D91549954 --- ax/generation_strategy/generation_node.py | 16 ----------- ax/generation_strategy/generation_strategy.py | 28 +++++++++++++------ .../tests/test_generation_strategy.py | 22 +++++++++++++++ 3 files changed, 42 insertions(+), 24 deletions(-) diff --git a/ax/generation_strategy/generation_node.py b/ax/generation_strategy/generation_node.py index a3efebc8cd7..b9f0a4c55eb 100644 --- a/ax/generation_strategy/generation_node.py +++ b/ax/generation_strategy/generation_node.py @@ -41,7 +41,6 @@ ) from ax.generation_strategy.generator_spec import GeneratorSpec from ax.generation_strategy.transition_criterion import ( - AutoTransitionAfterGen, MaxGenerationParallelism, MinTrials, TransitionCriterion, @@ -249,21 +248,6 @@ def experiment(self) -> Experiment: """Returns the experiment associated with this GenerationStrategy""" return self.generation_strategy.experiment - @property - def is_completed(self) -> bool: - """Returns True if this GenerationNode is complete and should transition to - the next node. - """ - # TODO: @mgarrard make this logic more robust and general - # We won't mark a node completed if it has an AutoTransitionAfterGen criterion - # as this is typically used in cyclic generation strategies - should_transition, _ = self.should_transition_to_next_node( - raise_data_required_error=False - ) - return should_transition and not any( - isinstance(tc, AutoTransitionAfterGen) for tc in self.transition_criteria - ) - @property def previous_node(self) -> GenerationNode | None: """Returns the previous ``GenerationNode``, if any.""" diff --git a/ax/generation_strategy/generation_strategy.py b/ax/generation_strategy/generation_strategy.py index 059d929080b..07942e05e26 100644 --- a/ax/generation_strategy/generation_strategy.py +++ b/ax/generation_strategy/generation_strategy.py @@ -180,8 +180,20 @@ def last_generator_run(self) -> GeneratorRun | None: @property def optimization_complete(self) -> bool: - """Checks whether all nodes are completed in the generation strategy.""" - return all(node.is_completed for node in self._nodes) + """Checks whether optimization is complete. + + A strategy is complete when the current node's transition criteria + are met and point back to itself (self-transition). + + Nodes with no transition_criteria are infinite by design and never complete. + """ + if len(self._curr.transition_criteria) == 0: + return False + + can_transition, next_node = self._curr.should_transition_to_next_node( + raise_data_required_error=False + ) + return can_transition and next_node == self._curr.name def gen_single_trial( self, @@ -612,13 +624,13 @@ def _maybe_transition_to_next_node( self, raise_data_required_error: bool = True, ) -> bool: - """Moves this generation strategy to next node if the current node is completed, - and it is not the last node in this generation strategy. This method is safe to - use both when generating candidates or simply checking how many generator runs - (to be made into trials) can currently be produced. + """Moves this generation strategy to next node if the current node's + transition criteria are met. This method is safe to use both when generating + candidates or simply checking how many generator runs (to be made into trials) + can currently be produced. - NOTE: this method raises ``GenerationStrategyCompleted`` error if the current - generation node is complete, but it is also the last in generation strategy. + NOTE: this method raises ``GenerationStrategyCompleted`` error if the + optimization is complete Args: raise_data_required_error: Whether to raise ``DataRequiredError`` in the diff --git a/ax/generation_strategy/tests/test_generation_strategy.py b/ax/generation_strategy/tests/test_generation_strategy.py index a44c1a9af05..cb954845614 100644 --- a/ax/generation_strategy/tests/test_generation_strategy.py +++ b/ax/generation_strategy/tests/test_generation_strategy.py @@ -2000,6 +2000,28 @@ def test_gs_with_input_constructor(self) -> None: self.assertEqual(trial.generator_runs[1]._generation_node_name, "sobol_3") self.assertEqual(len(trial.generator_runs[1].arms), 8) + def test_optimization_complete_single_node_no_criteria(self) -> None: + """Test that a single node with no transition_criteria never completes.""" + exp = get_branin_experiment() + gs = GenerationStrategy( + nodes=[ + GenerationNode( + name="infinite sobol", + generator_specs=[self.sobol_generator_spec], + transition_criteria=[], # No criteria = infinite by design + ), + ] + ) + gs.experiment = exp + + # Generate many trials - never completes + for _ in range(3): + self.assertFalse(gs.optimization_complete) + gr = gs.gen_single_trial(experiment=exp) + exp.new_trial(generator_run=gr).mark_running(no_runner_required=True) + + self.assertFalse(gs.optimization_complete) + # ------------- Testing helpers (put tests above this line) ------------- def _run_GS_for_N_rounds( From 7ade708c446a37126167d62cc4b669c4316f2a52 Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Thu, 5 Feb 2026 05:55:38 -0800 Subject: [PATCH 2/4] Basic transition logic simplification (#4829) Summary: Since transition_to is now required on transition criterion, we can remove checks/asserts related to none checks. this is a basic no-op simplification. Futher restructuring seperated into a different diff for ease of review Reviewed By: bletham Differential Revision: D91398877 --- ax/generation_strategy/generation_strategy.py | 27 ++----------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/ax/generation_strategy/generation_strategy.py b/ax/generation_strategy/generation_strategy.py index 07942e05e26..c6cdd81ea74 100644 --- a/ax/generation_strategy/generation_strategy.py +++ b/ax/generation_strategy/generation_strategy.py @@ -433,24 +433,9 @@ def _validate_and_set_node_graph(self, nodes: list[GenerationNode]) -> None: # Validate transition edges: # - All `transition_to` targets must exist in this GS # - All TCs on one edge must have the same `continue_trial_generation` setting - # All but `MaxGenerationParallelism` TCs must have a `transition_to` set for node in nodes: for next_node, tcs in node.transition_edges.items(): - if next_node is None: - # TODO[drfreund]: Handle the case of the last generation step not - # having any transition criteria. - # TODO[mgarrard]: Remove MaxGenerationParallelism check when - # we update TransitionCriterion always define `transition_to` - # NOTE: This is done in D86066476 - for tc in tcs: - if "MaxGenerationParallelism" not in tc.criterion_class: - raise GenerationStrategyMisconfiguredException( - error_info="Only MaxGenerationParallelism transition" - " criterion can have a null `transition_to` argument," - f" but {tc.criterion_class} does not define " - f"`transition_to` on {node.name}." - ) - elif next_node not in node_names: + if next_node not in node_names: raise GenerationStrategyMisconfiguredException( error_info=f"`transition_to` argument " f"{next_node} does not correspond to any node in" @@ -612,7 +597,6 @@ def _should_continue_gen_for_trial(self) -> bool: # if we will transition nodes, check if the transition criterion which define # the transition from this node to the next node indicate that we should # continue generating in the same trial, otherwise end the generation. - assert next_node is not None return all( tc.continue_trial_generation for tc in self._curr.transition_edges[next_node] @@ -648,12 +632,5 @@ def _maybe_transition_to_next_node( f"Generation strategy {self} generated all the trials as " "specified in its nodes." ) - if next_node is None: - # If the last node did not specify which node to transition to, - # move to the next node in the list. - current_node_index = self._nodes.index(self._curr) - next_node = self._nodes[current_node_index + 1].name - for node in self._nodes: - if node.name == next_node: - self._curr = node + self._curr = self.nodes_by_name[next_node] return move_to_next_node From 5d2578180154a2e5bf0feb355c5629e2e6c6f027 Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Thu, 5 Feb 2026 05:55:38 -0800 Subject: [PATCH 3/4] Add caching to common methods (#4830) Summary: This method is called many, many times during generation and it's computational cost adds up over time. By cacheing it we can significant improvements in computation time, especially in high trial count regimes. Reviewed By: mpolson64 Differential Revision: D91552553 --- ax/generation_strategy/generation_node.py | 25 +++++++++++++++---- ax/generation_strategy/generation_strategy.py | 4 +++ 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/ax/generation_strategy/generation_node.py b/ax/generation_strategy/generation_node.py index b9f0a4c55eb..d0ddd5284fe 100644 --- a/ax/generation_strategy/generation_node.py +++ b/ax/generation_strategy/generation_node.py @@ -190,6 +190,10 @@ def __init__( self.fallback_specs = ( fallback_specs if fallback_specs is not None else DEFAULT_FALLBACK ) + # Cache for trials_from_node property to avoid recomputation + # on every access. Invalidated when trial count changes. + self._trials_from_node_cache: set[int] | None = None + self._cached_trial_count: int = -1 @property def name(self) -> str: @@ -724,17 +728,28 @@ def _pick_fitted_adapter_to_gen_from(self) -> GeneratorSpec: def trials_from_node(self) -> set[int]: """Returns a set containing the indices of trials generated by this node. + Results are cached and invalidated when the experiment's trial count changes. + Returns: Set[int]: A set containing all the indices of trials generated by this node. """ + current_trial_count = len(self.experiment.trials) + if ( + self._trials_from_node_cache is not None + and self._cached_trial_count == current_trial_count + ): + return self._trials_from_node_cache + + # (re)-build cache trials_from_node = set() - for _idx, trial in self.experiment.trials.items(): + for trial in self.experiment.trials.values(): for gr in trial.generator_runs: - if ( - gr._generation_node_name is not None - and gr._generation_node_name == self.name - ): + if gr._generation_node_name == self.name: trials_from_node.add(trial.index) + break + + self._trials_from_node_cache = trials_from_node + self._cached_trial_count = current_trial_count return trials_from_node @property diff --git a/ax/generation_strategy/generation_strategy.py b/ax/generation_strategy/generation_strategy.py index c6cdd81ea74..b680e31e001 100644 --- a/ax/generation_strategy/generation_strategy.py +++ b/ax/generation_strategy/generation_strategy.py @@ -361,6 +361,10 @@ def _unset_non_persistent_state_fields(self) -> None: n._step_index = None if len(n.generator_specs) > 1: n._generator_spec_to_gen_from = None + # Reset cache fields that are used for performance optimization only + # and should not affect equality comparisons. + n._trials_from_node_cache = None + n._cached_trial_count = -1 # TODO: Deprecate `steps` argument fully in Q1'26. def _validate_and_set_step_sequence(self, steps: list[GenerationNode]) -> None: From 211421b1321acb0f3d19b1a1bac1aed208620fa3 Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Thu, 5 Feb 2026 05:55:38 -0800 Subject: [PATCH 4/4] Split TransitionCriterion into TransitionCriterion and GenerationCriterion (#4854) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: **TLDR:** This diff splits TransitionCriterion into (1) TransitionCriterion and (2) GenerationBlockingCriterion. I think this makes sense to do because it *greatly* increases the conceptual clarity of the transition criterion. Some ways it does this include: 1. Removal of confusing dual purpose flags — block_transition_if_unmet and block_generation_if_met flags. Now transition criteria are inferred to block transition if unmet and generation criteria are inferred to raise informative errors if the criteria is met. 2. Each criterion contains less flags, and the flags are more directly intuitive. 3. With upcoming removal of special logic for online, we will need to add more generation blocking criteria (ie do we have an opt config), it is better to make this change before adding more criteria that will need to be migrated 4. It will allows the logic for transition and generation to be smoother — this diff keeps things ~= to exisiting logic as possible to minimize diff review overhead, but in subsequent diffs we can save fit time if we know we can’t generate from this node + can’t transition. It will also allow for some further clarification on generation/transition blocking logic that i think is contributing to the confusion of the file 5. i like that creating a new generation blocking criteria with a specific error to raise is easy and painless **Cons of this change:** - it’s a large change, sorry about that. - There is some duplication between TrialBased transition criterion and generation criterion. I explored using a Mixin here, but i find mixins tend to add unnecessary inheritance structures to reason about. **Most important files for review, in order of importance** 1. transition_criterion.py 2. generation_node.py 3. decoder.py 4. encoders.py 5. registery.py 6. generation_strategy_dispatch.py 7. generation_nodes.py 8. generation_strategy.py The remaining files are mainly trivial updates to tests **Note about backwards compatibility:** * This diff will directly decode legacy MaxGenerationParallelism as a generation blocking criterion called MaxGenrationParallelism * Historically, there are some instances of mintrials that have block_gen_if_met=True, this usually comes from enforce_num_trials=True. Now we call this MaxTrialsAwaitingData, and MinTrials is decoded as that. I am open to other, better names for this new criterion. **Other notes/potential improvements:** - we could split transition criterion, generation criterion, and utils into their own files. i kinda like them together, and if we do want to do this split i’d like to do it in a follow up to try to minimize an already v large blast radius Differential Revision: D92201085 --- ax/api/utils/generation_strategy_dispatch.py | 18 +- .../test_generation_strategy_dispatch.py | 19 +- ax/generation_strategy/generation_node.py | 162 +++-- ax/generation_strategy/generation_strategy.py | 6 +- .../tests/test_dispatch_utils.py | 53 +- .../tests/test_generation_node.py | 9 +- .../tests/test_generation_strategy.py | 42 +- .../tests/test_transition_criterion.py | 70 ++- .../transition_criterion.py | 587 ++++++++++-------- ax/orchestration/tests/test_orchestrator.py | 26 +- ax/service/ax_client.py | 8 +- ax/service/tests/test_ax_client.py | 29 +- ax/storage/json_store/decoder.py | 131 +++- ax/storage/json_store/encoders.py | 15 +- ax/storage/json_store/registry.py | 6 +- .../json_store/tests/test_json_store.py | 124 ++++ ax/storage/sqa_store/tests/test_sqa_store.py | 92 ++- ax/utils/testing/core_stubs.py | 34 +- ax/utils/testing/modeling_stubs.py | 15 +- .../external_generation_node.ipynb | 12 +- 20 files changed, 954 insertions(+), 504 deletions(-) diff --git a/ax/api/utils/generation_strategy_dispatch.py b/ax/api/utils/generation_strategy_dispatch.py index f3f60d2fa35..91c1f9576a2 100644 --- a/ax/api/utils/generation_strategy_dispatch.py +++ b/ax/api/utils/generation_strategy_dispatch.py @@ -21,7 +21,7 @@ GenerationStrategy, ) from ax.generation_strategy.generator_spec import GeneratorSpec -from ax.generation_strategy.transition_criterion import MinTrials +from ax.generation_strategy.transition_criterion import MaxTrialsAwaitingData, MinTrials from ax.generators.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec from botorch.acquisition.acquisition import AcquisitionFunction from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP @@ -70,21 +70,28 @@ def _get_sobol_node( MinTrials( # This represents the initialization budget. threshold=initialization_budget, transition_to="MBM", - block_gen_if_met=(not allow_exceeding_initialization_budget), - block_transition_if_unmet=True, use_all_trials_in_exp=use_existing_trials_for_initialization, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], ), MinTrials( # This represents minimum observed trials requirement. threshold=min_observed_initialization_trials, transition_to="MBM", - block_gen_if_met=False, - block_transition_if_unmet=True, use_all_trials_in_exp=True, only_in_statuses=[TrialStatus.COMPLETED], count_only_trials_with_data=True, ), ] + # If we want to enforce the initialization budget, add a generation blocking + # criterion that prevents exceeding the budget. + generation_blocking_criteria = None + if not allow_exceeding_initialization_budget: + generation_blocking_criteria = [ + MaxTrialsAwaitingData( + threshold=initialization_budget, + not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], + use_all_trials_in_exp=use_existing_trials_for_initialization, + ) + ] return GenerationNode( name="Sobol", generator_specs=[ @@ -94,6 +101,7 @@ def _get_sobol_node( ) ], transition_criteria=transition_criteria, + generation_blocking_criteria=generation_blocking_criteria, should_deduplicate=True, ) diff --git a/ax/api/utils/tests/test_generation_strategy_dispatch.py b/ax/api/utils/tests/test_generation_strategy_dispatch.py index 958eccc0a46..9e844acb660 100644 --- a/ax/api/utils/tests/test_generation_strategy_dispatch.py +++ b/ax/api/utils/tests/test_generation_strategy_dispatch.py @@ -19,7 +19,7 @@ from ax.exceptions.core import UserInputError from ax.generation_strategy.center_generation_node import CenterGenerationNode from ax.generation_strategy.dispatch_utils import get_derelativize_config -from ax.generation_strategy.transition_criterion import MinTrials +from ax.generation_strategy.transition_criterion import MaxTrialsAwaitingData, MinTrials from ax.generators.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( @@ -99,16 +99,12 @@ def test_choose_gs_fast_with_options(self) -> None: MinTrials( threshold=2, transition_to="MBM", - block_gen_if_met=False, - block_transition_if_unmet=True, use_all_trials_in_exp=False, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], ), MinTrials( threshold=4, transition_to="MBM", - block_gen_if_met=False, - block_transition_if_unmet=True, use_all_trials_in_exp=True, only_in_statuses=[TrialStatus.COMPLETED], count_only_trials_with_data=True, @@ -390,7 +386,18 @@ def test_abandoned_and_failed_trials_excluded_from_initialization_budget( first_tc.not_in_statuses, [TrialStatus.FAILED, TrialStatus.ABANDONED] ) self.assertEqual(first_tc.threshold, 5) - self.assertTrue(first_tc.block_gen_if_met) + # Verify MaxTrialsAwaitingData is in generation_blocking_criteria + blocking_criteria = [ + bc + for bc in sobol_node._generation_blocking_criteria + if isinstance(bc, MaxTrialsAwaitingData) + ] + self.assertEqual(len(blocking_criteria), 1) + self.assertEqual(blocking_criteria[0].threshold, 5) + self.assertEqual( + blocking_criteria[0].not_in_statuses, + [TrialStatus.FAILED, TrialStatus.ABANDONED], + ) # Test the actual behavior: Generate 5 trials, mark 3 as ABANDONED, # verify that Sobol can still generate more trials diff --git a/ax/generation_strategy/generation_node.py b/ax/generation_strategy/generation_node.py index d0ddd5284fe..aa3c4237ef4 100644 --- a/ax/generation_strategy/generation_node.py +++ b/ax/generation_strategy/generation_node.py @@ -41,10 +41,12 @@ ) from ax.generation_strategy.generator_spec import GeneratorSpec from ax.generation_strategy.transition_criterion import ( + GenerationBlockingCriterion, MaxGenerationParallelism, + MaxTrialsAwaitingData, MinTrials, TransitionCriterion, - TrialBasedCriterion, + TrialCountBlockingCriterion, ) from ax.utils.common.base import SortableBase from ax.utils.common.constants import Keys @@ -130,6 +132,7 @@ class GenerationNode(SerializationMixin, SortableBase): _generator_spec_to_gen_from: GeneratorSpec | None = None # TODO: @mgarrard should this be a dict criterion_class name -> criterion mapping? _transition_criteria: Sequence[TransitionCriterion] + _generation_blocking_criteria: Sequence[GenerationBlockingCriterion] _input_constructors: TInputConstructorsByPurpose _previous_node_name: str | None = None _trial_type: str | None = None @@ -149,6 +152,8 @@ def __init__( name: str, generator_specs: list[GeneratorSpec], transition_criteria: Sequence[TransitionCriterion] | None = None, + generation_blocking_criteria: Sequence[GenerationBlockingCriterion] + | None = None, best_model_selector: BestModelSelector | None = None, should_deduplicate: bool = False, input_constructors: TInputConstructorsByPurpose | None = None, @@ -183,6 +188,7 @@ def __init__( self.best_model_selector = best_model_selector self.should_deduplicate = should_deduplicate self._transition_criteria = transition_criteria or [] + self._generation_blocking_criteria = generation_blocking_criteria or [] self._input_constructors = input_constructors or {} self._previous_node_name = previous_node_name self._trial_type = trial_type @@ -240,6 +246,17 @@ def transition_criteria(self) -> Sequence[TransitionCriterion]: """ return [] if self._transition_criteria is None else self._transition_criteria + @property + def generation_blocking_criteria(self) -> Sequence[GenerationBlockingCriterion]: + """Returns the sequence of GenerationBlockingCriteria that will be used to + block generation from this node without triggering a transition. + """ + return ( + [] + if self._generation_blocking_criteria is None + else self._generation_blocking_criteria + ) + @property def input_constructors(self) -> TInputConstructorsByPurpose: """Returns the input constructors that will be used to determine any dynamic @@ -295,23 +312,22 @@ def generator_name(self) -> str: def num_trials(self) -> int: """Returns the number of trials this node should generate. - Extracts the threshold from the first `MinTrials` transition criterion - that has `block_transition_if_unmet=True`. This represents the minimum - number of trials that must be generated before transitioning. + Extracts the threshold from the first `MinTrials` transition criterion. + This represents the minimum number of trials that must be generated + before transitioning. Returns: The number of trials (threshold value). Raises: - UserInputError: If no `MinTrials` transition criterion with - `block_transition_if_unmet=True` is found. + UserInputError: If no `MinTrials` transition criterion is found. """ for tc in self.transition_criteria: - if isinstance(tc, MinTrials) and tc.block_transition_if_unmet: + if isinstance(tc, MinTrials): return tc.threshold raise UserInputError( "`num_trials` property is only supported when a `MinTrials` " - "transition criterion with `block_transition_if_unmet=True` is present." + "transition criterion is present." ) @property @@ -354,6 +370,10 @@ def __repr__(self) -> str: str_rep += ( f", transition_criteria={str(self._brief_transition_criteria_repr())}" ) + str_rep += ( + ", generation_blocking_criteria=" + f"{str(self._brief_generation_blocking_criteria_repr())}" + ) return f"{str_rep})" def _fit( @@ -807,76 +827,66 @@ def should_transition_to_next_node( and the name of the node to gen from (either the current or next node) """ # if no transition criteria are defined, this node can generate unlimited trials - if len(self.transition_criteria) == 0: + if ( + len(self.transition_criteria) == 0 + and len(self.generation_blocking_criteria) == 0 + ): return False, self.name # For each "transition edge" (set of all transition criteria that lead from # current node (e.g. "node A") to another specific node ("e.g. "node B") # in the node DAG: - # I. Check if all of the transition criteria along that edge are met; if so, + # Check if all of the transition criteria along that edge are met; if so, # transition to the next node defined by that edge. - # II. If we did not transition along this edge, but the edge has some - # "generation blocking" transition criteria (ex `MaxGenerationParallelism`) - # that are met, raise the error associated with that criterion. for next_node, all_tc in self.transition_edges.items(): - # I. Check if there are any TCs that block transition and whether all - # of them are met. If all of them are met, then we should transition. - transition_blocking = [tc for tc in all_tc if tc.block_transition_if_unmet] - all_transition_blocking_met_should_transition = transition_blocking and all( + # Check if all TCs are met. If all of them are met, then we should + # transition. + all_tc_met_should_transition = all_tc and all( tc.is_met( experiment=self.experiment, curr_node=self, ) - for tc in transition_blocking + for tc in all_tc ) - if all_transition_blocking_met_should_transition: + if all_tc_met_should_transition: return True, next_node - # II. Raise any necessary generation errors: for any met criterion, - # call its `block_continued_generation_error` method if not all - # transition-blocking criteria are met. The method might not raise an - # error, depending on its implementation on given criterion, so the error - # from the first met one that does block continued generation, will raise. - if raise_data_required_error: - generation_blocking = [tc for tc in all_tc if tc.block_gen_if_met] - for tc in generation_blocking: - if tc.is_met(self.experiment, curr_node=self): - tc.block_continued_generation_error( - node_name=self.name, - experiment=self.experiment, - trials_from_node=self.trials_from_node, - ) - # TODO[@mgarrard, @drfreund] Try replacing `block_gen_if_met` with - # a self-transition and rework this error block. + # Only check generation blocking criteria if we're NOT transitioning. + # This ensures transition takes priority over blocking. + if raise_data_required_error: + for criterion in self.generation_blocking_criteria: + if criterion.is_met(self.experiment, curr_node=self): + criterion.block_continued_generation_error( + node_name=self.name, + experiment=self.experiment, + trials_from_node=self.trials_from_node, + ) return False, self.name def new_trial_limit(self, raise_generation_errors: bool = False) -> int: - """How many trials can this generation strategy can currently produce - ``GeneratorRun``-s for (with potentially multiple generator runs produced for - each intended trial). + """How many trials this node can currently produce GeneratorRun-s for. - NOTE: Only considers transition criteria that inherit from - ``TrialBasedCriterion``. + NOTE: Considers TrialCountBlockingCriterion subclasses for limiting + generation. Returns: The number of generator runs that can currently be produced, with -1 meaning unlimited generator runs. """ - # TODO: @mgarrard Should we consider returning `None` if there is no limit? - trial_based_gen_blocking_criteria = [ - criterion - for criterion in self.transition_criteria - if criterion.block_gen_if_met and isinstance(criterion, TrialBasedCriterion) - ] + # TODO: @mgarrard further improve and clarify this method # Cache trials_from_node to avoid repeated computation. trials_from_node = self.trials_from_node - gen_blocking_criterion_delta_from_threshold = [ - criterion.num_till_threshold( - experiment=self.experiment, trials_from_node=trials_from_node - ) - for criterion in trial_based_gen_blocking_criteria - ] + + # Compute limits from generation_blocking_criteria + gen_blocking_criterion_delta_from_threshold = [] + for criterion in self.generation_blocking_criteria: + if isinstance(criterion, TrialCountBlockingCriterion): + gen_blocking_criterion_delta_from_threshold.append( + criterion.num_till_threshold( + experiment=self.experiment, trials_from_node=trials_from_node + ) + ) # Raise any necessary generation errors: for any met criterion, # call its `block_continued_generation_error` method The method might not @@ -884,13 +894,10 @@ def new_trial_limit(self, raise_generation_errors: bool = False) -> int: # error from the first met one that does block continued generation, will be # raised. if raise_generation_errors: - for criterion in trial_based_gen_blocking_criteria: + for criterion in self.generation_blocking_criteria: # TODO[mgarrard]: Raise a group of all the errors, from each gen- # blocking transition criterion. - if criterion.is_met( - self.experiment, - curr_node=self, - ): + if criterion.is_met(self.experiment, curr_node=self): criterion.block_continued_generation_error( node_name=self.name, experiment=self.experiment, @@ -907,7 +914,7 @@ def _brief_transition_criteria_repr(self) -> str: Returns: str: A string representation of the transition criteria for this node. """ - if self.transition_criteria is None: + if not self._transition_criteria: return "None" tc_list = ", ".join( [ @@ -917,6 +924,25 @@ def _brief_transition_criteria_repr(self) -> str: ) return f"[{tc_list}]" + def _brief_generation_blocking_criteria_repr(self) -> str: + """Returns a brief string representation of the + generation blocking criteria for this node. + + Returns: + str: A string representation of the generation blocking criteria. + """ + if self._generation_blocking_criteria is None: + return "None" + bc_list = ", ".join( + [ + f"{bc.__class__.__name__}(threshold={bc.threshold})" + if isinstance(bc, TrialCountBlockingCriterion) + else f"{bc.__class__.__name__}()" + for bc in self.generation_blocking_criteria + ] + ) + return f"[{bc_list}]" + def apply_input_constructors( self, experiment: Experiment, @@ -1085,6 +1111,7 @@ def __new__( # is set in `GenerationStrategy` constructor, because only then is the order # of the generation steps actually known. transition_criteria: list[TransitionCriterion] = [] + generation_blocking_criteria: list[GenerationBlockingCriterion] = [] # Placeholder - will be overwritten in _validate_and_set_step_sequence in GS placeholder_transition_to = f"GenerationStep_{str(index)}" @@ -1094,11 +1121,18 @@ def __new__( threshold=num_trials, transition_to=placeholder_transition_to, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], - block_gen_if_met=enforce_num_trials, - block_transition_if_unmet=True, use_all_trials_in_exp=use_all_trials_in_exp, ) ) + # If enforce_num_trials is True, add a blocking criterion + if enforce_num_trials: + generation_blocking_criteria.append( + MaxTrialsAwaitingData( + threshold=num_trials, + not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], + use_all_trials_in_exp=use_all_trials_in_exp, + ) + ) if min_trials_observed > 0: transition_criteria.append( @@ -1109,19 +1143,14 @@ def __new__( TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED, ], - block_gen_if_met=False, - block_transition_if_unmet=True, use_all_trials_in_exp=use_all_trials_in_exp, ) ) if max_parallelism is not None: - transition_criteria.append( + generation_blocking_criteria.append( MaxGenerationParallelism( threshold=max_parallelism, - transition_to=placeholder_transition_to, only_in_statuses=[TrialStatus.RUNNING], - block_gen_if_met=True, - block_transition_if_unmet=False, ) ) @@ -1136,6 +1165,7 @@ def __new__( generator_specs=[generator_spec], should_deduplicate=should_deduplicate, transition_criteria=transition_criteria, + generation_blocking_criteria=generation_blocking_criteria, ) # Store step index on the node for naming in GenerationStrategy. diff --git a/ax/generation_strategy/generation_strategy.py b/ax/generation_strategy/generation_strategy.py index b680e31e001..dcef5b08160 100644 --- a/ax/generation_strategy/generation_strategy.py +++ b/ax/generation_strategy/generation_strategy.py @@ -398,11 +398,7 @@ def _validate_and_set_step_sequence(self, steps: list[GenerationNode]) -> None: ) ) for tc in step.transition_criteria: - if tc.criterion_class == "MaxGenerationParallelism": - # MaxGenerationParallelism transitions to self (current step) - tc._transition_to = step.name - else: - tc._transition_to = next_step_name + tc._transition_to = next_step_name self._curr = steps[0] def _validate_and_set_node_graph(self, nodes: list[GenerationNode]) -> None: diff --git a/ax/generation_strategy/tests/test_dispatch_utils.py b/ax/generation_strategy/tests/test_dispatch_utils.py index c382dd9fe3a..cd70d42a2ba 100644 --- a/ax/generation_strategy/tests/test_dispatch_utils.py +++ b/ax/generation_strategy/tests/test_dispatch_utils.py @@ -24,6 +24,7 @@ from ax.generation_strategy.generation_node import GenerationNode from ax.generation_strategy.transition_criterion import ( MaxGenerationParallelism, + MaxTrialsAwaitingData, MinTrials, ) from ax.generators.random.sobol import SobolGenerator @@ -615,18 +616,20 @@ def test_enforce_sequential_optimization(self) -> None: ).threshold, 5, ) - # Check that enforce_num_trials is True by verifying the MinTrials - # criterion has block_gen_if_met=True - node0_min_trials = assert_is_instance( - sobol_gpei._nodes[0].transition_criteria[0], MinTrials - ) - self.assertTrue(node0_min_trials.block_gen_if_met) + # Check that enforce_num_trials is True by verifying MaxTrialsAwaitingData + # exists in generation_blocking_criteria + node0_blocking_criteria = [ + bc + for bc in sobol_gpei._nodes[0].generation_blocking_criteria + if isinstance(bc, MaxTrialsAwaitingData) + ] + self.assertTrue(len(node0_blocking_criteria) > 0) # Check that max_parallelism is set by verifying MaxGenerationParallelism - # criterion exists on node 1 + # criterion exists in generation_blocking_criteria node1_max_parallelism = [ - tc - for tc in sobol_gpei._nodes[1].transition_criteria - if isinstance(tc, MaxGenerationParallelism) + bc + for bc in sobol_gpei._nodes[1].generation_blocking_criteria + if isinstance(bc, MaxGenerationParallelism) ] self.assertTrue(len(node1_max_parallelism) > 0) with self.subTest("False"): @@ -640,18 +643,20 @@ def test_enforce_sequential_optimization(self) -> None: ).threshold, 5, ) - # Check that enforce_num_trials is False by verifying the MinTrials - # criterion has block_gen_if_met=False - node0_min_trials = assert_is_instance( - sobol_gpei._nodes[0].transition_criteria[0], MinTrials - ) - self.assertFalse(node0_min_trials.block_gen_if_met) + # Check that enforce_num_trials is False by verifying no + # MaxTrialsAwaitingData exists in generation_blocking_criteria + node0_blocking_criteria = [ + bc + for bc in sobol_gpei._nodes[0].generation_blocking_criteria + if isinstance(bc, MaxTrialsAwaitingData) + ] + self.assertEqual(len(node0_blocking_criteria), 0) # Check that max_parallelism is None by verifying no - # MaxGenerationParallelism criterion exists on node 1 + # MaxGenerationParallelism criterion exists in generation_blocking_criteria node1_max_parallelism = [ - tc - for tc in sobol_gpei._nodes[1].transition_criteria - if isinstance(tc, MaxGenerationParallelism) + bc + for bc in sobol_gpei._nodes[1].generation_blocking_criteria + if isinstance(bc, MaxGenerationParallelism) ] self.assertEqual(len(node1_max_parallelism), 0) with self.subTest("False and max_parallelism_override"): @@ -818,10 +823,10 @@ def test_fixed_num_initialization_trials(self) -> None: ) def _get_max_parallelism(self, node: GenerationNode) -> int | None: - """Helper to extract max_parallelism from transition criteria.""" - for tc in node.transition_criteria: - if isinstance(tc, MaxGenerationParallelism): - return tc.threshold + """Helper to extract max_parallelism from generation_blocking_criteria.""" + for bc in node.generation_blocking_criteria: + if isinstance(bc, MaxGenerationParallelism): + return bc.threshold return None def test_max_parallelism_adjustments(self) -> None: diff --git a/ax/generation_strategy/tests/test_generation_node.py b/ax/generation_strategy/tests/test_generation_node.py index 3cae54f0b27..6e10f15ed37 100644 --- a/ax/generation_strategy/tests/test_generation_node.py +++ b/ax/generation_strategy/tests/test_generation_node.py @@ -335,7 +335,8 @@ def test_node_string_representation(self) -> None: "GenerationNode(name='test', " "generator_specs=[GeneratorSpec(generator_enum=BoTorch, " "generator_key_override=None)], " - "transition_criteria=[MinTrials(transition_to='next_node')])", + "transition_criteria=[MinTrials(transition_to='next_node')], " + "generation_blocking_criteria=[])", ) def test_single_fixed_features(self) -> None: @@ -445,8 +446,6 @@ def test_init(self) -> None: threshold=5, transition_to="GenerationStep_-1", # overwritten during GS init not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], - block_gen_if_met=True, - block_transition_if_unmet=True, use_all_trials_in_exp=False, ), ], @@ -471,8 +470,6 @@ def test_init(self) -> None: threshold=5, transition_to="GenerationStep_-1", # overwritten during GS init not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], - block_gen_if_met=False, - block_transition_if_unmet=True, use_all_trials_in_exp=True, ), MinTrials( @@ -482,8 +479,6 @@ def test_init(self) -> None: TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED, ], - block_gen_if_met=False, - block_transition_if_unmet=True, use_all_trials_in_exp=True, ), ], diff --git a/ax/generation_strategy/tests/test_generation_strategy.py b/ax/generation_strategy/tests/test_generation_strategy.py index cb954845614..cb7f3e15862 100644 --- a/ax/generation_strategy/tests/test_generation_strategy.py +++ b/ax/generation_strategy/tests/test_generation_strategy.py @@ -215,7 +215,6 @@ def setUp(self) -> None: MinTrials( threshold=5, transition_to="MBM_node", - block_gen_if_met=True, only_in_statuses=None, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], ) @@ -226,7 +225,6 @@ def setUp(self) -> None: # this self-pointing isn't representative of real-world, but is # useful for testing attributes likes repr etc transition_to="MBM_node", - block_gen_if_met=True, only_in_statuses=None, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], ) @@ -235,7 +233,6 @@ def setUp(self) -> None: MinTrials( threshold=1, transition_to="mbm", - block_transition_if_unmet=True, only_in_statuses=[TrialStatus.RUNNING], ) ] @@ -266,14 +263,12 @@ def setUp(self) -> None: self.mbm_to_sobol2_with_running_trial = MinTrials( threshold=1, transition_to="sobol_2", - block_transition_if_unmet=True, only_in_statuses=[TrialStatus.RUNNING], use_all_trials_in_exp=True, ) self.mbm_to_sobol2_with_completed_trial = MinTrials( threshold=1, transition_to="sobol_2", - block_transition_if_unmet=True, only_in_statuses=[TrialStatus.COMPLETED], use_all_trials_in_exp=True, ) @@ -336,13 +331,11 @@ def setUp(self) -> None: MinTrials( threshold=2, transition_to="sobol_4", - block_transition_if_unmet=True, only_in_statuses=[TrialStatus.RUNNING], use_all_trials_in_exp=True, ), AutoTransitionAfterGen( transition_to="mbm", - block_transition_if_unmet=True, continue_trial_generation=False, ), ], @@ -439,11 +432,13 @@ def test_string_representation(self) -> None: "generator_specs=[GeneratorSpec(generator_enum=Sobol, " "generator_key_override=None)], " "transition_criteria=" - "[MinTrials(transition_to='GenerationStep_1_BoTorch')]), " + "[MinTrials(transition_to='GenerationStep_1_BoTorch')], " + "generation_blocking_criteria=[MaxTrialsAwaitingData(threshold=5)]), " "GenerationNode(name='GenerationStep_1_BoTorch', " "generator_specs=[GeneratorSpec(generator_enum=BoTorch, " "generator_key_override=None)], " - "transition_criteria=[])])" + "transition_criteria=None, " + "generation_blocking_criteria=[])])" ), ) gs2 = GenerationStrategy( @@ -456,7 +451,8 @@ def test_string_representation(self) -> None: "nodes=[GenerationNode(name='GenerationStep_0_Sobol', " "generator_specs=[GeneratorSpec(generator_enum=Sobol, " "generator_key_override=None)], " - "transition_criteria=[])])" + "transition_criteria=None, " + "generation_blocking_criteria=[])])" ), ) @@ -480,7 +476,8 @@ def test_string_representation(self) -> None: "name='test', " "generator_specs=[GeneratorSpec(generator_enum=Sobol, " "generator_key_override=None)], " - "transition_criteria=[])])", + "transition_criteria=None, " + "generation_blocking_criteria=[])])", ) def test_equality(self) -> None: @@ -1220,30 +1217,29 @@ def test_gen_with_fixed_features( # ---------- Tests for GenerationStrategies composed of GenerationNodes -------- def test_gs_setup_with_nodes(self) -> None: """Test GS initialization and validation with nodes""" - node_1_criterion = [ + node_1_transition_criteria = [ MinTrials( threshold=4, - block_gen_if_met=False, transition_to="node_2", only_in_statuses=None, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], ), MinTrials( - only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED], threshold=2, + only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED], transition_to="node_2", ), + ] + node_1_blocking_criteria = [ MaxGenerationParallelism( threshold=1, only_in_statuses=[TrialStatus.RUNNING], - block_gen_if_met=True, - block_transition_if_unmet=False, - transition_to="node_1", ), ] node_1 = GenerationNode( name="node_1", - transition_criteria=node_1_criterion, + transition_criteria=node_1_transition_criteria, + generation_blocking_criteria=node_1_blocking_criteria, generator_specs=[self.sobol_generator_spec], ) node_3 = GenerationNode( @@ -1311,7 +1307,6 @@ def test_gs_setup_with_nodes(self) -> None: transition_criteria=[ MinTrials( threshold=4, - block_gen_if_met=False, transition_to="node_2", only_in_statuses=None, not_in_statuses=[ @@ -1383,7 +1378,6 @@ def test_gs_with_suggested_n_is_zero(self) -> None: transition_criteria=[ AutoTransitionAfterGen( transition_to="sobol_2", - block_transition_if_unmet=True, continue_trial_generation=False, ), ], @@ -1440,7 +1434,6 @@ def test_gen_with_multiple_nodes_pending_points(self) -> None: transition_criteria=[ AutoTransitionAfterGen( transition_to="sobol_1", - block_transition_if_unmet=True, continue_trial_generation=False, ), ], @@ -1599,15 +1592,11 @@ def test_gs_with_nodes_and_blocking_criteria(self) -> None: transition_criteria=[ MinTrials( threshold=3, - block_gen_if_met=True, - block_transition_if_unmet=True, transition_to="MBM_node", ), MinTrials( threshold=2, only_in_statuses=[TrialStatus.COMPLETED], - block_gen_if_met=False, - block_transition_if_unmet=True, transition_to="MBM_node", ), ], @@ -1752,7 +1741,6 @@ def test_node_gs_with_auto_transitions(self) -> None: transition_criteria=[ AutoTransitionAfterGen( transition_to="mbm", - block_transition_if_unmet=True, continue_trial_generation=False, ) ], @@ -1816,7 +1804,6 @@ def test_gs_with_fixed_features_constructor(self) -> None: MinTrials( threshold=1, transition_to="sobol_2", - block_gen_if_met=True, only_in_statuses=None, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], ) @@ -1942,7 +1929,6 @@ def test_gs_with_input_constructor(self) -> None: MinTrials( threshold=1, transition_to="sobol_2", - block_gen_if_met=True, only_in_statuses=None, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], ) diff --git a/ax/generation_strategy/tests/test_transition_criterion.py b/ax/generation_strategy/tests/test_transition_criterion.py index c34850ca4c0..1d63d30e170 100644 --- a/ax/generation_strategy/tests/test_transition_criterion.py +++ b/ax/generation_strategy/tests/test_transition_criterion.py @@ -11,7 +11,8 @@ from ax.adapter.registry import Generators from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose from ax.core.trial_status import TrialStatus -from ax.exceptions.core import UserInputError +from ax.exceptions.core import DataRequiredError, UserInputError +from ax.exceptions.generation_strategy import MaxParallelismReachedException from ax.generation_strategy.generation_strategy import ( GenerationNode, GenerationStep, @@ -23,6 +24,7 @@ AuxiliaryExperimentCheck, IsSingleObjective, MaxGenerationParallelism, + MaxTrialsAwaitingData, MinTrials, ) from ax.utils.common.logger import get_logger @@ -167,7 +169,6 @@ def test_default_step_criterion_setup(self) -> None: step_0_expected_transition_criteria = [ MinTrials( threshold=3, - block_gen_if_met=True, transition_to="GenerationStep_1_BoTorch", only_in_statuses=None, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], @@ -176,7 +177,6 @@ def test_default_step_criterion_setup(self) -> None: step_1_expected_transition_criteria = [ MinTrials( threshold=4, - block_gen_if_met=False, transition_to="GenerationStep_2_BoTorch", only_in_statuses=None, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], @@ -186,12 +186,11 @@ def test_default_step_criterion_setup(self) -> None: threshold=2, transition_to="GenerationStep_2_BoTorch", ), + ] + step_1_expected_blocking_criteria = [ MaxGenerationParallelism( threshold=1, only_in_statuses=[TrialStatus.RUNNING], - block_gen_if_met=True, - block_transition_if_unmet=False, - transition_to="GenerationStep_1_BoTorch", ), ] step_2_expected_transition_criteria = [] @@ -201,6 +200,9 @@ def test_default_step_criterion_setup(self) -> None: self.assertEqual( gs._nodes[1].transition_criteria, step_1_expected_transition_criteria ) + self.assertEqual( + gs._nodes[1].generation_blocking_criteria, step_1_expected_blocking_criteria + ) self.assertEqual( gs._nodes[2].transition_criteria, step_2_expected_transition_criteria ) @@ -391,12 +393,9 @@ def test_trials_from_node_empty(self) -> None: max_criterion_with_status = MinTrials( threshold=2, transition_to="next_node", - block_gen_if_met=True, only_in_statuses=[TrialStatus.COMPLETED], ) - max_criterion = MinTrials( - threshold=2, transition_to="next_node", block_gen_if_met=True - ) + max_criterion = MinTrials(threshold=2, transition_to="next_node") self.assertFalse( max_criterion.is_met(experiment=experiment, curr_node=gs._nodes[0]) ) @@ -428,8 +427,6 @@ def test_repr(self) -> None: min_trials_criterion = MinTrials( threshold=5, transition_to="GenerationStep_1", - block_gen_if_met=True, - block_transition_if_unmet=False, only_in_statuses=[TrialStatus.COMPLETED], not_in_statuses=[TrialStatus.FAILED], ) @@ -439,8 +436,6 @@ def test_repr(self) -> None: + "'transition_to': 'GenerationStep_1', " + "'only_in_statuses': [.COMPLETED], " + "'not_in_statuses': [.FAILED], " - + "'block_transition_if_unmet': False, " - + "'block_gen_if_met': True, " + "'use_all_trials_in_exp': False, " + "'continue_trial_generation': False, " + "'count_only_trials_with_data': False})", @@ -449,8 +444,6 @@ def test_repr(self) -> None: threshold=0, transition_to="GenerationStep_2", only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED], - block_gen_if_met=True, - block_transition_if_unmet=False, not_in_statuses=[TrialStatus.FAILED], ) self.assertEqual( @@ -460,8 +453,6 @@ def test_repr(self) -> None: + "'only_in_statuses': " + "[.COMPLETED, .EARLY_STOPPED], " + "'not_in_statuses': [.FAILED], " - + "'block_transition_if_unmet': False, " - + "'block_gen_if_met': True, " + "'use_all_trials_in_exp': False, " + "'continue_trial_generation': False, " + "'count_only_trials_with_data': False})", @@ -469,27 +460,54 @@ def test_repr(self) -> None: max_parallelism = MaxGenerationParallelism( only_in_statuses=[TrialStatus.EARLY_STOPPED], threshold=3, - transition_to="GenerationStep_2", - block_gen_if_met=True, - block_transition_if_unmet=False, not_in_statuses=[TrialStatus.FAILED], ) self.assertEqual( str(max_parallelism), "MaxGenerationParallelism({'threshold': 3, " - + "'transition_to': 'GenerationStep_2', " + "'only_in_statuses': " + "[.EARLY_STOPPED], " + "'not_in_statuses': [.FAILED], " - + "'block_transition_if_unmet': False, " - + "'block_gen_if_met': True, " + "'use_all_trials_in_exp': False, " - + "'continue_trial_generation': False})", + + "'count_only_trials_with_data': False})", ) auto_transition = AutoTransitionAfterGen(transition_to="GenerationStep_2") self.assertEqual( str(auto_transition), "AutoTransitionAfterGen({'transition_to': 'GenerationStep_2', " - + "'block_transition_if_unmet': True, " + "'continue_trial_generation': True})", ) + + +class TestGenerationBlockingCriterion(TestCase): + """Tests for new GenerationBlockingCriterion classes.""" + + def setUp(self) -> None: + super().setUp() + self.experiment = get_branin_experiment() + + def test_max_trials_awaiting_data(self) -> None: + with self.subTest("default_not_in_statuses"): + criterion = MaxTrialsAwaitingData(threshold=10) + self.assertEqual( + criterion.not_in_statuses, + [TrialStatus.FAILED, TrialStatus.ABANDONED], + ) + + with self.subTest("block_continued_generation_error"): + criterion = MaxTrialsAwaitingData(threshold=3) + with self.assertRaises(DataRequiredError): + criterion.block_continued_generation_error( + node_name="test", experiment=self.experiment, trials_from_node=set() + ) + + def test_max_generation_parallelism_block_error(self) -> None: + criterion = MaxGenerationParallelism( + threshold=2, only_in_statuses=[TrialStatus.RUNNING] + ) + with self.assertRaises(MaxParallelismReachedException): + criterion.block_continued_generation_error( + node_name="test", + experiment=self.experiment, + trials_from_node={0, 1, 2}, + ) diff --git a/ax/generation_strategy/transition_criterion.py b/ax/generation_strategy/transition_criterion.py index 7b4aceb6c46..41706325dab 100644 --- a/ax/generation_strategy/transition_criterion.py +++ b/ax/generation_strategy/transition_criterion.py @@ -32,6 +32,325 @@ ) +# ============================================================================ +# Trial Counting Utility Functions +# ============================================================================ + + +def get_trials_by_status( + experiment: Experiment, statuses: list[TrialStatus] +) -> set[int]: + """Get trial indices from the experiment with the specified statuses. + + Args: + experiment: The experiment to query. + statuses: The trial statuses to filter on. + + Returns: + Set of trial indices with the specified statuses. + """ + trials_with_statuses = set() + for status in statuses: + trials_with_statuses = trials_with_statuses.union( + experiment.trial_indices_by_status[status] + ) + return trials_with_statuses + + +def filter_trials_by_status( + experiment: Experiment, + only_in_statuses: list[TrialStatus] | None = None, + not_in_statuses: list[TrialStatus] | None = None, +) -> set[int]: + """Filter trial indices by status inclusion/exclusion. + + Args: + experiment: The experiment to query. + only_in_statuses: If provided, only include trials with these statuses. + not_in_statuses: If provided, exclude trials with these statuses. + + Returns: + Set of trial indices matching the filter criteria. + """ + trials_to_check = set(experiment.trials.keys()) + if only_in_statuses is not None: + trials_to_check = get_trials_by_status( + experiment=experiment, statuses=only_in_statuses + ) + if not_in_statuses is not None: + trials_to_check -= get_trials_by_status( + experiment=experiment, statuses=not_in_statuses + ) + return trials_to_check + + +def count_trials_toward_threshold( + experiment: Experiment, + trials_from_node: set[int], + only_in_statuses: list[TrialStatus] | None = None, + not_in_statuses: list[TrialStatus] | None = None, + use_all_trials_in_exp: bool = False, + count_only_trials_with_data: bool = False, +) -> int: + """Count trials contributing toward a threshold. + + Args: + experiment: The experiment to query. + trials_from_node: Set of trial indices generated by the current node. + only_in_statuses: If provided, only count trials with these statuses. + not_in_statuses: If provided, exclude trials with these statuses. + use_all_trials_in_exp: If True, count all trials in the experiment. + Otherwise, only count trials from the current node. + count_only_trials_with_data: If True, only count trials with data. + + Returns: + The number of trials contributing to the threshold. + """ + all_trials_to_check = filter_trials_by_status( + experiment=experiment, + only_in_statuses=only_in_statuses, + not_in_statuses=not_in_statuses, + ) + if count_only_trials_with_data: + data_trial_indices = get_trial_indices_with_required_metrics( + experiment=experiment, + df=experiment.lookup_data().df, + require_data_for_all_metrics=False, + ) + all_trials_to_check = all_trials_to_check.intersection(data_trial_indices) + + if use_all_trials_in_exp: + return len(all_trials_to_check) + + return len(trials_from_node.intersection(all_trials_to_check)) + + +# ============================================================================ +# GenerationBlockingCriterion - for blocking generation without transitioning +# ============================================================================ + + +class GenerationBlockingCriterion(SortableBase): + """A criterion that blocks generation from a GenerationNode without triggering + a transition to another node. + """ + + @abstractmethod + def is_met( + self, + experiment: Experiment, + curr_node: GenerationNode, + ) -> bool: + """Returns True if this criterion's condition is met.""" + pass + + @abstractmethod + def block_continued_generation_error( + self, + node_name: str, + experiment: Experiment, + trials_from_node: set[int], + ) -> None: + """Raises an appropriate error when generation is blocked.""" + pass + + @property + def criterion_class(self) -> str: + """Name of the class of this GenerationBlockingCriterion.""" + return self.__class__.__name__ + + def __repr__(self) -> str: + return f"{self.criterion_class}({serialize_init_args(obj=self)})" + + @property + def _unique_id(self) -> str: + """Unique id for this GenerationBlockingCriterion.""" + return str(self) + + +class TrialCountBlockingCriterion(GenerationBlockingCriterion): + """Abstract base class for blocking criteria based on trial count thresholds. + + This class provides shared logic for blocking criteria that count trials toward + a threshold. Subclasses only need to implement `block_continued_generation_error()` + to define the specific error raised when generation is blocked. + + Args: + threshold: The maximum number of trials allowed before blocking generation. + only_in_statuses: A list of trial statuses to filter on when checking the + criterion threshold. + not_in_statuses: A list of trial statuses to exclude when checking the + criterion threshold. + use_all_trials_in_exp: A flag to use all trials in the experiment, instead of + only those generated by the current GenerationNode. + count_only_trials_with_data: If set to True, only trials with data will be + counted towards the ``threshold``. Defaults to False. + """ + + def __init__( + self, + threshold: int, + only_in_statuses: list[TrialStatus] | None = None, + not_in_statuses: list[TrialStatus] | None = None, + use_all_trials_in_exp: bool = False, + count_only_trials_with_data: bool = False, + ) -> None: + self.threshold = threshold + self.only_in_statuses = only_in_statuses + self.not_in_statuses = not_in_statuses + self.use_all_trials_in_exp = use_all_trials_in_exp + self.count_only_trials_with_data = count_only_trials_with_data + + def num_contributing_to_threshold( + self, experiment: Experiment, trials_from_node: set[int] + ) -> int: + """Returns the number of trials contributing to the threshold.""" + return count_trials_toward_threshold( + experiment=experiment, + trials_from_node=trials_from_node, + only_in_statuses=self.only_in_statuses, + not_in_statuses=self.not_in_statuses, + use_all_trials_in_exp=self.use_all_trials_in_exp, + count_only_trials_with_data=self.count_only_trials_with_data, + ) + + def num_till_threshold( + self, experiment: Experiment, trials_from_node: set[int] + ) -> int: + """Returns the number of trials available before hitting the threshold.""" + return self.threshold - self.num_contributing_to_threshold( + experiment=experiment, trials_from_node=trials_from_node + ) + + def is_met( + self, + experiment: Experiment, + curr_node: GenerationNode, + ) -> bool: + """Returns True if the trial count threshold has been reached.""" + return ( + self.num_contributing_to_threshold( + experiment=experiment, trials_from_node=curr_node.trials_from_node + ) + >= self.threshold + ) + + @abstractmethod + def block_continued_generation_error( + self, + node_name: str, + experiment: Experiment, + trials_from_node: set[int], + ) -> None: + """Raises an appropriate error when generation is blocked. + + Subclasses must implement this to define the specific error behavior. + """ + pass + + +class MaxGenerationParallelism(TrialCountBlockingCriterion): + """A GenerationBlockingCriterion that blocks generation after a maximum number + of trials have been generated for the current GenerationNode and are currently + running. + + Args: + threshold: The maximum number of trials allowed in the specified statuses. + only_in_statuses: A list of trial statuses to filter on when checking the + criterion threshold. + not_in_statuses: A list of trial statuses to exclude when checking the + criterion threshold. + use_all_trials_in_exp: A flag to use all trials in the experiment, instead of + only those generated by the current GenerationNode. + count_only_trials_with_data: If set to True, only trials with data will be + counted towards the ``threshold``. Defaults to False. + """ + + def __init__( + self, + threshold: int, + only_in_statuses: list[TrialStatus] | None = None, + not_in_statuses: list[TrialStatus] | None = None, + use_all_trials_in_exp: bool = False, + count_only_trials_with_data: bool = False, + ) -> None: + super().__init__( + threshold=threshold, + only_in_statuses=only_in_statuses, + not_in_statuses=not_in_statuses, + use_all_trials_in_exp=use_all_trials_in_exp, + count_only_trials_with_data=count_only_trials_with_data, + ) + + def block_continued_generation_error( + self, + node_name: str, + experiment: Experiment, + trials_from_node: set[int], + ) -> None: + """Raises MaxParallelismReachedException.""" + raise MaxParallelismReachedException( + node_name=node_name, + num_running=self.num_contributing_to_threshold( + experiment=experiment, trials_from_node=trials_from_node + ), + ) + + +class MaxTrialsAwaitingData(TrialCountBlockingCriterion): + """A GenerationBlockingCriterion that blocks generation after a maximum number + of trials have been generated, waiting for data before allowing more generation. + + This criterion blocks generation from the associated GenerationNode when the + threshold is met, but does NOT trigger a transition to another node. Use this + when you want to enforce that a node generates at most a certain number of + trials before requiring data. + + Args: + threshold: The maximum number of trials allowed before blocking generation. + only_in_statuses: A list of trial statuses to filter on when checking the + criterion threshold. + not_in_statuses: A list of trial statuses to exclude when checking the + criterion threshold. Defaults to [FAILED, ABANDONED]. + use_all_trials_in_exp: A flag to use all trials in the experiment, instead of + only those generated by the current GenerationNode. Defaults to True. + count_only_trials_with_data: If set to True, only trials with data will be + counted towards the ``threshold``. Defaults to False. + """ + + def __init__( + self, + threshold: int, + only_in_statuses: list[TrialStatus] | None = None, + not_in_statuses: list[TrialStatus] | None = None, + use_all_trials_in_exp: bool = False, + count_only_trials_with_data: bool = False, + ) -> None: + if not_in_statuses is None: + not_in_statuses = [TrialStatus.FAILED, TrialStatus.ABANDONED] + super().__init__( + threshold=threshold, + only_in_statuses=only_in_statuses, + not_in_statuses=not_in_statuses, + use_all_trials_in_exp=use_all_trials_in_exp, + count_only_trials_with_data=count_only_trials_with_data, + ) + + def block_continued_generation_error( + self, + node_name: str, + experiment: Experiment, + trials_from_node: set[int], + ) -> None: + """Raises DataRequiredError when the trial threshold is reached.""" + raise DataRequiredError(DATA_REQUIRED_MSG.format(node_name=node_name)) + + +# ============================================================================ +# TransitionCriterion - for node transitions +# ============================================================================ + + class TransitionCriterion(SortableBase): """ Simple class to describe a condition which must be met for this GenerationNode to @@ -40,16 +359,6 @@ class TransitionCriterion(SortableBase): Args: transition_to: The name of the GenerationNode the GenerationStrategy should transition to when this criterion is met. - block_gen_if_met: A flag to prevent continued generation from the - associated GenerationNode if this criterion is met but other criterion - remain unmet. Ex: ``MinTrials`` has not been met yet, but - MinTrials has been reached. If this flag is set to true on MinTrials then - we will raise an error, otherwise we will continue to generate trials - until ``MinTrials`` is met (thus overriding MinTrials). - block_transition_if_unmet: A flag to prevent the node from completing and - being able to transition to another node. Ex: MaxGenerationParallelism - defaults to setting this to False since we can complete and move on from - this node without ever reaching its threshold. continue_trial_generation: A flag to indicate that all generation for a given trial is not completed, and thus even after transition, the next node will continue to generate arms for the same trial. Example usage: in @@ -62,13 +371,9 @@ class TransitionCriterion(SortableBase): def __init__( self, transition_to: str, - block_transition_if_unmet: bool | None = True, - block_gen_if_met: bool | None = False, continue_trial_generation: bool | None = False, ) -> None: self._transition_to = transition_to - self.block_transition_if_unmet = block_transition_if_unmet - self.block_gen_if_met = block_gen_if_met self.continue_trial_generation = continue_trial_generation @property @@ -87,16 +392,6 @@ def is_met( """If the criterion of this TransitionCriterion is met, returns True.""" pass - @abstractmethod - def block_continued_generation_error( - self, - node_name: str, - experiment: Experiment, - trials_from_node: set[int], - ) -> None: - """Error to be raised if the `block_gen_if_met` flag is set to True.""" - pass - @property def criterion_class(self) -> str: """Name of the class of this TransitionCriterion.""" @@ -117,10 +412,6 @@ class AutoTransitionAfterGen(TransitionCriterion): Args: transition_to: The name of the GenerationNode the GenerationStrategy should transition to next. - block_transition_if_unmet: A flag to prevent the node from completing and - being able to transition to another node. Ex: This criterion defaults to - setting this to True to ensure we validate a GeneratorRun is generated by - the current GenerationNode. continue_trial_generation: A flag to indicate that all generation for a given trial is not completed, and thus even after transition, the next node will continue to generate arms for the same trial. Example usage: in @@ -131,12 +422,10 @@ class AutoTransitionAfterGen(TransitionCriterion): def __init__( self, transition_to: str, - block_transition_if_unmet: bool | None = True, continue_trial_generation: bool | None = True, ) -> None: super().__init__( transition_to=transition_to, - block_transition_if_unmet=block_transition_if_unmet, continue_trial_generation=continue_trial_generation, ) @@ -161,15 +450,6 @@ def is_met( else False ) - def block_continued_generation_error( - self, - node_name: str, - experiment: Experiment, - trials_from_node: set[int], - ) -> None: - """Error to be raised if the `block_gen_if_met` flag is set to True.""" - pass - class IsSingleObjective(TransitionCriterion): """A class to initiate transition based on whether the experiment is optimizing @@ -178,10 +458,6 @@ class IsSingleObjective(TransitionCriterion): Args: transition_to: The name of the GenerationNode the GenerationStrategy should transition to next. - block_transition_if_unmet: A flag to prevent the node from completing and - being able to transition to another node. Ex: This criterion defaults to - setting this to True to ensure we validate a GeneratorRun is generated by - the current GenerationNode. continue_trial_generation: A flag to indicate that all generation for a given trial is not completed, and thus even after transition, the next node will continue to generate arms for the same trial. Example usage: in @@ -192,12 +468,10 @@ class IsSingleObjective(TransitionCriterion): def __init__( self, transition_to: str, - block_transition_if_unmet: bool | None = True, continue_trial_generation: bool | None = False, ) -> None: super().__init__( transition_to=transition_to, - block_transition_if_unmet=block_transition_if_unmet, continue_trial_generation=continue_trial_generation, ) @@ -216,14 +490,6 @@ def is_met( else True ) - def block_continued_generation_error( - self, - node_name: str, - experiment: Experiment, - trials_from_node: set[int], - ) -> None: - pass - class TrialBasedCriterion(TransitionCriterion): """Common class for transition criterion that are based on trial information. @@ -231,16 +497,6 @@ class TrialBasedCriterion(TransitionCriterion): Args: threshold: The threshold as an integer for this criterion. Ex: If we want to generate at most 3 trials, then the threshold is 3. - block_transition_if_unmet: A flag to prevent the node from completing and - being able to transition to another node. Ex: MaxGenerationParallelism - defaults to setting this to False since we can complete and move on from - this node without ever reaching its threshold. - block_gen_if_met: A flag to prevent continued generation from the - associated GenerationNode if this criterion is met but other criterion - remain unmet. Ex: ``MinTrials`` has not been met yet, but - MinTrials has been reached. If this flag is set to true on MinTrials then - we will raise an error, otherwise we will continue to generate trials - until ``MinTrials`` is met (thus overriding MinTrials). only_in_statuses: A list of trial statuses to filter on when checking the criterion threshold. not_in_statuses: A list of trial statuses to exclude when checking the @@ -263,8 +519,6 @@ def __init__( self, threshold: int, transition_to: str, - block_transition_if_unmet: bool | None = True, - block_gen_if_met: bool | None = False, only_in_statuses: list[TrialStatus] | None = None, not_in_statuses: list[TrialStatus] | None = None, use_all_trials_in_exp: bool | None = False, @@ -278,72 +532,35 @@ def __init__( self.count_only_trials_with_data = count_only_trials_with_data super().__init__( transition_to=transition_to, - block_transition_if_unmet=block_transition_if_unmet, - block_gen_if_met=block_gen_if_met, continue_trial_generation=continue_trial_generation, ) def experiment_trials_by_status( self, experiment: Experiment, statuses: list[TrialStatus] ) -> set[int]: - """Get the trial indices from the entire experiment with the desired - statuses. - - Args: - experiment: The experiment associated with this GenerationStrategy. - statuses: The trial statuses to filter on. - Returns: - The trial indices in the experiment with the desired statuses. - """ - exp_trials_with_statuses = set() - for status in statuses: - exp_trials_with_statuses = exp_trials_with_statuses.union( - experiment.trial_indices_by_status[status] - ) - return exp_trials_with_statuses + """Get the trial indices from the experiment with the desired statuses.""" + return get_trials_by_status(experiment=experiment, statuses=statuses) def all_trials_to_check(self, experiment: Experiment) -> set[int]: - """All the trials to check from the entire experiment that meet - all the provided status filters. - - Args: - experiment: The experiment associated with this GenerationStrategy. - """ - trials_to_check = set(experiment.trials.keys()) - if self.only_in_statuses is not None: - trials_to_check = self.experiment_trials_by_status( - experiment=experiment, statuses=self.only_in_statuses - ) - # exclude the trials to those not in the specified statuses - if self.not_in_statuses is not None: - trials_to_check -= self.experiment_trials_by_status( - experiment=experiment, statuses=self.not_in_statuses - ) - return trials_to_check + """All the trials to check that meet the provided status filters.""" + return filter_trials_by_status( + experiment=experiment, + only_in_statuses=self.only_in_statuses, + not_in_statuses=self.not_in_statuses, + ) def num_contributing_to_threshold( self, experiment: Experiment, trials_from_node: set[int] ) -> int: - """Returns the number of trials contributing to the threshold. - - Args: - experiment: The experiment associated with this GenerationStrategy. - trials_from_node: The set of trials generated by this GenerationNode. - """ - all_trials_to_check = self.all_trials_to_check(experiment=experiment) - if self.count_only_trials_with_data: - data_trial_indices = get_trial_indices_with_required_metrics( - experiment=experiment, - df=experiment.lookup_data().df, - require_data_for_all_metrics=False, - ) - all_trials_to_check = all_trials_to_check.intersection(data_trial_indices) - # Some criteria may rely on experiment level data, instead of only trials - # generated from the node associated with the criterion. - if self.use_all_trials_in_exp: - return len(all_trials_to_check) - - return len(trials_from_node.intersection(all_trials_to_check)) + """Returns the number of trials contributing to the threshold.""" + return count_trials_toward_threshold( + experiment=experiment, + trials_from_node=trials_from_node, + only_in_statuses=self.only_in_statuses, + not_in_statuses=self.not_in_statuses, + use_all_trials_in_exp=bool(self.use_all_trials_in_exp), + count_only_trials_with_data=self.count_only_trials_with_data, + ) def num_till_threshold( self, experiment: Experiment, trials_from_node: set[int] @@ -382,93 +599,10 @@ def is_met( ) -class MaxGenerationParallelism(TrialBasedCriterion): - """Specific TransitionCriterion implementation which defines the maximum number - of trials that can simultaneously be in the designated trial statuses. The - default behavior is to block generation from the associated GenerationNode if the - threshold is met. This is configured via the `block_gen_if_met` flag being set to - True. This criterion defaults to not blocking transition to another node via the - `block_transition_if_unmet` flag being set to False. - - Args: - threshold: The threshold as an integer for this criterion. Ex: If we want to - generate at most 3 trials, then the threshold is 3. - only_in_statuses: A list of trial statuses to filter on when checking the - criterion threshold. - not_in_statuses: A list of trial statuses to exclude when checking the - criterion threshold. - transition_to: The name of the GenerationNode the GenerationStrategy should - transition to when this criterion is met, if it exists. - block_transition_if_unmet: A flag to prevent the node from completing and - being able to transition to another node. Ex: MaxGenerationParallelism - defaults to setting this to False since we can complete and move on from - this node without ever reaching its threshold. - block_gen_if_met: A flag to prevent continued generation from the - associated GenerationNode if this criterion is met but other criterion - remain unmet. Ex: ``MinTrials`` has not been met yet, but - MinTrials has been reached. If this flag is set to true on MinTrials then - we will raise an error, otherwise we will continue to generate trials - until ``MinTrials`` is met (thus overriding MinTrials). - use_all_trials_in_exp: A flag to use all trials in the experiment, instead of - only those generated by the current GenerationNode. - continue_trial_generation: A flag to indicate that all generation for a given - trial is not completed, and thus even after transition, the next node will - continue to generate arms for the same trial. Example usage: in - ``BatchTrial``s we may enable generation of arms within a batch from - different ``GenerationNodes`` by setting this flag to True. Defaults to - False for MaxGenerationParallelism since this criterion isn't currently - used for node -> node or trial -> trial transition. - count_only_trials_with_data: If set to True, only trials with data will be - counted towards the ``threshold``. Defaults to False. - """ - - def __init__( - self, - threshold: int, - transition_to: str, - only_in_statuses: list[TrialStatus] | None = None, - not_in_statuses: list[TrialStatus] | None = None, - block_transition_if_unmet: bool | None = False, - block_gen_if_met: bool | None = True, - use_all_trials_in_exp: bool | None = False, - continue_trial_generation: bool | None = False, - ) -> None: - super().__init__( - threshold=threshold, - only_in_statuses=only_in_statuses, - not_in_statuses=not_in_statuses, - transition_to=transition_to, - block_gen_if_met=block_gen_if_met, - block_transition_if_unmet=block_transition_if_unmet, - use_all_trials_in_exp=use_all_trials_in_exp, - continue_trial_generation=continue_trial_generation, - ) - - def block_continued_generation_error( - self, - node_name: str, - experiment: Experiment, - trials_from_node: set[int], - ) -> None: - """Raises the appropriate error (should only be called when the - ``GenerationNode`` is blocked from continued generation). For this - class, the exception is ``MaxParallelismReachedException``. - """ - assert self.block_gen_if_met # Sanity check. - raise MaxParallelismReachedException( - node_name=node_name, - num_running=self.num_contributing_to_threshold( - experiment=experiment, trials_from_node=trials_from_node - ), - ) - - class MinTrials(TrialBasedCriterion): """ Simple class to enforce a minimum threshold for the number of trials with the - designated statuses being generated by a specific GenerationNode. The default - behavior is to block transition to the next node if the threshold is unmet, but - not affect continued generation. + designated statuses being generated by a specific GenerationNode. Args: threshold: The threshold as an integer for this criterion. Ex: If we want to @@ -479,16 +613,6 @@ class MinTrials(TrialBasedCriterion): criterion threshold. transition_to: The name of the GenerationNode the GenerationStrategy should transition to when this criterion is met. - block_transition_if_unmet: A flag to prevent the node from completing and - being able to transition to another node. Ex: MaxGenerationParallelism - defaults to setting this to False since we can complete and move on from - this node without ever reaching its threshold. - block_gen_if_met: A flag to prevent continued generation from the - associated GenerationNode if this criterion is met but other criterion - remain unmet. Ex: ``MinTrials`` has not been met yet, but - MinTrials has been reached. If this flag is set to true on MinTrials then - we will raise an error, otherwise we will continue to generate trials - until ``MinTrials`` is met (thus overriding MinTrials). use_all_trials_in_exp: A flag to use all trials in the experiment, instead of only those generated by the current GenerationNode. continue_trial_generation: A flag to indicate that all generation for a given @@ -506,8 +630,6 @@ def __init__( transition_to: str, only_in_statuses: list[TrialStatus] | None = None, not_in_statuses: list[TrialStatus] | None = None, - block_transition_if_unmet: bool | None = True, - block_gen_if_met: bool | None = False, use_all_trials_in_exp: bool | None = False, continue_trial_generation: bool | None = False, count_only_trials_with_data: bool = False, @@ -517,26 +639,11 @@ def __init__( transition_to=transition_to, only_in_statuses=only_in_statuses, not_in_statuses=not_in_statuses, - block_gen_if_met=block_gen_if_met, - block_transition_if_unmet=block_transition_if_unmet, use_all_trials_in_exp=use_all_trials_in_exp, continue_trial_generation=continue_trial_generation, count_only_trials_with_data=count_only_trials_with_data, ) - def block_continued_generation_error( - self, - node_name: str, - experiment: Experiment, - trials_from_node: set[int], - ) -> None: - """Raises the appropriate error (should only be called when the - ``GenerationNode`` is blocked from continued generation). For this - class, the exception is ``DataRequiredError``. - """ - assert self.block_gen_if_met # Sanity check. - raise DataRequiredError(DATA_REQUIRED_MSG.format(node_name=node_name)) - class AuxiliaryExperimentCheck(TransitionCriterion): """A class to transition from one GenerationNode to another by checking if certain @@ -563,14 +670,6 @@ class AuxiliaryExperimentCheck(TransitionCriterion): purpose we expect to not have. This can be helpful when need to transition out of a node based on AuxiliaryExperimentPurpose. Criterion is met when all inclusion and exclusion checks pass. - block_gen_if_met: A flag to prevent continued generation from the - associated GenerationNode if this criterion is met but other criteria - remain unmet. Defaults to False since auxiliary experiment checks are - typically used for transition logic rather than blocking generation. - block_transition_if_unmet: A flag to prevent the node from completing and - being able to transition to another node. Ex: MaxGenerationParallelism - defaults to setting this to False since we can complete and move on from - this node without ever reaching its threshold. continue_trial_generation: A flag to indicate that all generation for a given trial is not completed, and thus even after transition, the next node will continue to generate arms for the same trial. Example usage: in @@ -587,14 +686,10 @@ def __init__( auxiliary_experiment_purposes_to_exclude: ( list[AuxiliaryExperimentPurpose] | None ) = None, - block_transition_if_unmet: bool | None = True, - block_gen_if_met: bool | None = False, continue_trial_generation: bool | None = False, ) -> None: super().__init__( transition_to=transition_to, - block_transition_if_unmet=block_transition_if_unmet, - block_gen_if_met=block_gen_if_met, continue_trial_generation=continue_trial_generation, ) @@ -651,11 +746,3 @@ def is_met( expected_aux_exp_purposes=self.auxiliary_experiment_purposes_to_exclude, ) return inclusion_check and exclusion_check - - def block_continued_generation_error( - self, - node_name: str, - experiment: Experiment, - trials_from_node: set[int], - ) -> None: - pass diff --git a/ax/orchestration/tests/test_orchestrator.py b/ax/orchestration/tests/test_orchestrator.py index 67f982ca1a1..d6e9d9f6fc1 100644 --- a/ax/orchestration/tests/test_orchestrator.py +++ b/ax/orchestration/tests/test_orchestrator.py @@ -159,12 +159,15 @@ class TestAxOrchestrator(TestCase): "generator_specs=[GeneratorSpec(generator_enum=Sobol, " "generator_key_override=None)], " "transition_criteria=[MinTrials(transition_to='GenerationStep_1_BoTorch'), " - "MinTrials(transition_to='GenerationStep_1_BoTorch')]), " + "MinTrials(transition_to='GenerationStep_1_BoTorch')], " + "generation_blocking_criteria=" + "[MaxTrialsAwaitingData(threshold=5)]), " "GenerationNode(name='GenerationStep_1_BoTorch', " "generator_specs=[GeneratorSpec(generator_enum=BoTorch, " "generator_key_override=None)], " - "transition_criteria=[MaxGenerationParallelism(" - "transition_to='GenerationStep_1_BoTorch')])]), " + "transition_criteria=None, " + "generation_blocking_criteria=" + "[MaxGenerationParallelism(threshold=3)])]), " "options=OrchestratorOptions(max_pending_trials=10, " "trial_type=, batch_size=None, " "total_trials=0, tolerated_trial_failure_rate=0.2, " @@ -1168,11 +1171,11 @@ def test_run_trials_and_yield_results_with_early_stopper(self) -> None: expected_num_polls = 2 self.assertEqual(len(res_list), expected_num_polls + 1) # Both trials in first batch of parallelism will be early stopped - # Extract max_parallelism from transition criteria + # Extract max_parallelism from generation_blocking_criteria node0_max_parallelism = None - for tc in self.two_sobol_steps_GS._nodes[0].transition_criteria: - if isinstance(tc, MaxGenerationParallelism): - node0_max_parallelism = tc.threshold + for bc in self.two_sobol_steps_GS._nodes[0].generation_blocking_criteria: + if isinstance(bc, MaxGenerationParallelism): + node0_max_parallelism = bc.threshold break self.assertEqual( len(res_list[0]["trials_early_stopped_so_far"]), @@ -2852,12 +2855,15 @@ class TestAxOrchestratorMultiTypeExperiment(TestAxOrchestrator): "generator_specs=[GeneratorSpec(generator_enum=Sobol, " "generator_key_override=None)], " "transition_criteria=[MinTrials(transition_to='GenerationStep_1_BoTorch'), " - "MinTrials(transition_to='GenerationStep_1_BoTorch')]), " + "MinTrials(transition_to='GenerationStep_1_BoTorch')], " + "generation_blocking_criteria=" + "[MaxTrialsAwaitingData(threshold=5)]), " "GenerationNode(name='GenerationStep_1_BoTorch', " "generator_specs=[GeneratorSpec(generator_enum=BoTorch, " "generator_key_override=None)], " - "transition_criteria=" - "[MaxGenerationParallelism(transition_to='GenerationStep_1_BoTorch')])]), " + "transition_criteria=None, " + "generation_blocking_criteria=" + "[MaxGenerationParallelism(threshold=3)])]), " "options=OrchestratorOptions(max_pending_trials=10, " "trial_type=, batch_size=None, " "total_trials=0, tolerated_trial_failure_rate=0.2, " diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 47a8f54ad36..7023692d495 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -865,11 +865,11 @@ def get_max_parallelism(self) -> list[tuple[int, int]]: """ parallelism_settings = [] for node in self.generation_strategy._nodes: - # Extract max_parallelism from MaxGenerationParallelism criterion + # Check generation_blocking_criteria for max parallelism max_parallelism = None - for tc in node.transition_criteria: - if isinstance(tc, MaxGenerationParallelism): - max_parallelism = tc.threshold + for bc in node.generation_blocking_criteria: + if isinstance(bc, MaxGenerationParallelism): + max_parallelism = bc.threshold break # Try to get num_trials from the node. If there's no MinTrials # criterion (unlimited trials), num_trials will raise UserInputError. diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index afe81f05d01..b9334d53a63 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -59,7 +59,7 @@ from ax.generation_strategy.generator_spec import GeneratorSpec from ax.generation_strategy.transition_criterion import ( MaxGenerationParallelism, - MinTrials, + MaxTrialsAwaitingData, ) from ax.metrics.branin import branin, BraninMetric from ax.runners.synthetic import SyntheticRunner @@ -1606,22 +1606,25 @@ def test_keep_generating_without_data(self) -> None: {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], ) - # Check that enforce_num_trials is False by checking the MinTrials criterion - # has block_gen_if_met=False - node0_min_trials = [ - tc - for tc in ax_client.generation_strategy._nodes[0].transition_criteria - if isinstance(tc, MinTrials) + # Check that enforce_num_trials is False by verifying no + # MaxTrialsAwaitingData exists in generation_blocking_criteria + node0_blocking_criteria = [ + bc + for bc in ax_client.generation_strategy._nodes[ + 0 + ].generation_blocking_criteria + if isinstance(bc, MaxTrialsAwaitingData) ] - self.assertTrue(len(node0_min_trials) > 0) - self.assertFalse(node0_min_trials[0].block_gen_if_met) + self.assertEqual(len(node0_blocking_criteria), 0) # Check that max_parallelism is None by verifying no MaxGenerationParallelism - # criterion exists on node 1 + # criterion exists in generation_blocking_criteria node1_max_parallelism = [ - tc - for tc in ax_client.generation_strategy._nodes[1].transition_criteria - if isinstance(tc, MaxGenerationParallelism) + bc + for bc in ax_client.generation_strategy._nodes[ + 1 + ].generation_blocking_criteria + if isinstance(bc, MaxGenerationParallelism) ] self.assertEqual(len(node1_max_parallelism), 0) diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 0beccd47c09..c08beca8742 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -16,7 +16,7 @@ from inspect import isclass from io import StringIO from logging import Logger -from typing import Any +from typing import Any, TypeVar import numpy as np import pandas as pd @@ -43,7 +43,11 @@ GenerationStrategy, ) from ax.generation_strategy.generator_spec import GeneratorSpec -from ax.generation_strategy.transition_criterion import MinTrials, TransitionCriterion +from ax.generation_strategy.transition_criterion import ( + GenerationBlockingCriterion, + MinTrials, + TransitionCriterion, +) from ax.generators.torch.botorch_modular.generator import BoTorchGenerator from ax.generators.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.generators.torch.botorch_modular.utils import ModelConfig @@ -70,7 +74,7 @@ from botorch.utils.types import DEFAULT from pyre_extensions import assert_is_instance, none_throws - +T = TypeVar("T") logger: Logger = get_logger(__name__) @@ -297,11 +301,7 @@ def object_from_json( object_json["outcome_transform_options"] = ( outcome_transform_options_json ) - elif ( - isclass(_class) - and issubclass(_class, TransitionCriterion) - and _class is not TransitionCriterion # TransitionCriterion is abstract - ): + elif isclass(_class) and issubclass(_class, TransitionCriterion): # TransitionCriterion may contain nested Ax objects (TrialStatus, etc.) # that need recursive deserialization via object_from_json. return transition_criterion_from_json( @@ -309,6 +309,13 @@ def object_from_json( object_json=object_json, **vars(registry_kwargs), ) + elif isclass(_class) and issubclass(_class, GenerationBlockingCriterion): + # GenerationBlockingCriterion is similar to TransitionCriterion + return generation_blocking_criterion_from_json( + blocking_criterion_class=_class, + object_json=object_json, + **vars(registry_kwargs), + ) elif isclass(_class) and issubclass(_class, SerializationMixin): # Special handling for Data backward compatibility if _class is Data: @@ -447,19 +454,37 @@ def generator_run_from_json( return generator_run +def _criterion_from_json( + criterion_class: type[T], + object_json: dict[str, Any], + decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, + class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, +) -> T: + """Generic helper to load criterion objects from JSON. + + Handles recursive deserialization of nested Ax objects and filters + to valid constructor arguments for backwards compatibility. + """ + decoded = { + key: object_from_json( + object_json=value, + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) + for key, value in object_json.items() + } + init_args = extract_init_args(args=decoded, class_=criterion_class) + # pyre-ignore[45]: Class passed is always a concrete subclass. + return criterion_class(**init_args) + + def transition_criterion_from_json( transition_criterion_class: type[TransitionCriterion], object_json: dict[str, Any], decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> TransitionCriterion: - """Load TransitionCriterion from JSON. - - TransitionCriterion subclasses may contain nested Ax objects (like TrialStatus - enums and AuxiliaryExperimentPurpose) that need recursive deserialization via - object_from_json. We also use extract_init_args for backwards compatibility, - filtering to only valid constructor arguments. - """ + """Load TransitionCriterion from JSON.""" # Handle deprecated MinimumTrialsInStatus -> MinTrials conversion if transition_criterion_class is MinTrials and "status" in object_json: logger.warning( @@ -478,20 +503,27 @@ def transition_criterion_from_json( use_all_trials_in_exp=True, ) - decoded = { - key: object_from_json( - object_json=value, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, - ) - for key, value in object_json.items() - } + return _criterion_from_json( + criterion_class=transition_criterion_class, + object_json=object_json, + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) - # filter to only valid constructor args (backwards compatibility) - init_args = extract_init_args(args=decoded, class_=transition_criterion_class) - # pyre-ignore[45]: Class passed is always a concrete subclass. - return transition_criterion_class(**init_args) +def generation_blocking_criterion_from_json( + blocking_criterion_class: type[GenerationBlockingCriterion], + object_json: dict[str, Any], + decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, + class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, +) -> GenerationBlockingCriterion: + """Load GenerationBlockingCriterion from JSON.""" + return _criterion_from_json( + criterion_class=blocking_criterion_class, + object_json=object_json, + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) def search_space_from_json( @@ -848,13 +880,47 @@ def generation_node_from_json( # if needed during _validate_and_set_step_sequence. generation_node_json.pop("step_index", None) - # Backwards compatibility: For transition criteria with transition_to=None - # set transition_to to point to itself. - transition_criteria_json = generation_node_json.pop("transition_criteria") + transition_criteria_json = generation_node_json.pop("transition_criteria", None) + generation_blocking_criteria_json = generation_node_json.pop( + "generation_blocking_criteria", None + ) + + # Backwards compatibility: For old experiments, TransitionCriterion with + # block_gen_if_met=True need to be migrated to GenerationBlockingCriterion. + # Also, transition_to=None needs to be handled for transition criteria, these + # now point to self. if transition_criteria_json is not None: + migrated_blocking_criteria = [] + migrated_transition_criteria_json = [] for tc_json in transition_criteria_json: + tc_type = tc_json.get("__type", "") + + if tc_json.get("block_gen_if_met") is True: + if tc_type == "MaxGenerationParallelism": + migrated_blocking_criteria.append(tc_json) + elif tc_type == "MinTrials": + # Copy and update type + blocking_json = tc_json.copy() + blocking_json["__type"] = "MaxTrialsAwaitingData" + migrated_blocking_criteria.append(blocking_json) + + # Only keep in transition_criteria if block_transition_if_unmet=True + if tc_json.get("block_transition_if_unmet") is not True: + continue # Skip adding to transition_criteria + + # handle transition_to=None if tc_json.get("transition_to") is None: tc_json["transition_to"] = name + migrated_transition_criteria_json.append(tc_json) + + # Merge migrated criteria with any existing generation_blocking_criteria + if generation_blocking_criteria_json is None: + generation_blocking_criteria_json = migrated_blocking_criteria + else: + generation_blocking_criteria_json = ( + migrated_blocking_criteria + generation_blocking_criteria_json + ) + transition_criteria_json = migrated_transition_criteria_json return GenerationNode( name=name, @@ -874,6 +940,11 @@ def generation_node_from_json( decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ), + generation_blocking_criteria=object_from_json( + object_json=generation_blocking_criteria_json, + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ), input_constructors=decoded_input_constructors, previous_node_name=( generation_node_json.pop("previous_node_name") diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index bfd6157f129..ca7267cb02c 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -50,7 +50,10 @@ GenerationStrategy, ) from ax.generation_strategy.generator_spec import GeneratorSpec -from ax.generation_strategy.transition_criterion import TransitionCriterion +from ax.generation_strategy.transition_criterion import ( + GenerationBlockingCriterion, + TransitionCriterion, +) from ax.generators.torch.botorch_modular.generator import BoTorchGenerator from ax.generators.torch.botorch_modular.surrogate import Surrogate from ax.generators.winsorization_config import WinsorizationConfig @@ -408,6 +411,7 @@ def generation_node_to_dict(generation_node: GenerationNode) -> dict[str, Any]: "best_model_selector": generation_node.best_model_selector, "should_deduplicate": generation_node.should_deduplicate, "transition_criteria": generation_node.transition_criteria, + "generation_blocking_criteria": generation_node.generation_blocking_criteria, "generator_spec_to_gen_from": generation_node._generator_spec_to_gen_from, "previous_node_name": generation_node._previous_node_name, "trial_type": generation_node._trial_type, @@ -445,6 +449,15 @@ def transition_criterion_to_dict(criterion: TransitionCriterion) -> dict[str, An return properties +def generation_blocking_criterion_to_dict( + criterion: GenerationBlockingCriterion, +) -> dict[str, Any]: + """Convert Ax GenerationBlockingCriterion to a dictionary.""" + properties = serialize_init_args(obj=criterion) + properties["__type"] = criterion.__class__.__name__ + return properties + + def generator_spec_to_dict(generator_spec: GeneratorSpec) -> dict[str, Any]: """Convert Ax model spec to a dictionary.""" return { diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index 842a90968be..561a4d09c49 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -77,6 +77,7 @@ AuxiliaryExperimentCheck, IsSingleObjective, MaxGenerationParallelism, + MaxTrialsAwaitingData, MinTrials, TransitionCriterion, ) @@ -125,6 +126,7 @@ derived_parameter_to_dict, experiment_to_dict, fixed_parameter_to_dict, + generation_blocking_criterion_to_dict, generation_node_to_dict, generation_strategy_to_dict, generator_run_to_dict, @@ -214,7 +216,8 @@ L2NormMetric: metric_to_dict, LogNormalPrior: botorch_component_to_dict, MapMetric: metric_to_dict, - MaxGenerationParallelism: transition_criterion_to_dict, + MaxGenerationParallelism: generation_blocking_criterion_to_dict, + MaxTrialsAwaitingData: generation_blocking_criterion_to_dict, Metric: metric_to_dict, MinTrials: transition_criterion_to_dict, AuxiliaryExperimentCheck: transition_criterion_to_dict, @@ -341,6 +344,7 @@ "MapMetric": MapMetric, "MaxTrials": MinTrials, "MaxGenerationParallelism": MaxGenerationParallelism, + "MaxTrialsAwaitingData": MaxTrialsAwaitingData, "Metric": Metric, "MinTrials": MinTrials, # DEPRECATED; backward compatibility for MinimumTrialsInStatus -> MinTrials diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index c36f0fb5b5e..ebf908e9fdf 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -1236,6 +1236,130 @@ def test_generation_node_backwards_compatibility(self) -> None: ) self.assertEqual(node.generator_specs[0].cv_kwargs, {"test_cv_kwarg": True}) + def test_block_gen_if_met_migration(self) -> None: + """Test that TransitionCriteria with block_gen_if_met=True are migrated + to GenerationBlockingCriterion during deserialization.""" + with self.subTest("MaxGenerationParallelism_with_block_gen_if_met"): + # MaxGenerationParallelism with block_gen_if_met=True should be + # migrated to generation_blocking_criteria and removed from + # transition_criteria + json = { + "node_name": "test_node", + "model_specs": [ + { + "__type": "GeneratorSpec", + "model_enum": {"__type": "Generators", "name": "SOBOL"}, + "model_kwargs": {}, + "model_gen_kwargs": {}, + "model_cv_kwargs": {}, + } + ], + "best_model_selector": None, + "should_deduplicate": False, + "transition_criteria": [ + { + "__type": "MaxGenerationParallelism", + "threshold": 3, + "only_in_statuses": [ + {"__type": "TrialStatus", "name": "RUNNING"} + ], + "block_gen_if_met": True, + "transition_to": None, + } + ], + } + node = generation_node_from_json(json) + self.assertEqual(len(node.transition_criteria), 0) + self.assertEqual(len(node.generation_blocking_criteria), 1) + blocking = node.generation_blocking_criteria[0] + self.assertEqual(blocking.__class__.__name__, "MaxGenerationParallelism") + # pyre-ignore[16]: Attribute exists on MaxGenerationParallelism + self.assertEqual(blocking.threshold, 3) + + with self.subTest("MinTrials_with_block_gen_if_met_only"): + # MinTrials with block_gen_if_met=True only should be migrated to + # MaxTrialsAwaitingData and removed from transition_criteria + json = { + "node_name": "test_node", + "model_specs": [ + { + "__type": "GeneratorSpec", + "model_enum": {"__type": "Generators", "name": "SOBOL"}, + "model_kwargs": {}, + "model_gen_kwargs": {}, + "model_cv_kwargs": {}, + } + ], + "best_model_selector": None, + "should_deduplicate": False, + "transition_criteria": [ + { + "__type": "MinTrials", + "threshold": 5, + "only_in_statuses": [ + {"__type": "TrialStatus", "name": "RUNNING"} + ], + "not_in_statuses": None, + "use_all_trials_in_exp": False, + "block_gen_if_met": True, + "block_transition_if_unmet": False, + "transition_to": None, + } + ], + } + node = generation_node_from_json(json) + self.assertEqual(len(node.transition_criteria), 0) + self.assertEqual(len(node.generation_blocking_criteria), 1) + blocking = node.generation_blocking_criteria[0] + self.assertEqual(blocking.__class__.__name__, "MaxTrialsAwaitingData") + self.assertEqual(blocking.threshold, 5) + + with self.subTest("MinTrials_with_block_gen_if_met_and_block_transition"): + # MinTrials with both block_gen_if_met=True and + # block_transition_if_unmet=True should create + # MaxTrialsAwaitingData AND keep in transition_criteria + json = { + "node_name": "test_node", + "model_specs": [ + { + "__type": "GeneratorSpec", + "model_enum": {"__type": "Generators", "name": "SOBOL"}, + "model_kwargs": {}, + "model_gen_kwargs": {}, + "model_cv_kwargs": {}, + } + ], + "best_model_selector": None, + "should_deduplicate": False, + "transition_criteria": [ + { + "__type": "MinTrials", + "threshold": 5, + "only_in_statuses": None, + "not_in_statuses": None, + "use_all_trials_in_exp": True, + "block_gen_if_met": True, + "block_transition_if_unmet": True, + "transition_to": "next_node", + } + ], + } + node = generation_node_from_json(json) + # Should have both + self.assertEqual(len(node.transition_criteria), 1) + self.assertEqual(len(node.generation_blocking_criteria), 1) + tc = node.transition_criteria[0] + self.assertEqual(tc.__class__.__name__, "MinTrials") + # pyre-ignore[16]: Attribute exists on MinTrials + self.assertEqual(tc.threshold, 5) + self.assertEqual(tc.transition_to, "next_node") + blocking = node.generation_blocking_criteria[0] + self.assertEqual(blocking.__class__.__name__, "MaxTrialsAwaitingData") + # pyre-ignore[16]: threshold exists on MaxTrialsAwaitingData + self.assertEqual(blocking.threshold, 5) + # pyre-ignore[16]: use_all_trials_in_exp exists on MaxTrialsAwaitingData + self.assertTrue(blocking.use_all_trials_in_exp) + def test_SobolQMCNormalSampler(self) -> None: # This fails default equality checks, so testing it separately. sampler = SobolQMCNormalSampler(sample_shape=torch.Size([2])) diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index d0a9df77ce4..4e277da9d2a 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -55,13 +55,17 @@ ) from ax.exceptions.storage import JSONDecodeError, SQADecodeError, SQAEncodeError from ax.generation_strategy.dispatch_utils import choose_generation_strategy_legacy -from ax.generation_strategy.transition_criterion import MaxGenerationParallelism +from ax.generation_strategy.transition_criterion import ( + MaxGenerationParallelism, + MaxTrialsAwaitingData, + MinTrials, +) from ax.generators.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.metrics.branin import BraninMetric from ax.runners.synthetic import SyntheticRunner from ax.storage.json_store.decoder import ( + generation_blocking_criterion_from_json, generation_node_from_json, - transition_criterion_from_json, ) from ax.storage.json_store.registry import ( CORE_CLASS_DECODER_REGISTRY, @@ -3397,7 +3401,10 @@ def test_load_candidate_source_auxiliary_experiments(self) -> None: def test_transition_criterion_deserialize_with_extra_fields(self) -> None: """Test that deserialization gracefully handles extra/unknown fields - ie this validates that backwards compatibility is maintained""" + ie this validates that backwards compatibility is maintained for + MaxGenerationParallelism, which is now a GenerationBlockingCriterion. + Old serialized experiments may have transition_to, block_gen_if_met, etc. + fields that are now deprecated and should be ignored.""" # Simulate old serialized format with extra fields that no longer exist old_format_json = { "threshold": 5, @@ -3411,10 +3418,12 @@ def test_transition_criterion_deserialize_with_extra_fields(self) -> None: "some_deprecated_field": "should_be_ignored", } - # Should not raise, extra field should be ignored + # Should not raise, extra fields should be ignored. + # Note: MaxGenerationParallelism is now a GenerationBlockingCriterion, + # so we use generation_blocking_criterion_from_json. criterion = assert_is_instance( - transition_criterion_from_json( - transition_criterion_class=MaxGenerationParallelism, + generation_blocking_criterion_from_json( + blocking_criterion_class=MaxGenerationParallelism, object_json=old_format_json, decoder_registry=CORE_DECODER_REGISTRY, class_decoder_registry=CORE_CLASS_DECODER_REGISTRY, @@ -3422,13 +3431,13 @@ def test_transition_criterion_deserialize_with_extra_fields(self) -> None: MaxGenerationParallelism, ) self.assertEqual(criterion.threshold, 5) - self.assertEqual(criterion.transition_to, "test_node") def test_gen_node_deserialize_with_tc_transition_to_none( self, ) -> None: - """Test backwards compatibility when loading a MaxGenerationParallelism - that was stored with transition_to=None + """Test backwards compatibility when loading an old GenerationNode that + has MaxGenerationParallelism stored in transition_criteria. The decoder + should automatically migrate it to generation_blocking_criteria. """ old_format_node_json = { "__type": "GenerationNode", @@ -3446,7 +3455,8 @@ def test_gen_node_deserialize_with_tc_transition_to_none( "__type": "MaxGenerationParallelism", "threshold": 3, "only_in_statuses": [{"__type": "TrialStatus", "name": "RUNNING"}], - "transition_to": None, # Old default + "block_gen_if_met": True, + "transition_to": None, # Old default - should be ignored } ], } @@ -3457,11 +3467,65 @@ def test_gen_node_deserialize_with_tc_transition_to_none( class_decoder_registry=CORE_CLASS_DECODER_REGISTRY, ) self.assertEqual(node.name, "test_node") - self.assertEqual(len(node.transition_criteria), 1) + # MaxGenerationParallelism should have been migrated from + # transition_criteria to generation_blocking_criteria + self.assertEqual(len(node.transition_criteria), 0) + self.assertEqual(len(node.generation_blocking_criteria), 1) criterion = assert_is_instance( - node.transition_criteria[0], + node.generation_blocking_criteria[0], MaxGenerationParallelism, ) self.assertEqual(criterion.threshold, 3) - # transition_to should now be set to the node name (pointing to itself) - self.assertEqual(criterion.transition_to, "test_node") + + def test_block_gen_if_met_mintrials_migration(self) -> None: + """Test backwards compatibility when loading an old GenerationNode that + has MinTrials with block_gen_if_met=True. The decoder should + automatically migrate it to MaxTrialsAwaitingData in + generation_blocking_criteria. + """ + # MinTrials with both block_gen_if_met=True and + # block_transition_if_unmet=True should create MaxTrialsAwaitingData + # AND keep in transition_criteria + old_format_node_json = { + "__type": "GenerationNode", + "name": "test_node", + "generator_specs": [ + { + "__type": "GeneratorSpec", + "generator_enum": {"__type": "Generators", "name": "SOBOL"}, + "generator_kwargs": {}, + "generator_gen_kwargs": {}, + } + ], + "transition_criteria": [ + { + "__type": "MinTrials", + "threshold": 5, + "only_in_statuses": None, + "not_in_statuses": None, + "use_all_trials_in_exp": True, + "block_gen_if_met": True, + "block_transition_if_unmet": True, + "transition_to": "next_node", + } + ], + } + + node = generation_node_from_json( + generation_node_json=old_format_node_json, + decoder_registry=CORE_DECODER_REGISTRY, + class_decoder_registry=CORE_CLASS_DECODER_REGISTRY, + ) + self.assertEqual(node.name, "test_node") + # Should have both, one to represent the blocking criterion and + # one to represent the transition criterion + self.assertEqual(len(node.transition_criteria), 1) + self.assertEqual(len(node.generation_blocking_criteria), 1) + tc = assert_is_instance(node.transition_criteria[0], MinTrials) + self.assertEqual(tc.threshold, 5) + self.assertEqual(tc.transition_to, "next_node") + blocking = assert_is_instance( + node.generation_blocking_criteria[0], MaxTrialsAwaitingData + ) + self.assertEqual(blocking.threshold, 5) + self.assertTrue(blocking.use_all_trials_in_exp) diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index c28bfad5399..9535e1b5747 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -88,7 +88,9 @@ from ax.generation_strategy.generator_spec import GeneratorSpec from ax.generation_strategy.transition_criterion import ( AutoTransitionAfterGen, + GenerationBlockingCriterion, MaxGenerationParallelism, + MaxTrialsAwaitingData, MinTrials, TransitionCriterion, ) @@ -179,6 +181,7 @@ def get_experiment_with_map_data_type() -> Experiment: def get_trial_based_criterion() -> list[TransitionCriterion]: + """Returns a list of trial-based TransitionCriteria for testing.""" return [ MinTrials( threshold=3, @@ -186,16 +189,21 @@ def get_trial_based_criterion() -> list[TransitionCriterion]: only_in_statuses=[TrialStatus.RUNNING, TrialStatus.COMPLETED], not_in_statuses=None, ), + AutoTransitionAfterGen( + transition_to="next_node", + ), + ] + + +def get_generation_blocking_criterion() -> list[GenerationBlockingCriterion]: + """Returns a list of GenerationBlockingCriteria for testing.""" + return [ MaxGenerationParallelism( threshold=5, only_in_statuses=None, not_in_statuses=[ TrialStatus.RUNNING, ], - transition_to="Sobol", - ), - AutoTransitionAfterGen( - transition_to="next_node", ), ] @@ -2960,14 +2968,12 @@ def get_online_sobol_mbm_generation_strategy() -> GenerationStrategy: MinTrials( threshold=1, transition_to="MBM_node", - block_gen_if_met=True, only_in_statuses=None, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], ), MinTrials( threshold=1, transition_to="MBM_node", - block_gen_if_met=True, only_in_statuses=[ TrialStatus.RUNNING, TrialStatus.COMPLETED, @@ -2985,9 +2991,25 @@ def get_online_sobol_mbm_generation_strategy() -> GenerationStrategy: generator_kwargs=step_generator_kwargs, generator_gen_kwargs={}, ) + sobol_blocking_criteria = [ + MaxTrialsAwaitingData( + threshold=1, + only_in_statuses=None, + not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], + ), + MaxTrialsAwaitingData( + threshold=1, + only_in_statuses=[ + TrialStatus.RUNNING, + TrialStatus.COMPLETED, + TrialStatus.EARLY_STOPPED, + ], + ), + ] sobol_node = GenerationNode( name="sobol_node", transition_criteria=sobol_criterion, + generation_blocking_criteria=sobol_blocking_criteria, generator_specs=[sobol_generator_spec], input_constructors={InputConstructorPurpose.N: NodeInputConstructors.ALL_N}, ) diff --git a/ax/utils/testing/modeling_stubs.py b/ax/utils/testing/modeling_stubs.py index 6e11278a8b6..75f7a73407b 100644 --- a/ax/utils/testing/modeling_stubs.py +++ b/ax/utils/testing/modeling_stubs.py @@ -41,6 +41,7 @@ AutoTransitionAfterGen, IsSingleObjective, MaxGenerationParallelism, + MaxTrialsAwaitingData, MinTrials, ) from ax.generators.torch.botorch_modular.surrogate import ( @@ -184,7 +185,6 @@ def sobol_gpei_generation_node_gs( MinTrials( threshold=5, transition_to="MBM_node", - block_gen_if_met=True, only_in_statuses=None, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], ) @@ -195,7 +195,6 @@ def sobol_gpei_generation_node_gs( MinTrials( threshold=2, transition_to="MBM_node", - block_gen_if_met=True, only_in_statuses=None, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], ), @@ -204,17 +203,20 @@ def sobol_gpei_generation_node_gs( MinTrials( threshold=0, transition_to="MBM_node", - block_gen_if_met=False, only_in_statuses=[TrialStatus.CANDIDATE], not_in_statuses=None, ), + ] + sobol_blocking_criteria = [ + MaxTrialsAwaitingData( + threshold=5, + only_in_statuses=None, + not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], + ), MaxGenerationParallelism( threshold=1000, - transition_to="MBM_node", - block_gen_if_met=True, only_in_statuses=[TrialStatus.RUNNING], not_in_statuses=None, - continue_trial_generation=False, ), ] auto_mbm_criterion = [AutoTransitionAfterGen(transition_to="MBM_node")] @@ -235,6 +237,7 @@ def sobol_gpei_generation_node_gs( sobol_node = GenerationNode( name="sobol_node", transition_criteria=sobol_criterion, + generation_blocking_criteria=sobol_blocking_criteria, generator_specs=[sobol_generator_spec], ) if with_model_selection: diff --git a/tutorials/external_generation_node/external_generation_node.ipynb b/tutorials/external_generation_node/external_generation_node.ipynb index 928a551a921..8e5933086e8 100644 --- a/tutorials/external_generation_node/external_generation_node.ipynb +++ b/tutorials/external_generation_node/external_generation_node.ipynb @@ -47,7 +47,7 @@ "from ax.generation_strategy.generation_node import GenerationNode\n", "from ax.generation_strategy.generation_strategy import GenerationStrategy\n", "from ax.generation_strategy.generator_spec import GeneratorSpec\n", - "from ax.generation_strategy.transition_criterion import MinTrials\n", + "from ax.generation_strategy.transition_criterion import MaxTrialsAwaitingData, MinTrials\n", "\n", "from sklearn.ensemble import RandomForestRegressor" ] @@ -231,10 +231,16 @@ " # This specifies the maximum number of trials to generate from this node,\n", " # and the next node in the strategy.\n", " threshold=25,\n", - " block_transition_if_unmet=True,\n", " transition_to=\"RandomForest\",\n", " )\n", " ],\n", + " generation_blocking_criteria=[\n", + " MaxTrialsAwaitingData(\n", + " # Block generation once we have 25 trials with data.\n", + " threshold=25,\n", + " not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],\n", + " )\n", + " ],\n", " ),\n", " RandomForestGenerationNode(num_samples=128, regressor_options={}),\n", " ],\n", @@ -388,6 +394,8 @@ ], "metadata": { "fileHeader": "", + "fileUid": "0712ebfb-bf0d-43ad-9aec-e0447e101ece", + "isAdHoc": false, "kernelspec": { "display_name": "python3", "language": "python",