diff --git a/ax/api/utils/generation_strategy_dispatch.py b/ax/api/utils/generation_strategy_dispatch.py index f3f60d2fa35..91c1f9576a2 100644 --- a/ax/api/utils/generation_strategy_dispatch.py +++ b/ax/api/utils/generation_strategy_dispatch.py @@ -21,7 +21,7 @@ GenerationStrategy, ) from ax.generation_strategy.generator_spec import GeneratorSpec -from ax.generation_strategy.transition_criterion import MinTrials +from ax.generation_strategy.transition_criterion import MaxTrialsAwaitingData, MinTrials from ax.generators.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec from botorch.acquisition.acquisition import AcquisitionFunction from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP @@ -70,21 +70,28 @@ def _get_sobol_node( MinTrials( # This represents the initialization budget. threshold=initialization_budget, transition_to="MBM", - block_gen_if_met=(not allow_exceeding_initialization_budget), - block_transition_if_unmet=True, use_all_trials_in_exp=use_existing_trials_for_initialization, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], ), MinTrials( # This represents minimum observed trials requirement. threshold=min_observed_initialization_trials, transition_to="MBM", - block_gen_if_met=False, - block_transition_if_unmet=True, use_all_trials_in_exp=True, only_in_statuses=[TrialStatus.COMPLETED], count_only_trials_with_data=True, ), ] + # If we want to enforce the initialization budget, add a generation blocking + # criterion that prevents exceeding the budget. + generation_blocking_criteria = None + if not allow_exceeding_initialization_budget: + generation_blocking_criteria = [ + MaxTrialsAwaitingData( + threshold=initialization_budget, + not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], + use_all_trials_in_exp=use_existing_trials_for_initialization, + ) + ] return GenerationNode( name="Sobol", generator_specs=[ @@ -94,6 +101,7 @@ def _get_sobol_node( ) ], transition_criteria=transition_criteria, + generation_blocking_criteria=generation_blocking_criteria, should_deduplicate=True, ) diff --git a/ax/api/utils/tests/test_generation_strategy_dispatch.py b/ax/api/utils/tests/test_generation_strategy_dispatch.py index 958eccc0a46..9e844acb660 100644 --- a/ax/api/utils/tests/test_generation_strategy_dispatch.py +++ b/ax/api/utils/tests/test_generation_strategy_dispatch.py @@ -19,7 +19,7 @@ from ax.exceptions.core import UserInputError from ax.generation_strategy.center_generation_node import CenterGenerationNode from ax.generation_strategy.dispatch_utils import get_derelativize_config -from ax.generation_strategy.transition_criterion import MinTrials +from ax.generation_strategy.transition_criterion import MaxTrialsAwaitingData, MinTrials from ax.generators.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( @@ -99,16 +99,12 @@ def test_choose_gs_fast_with_options(self) -> None: MinTrials( threshold=2, transition_to="MBM", - block_gen_if_met=False, - block_transition_if_unmet=True, use_all_trials_in_exp=False, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], ), MinTrials( threshold=4, transition_to="MBM", - block_gen_if_met=False, - block_transition_if_unmet=True, use_all_trials_in_exp=True, only_in_statuses=[TrialStatus.COMPLETED], count_only_trials_with_data=True, @@ -390,7 +386,18 @@ def test_abandoned_and_failed_trials_excluded_from_initialization_budget( first_tc.not_in_statuses, [TrialStatus.FAILED, TrialStatus.ABANDONED] ) self.assertEqual(first_tc.threshold, 5) - self.assertTrue(first_tc.block_gen_if_met) + # Verify MaxTrialsAwaitingData is in generation_blocking_criteria + blocking_criteria = [ + bc + for bc in sobol_node._generation_blocking_criteria + if isinstance(bc, MaxTrialsAwaitingData) + ] + self.assertEqual(len(blocking_criteria), 1) + self.assertEqual(blocking_criteria[0].threshold, 5) + self.assertEqual( + blocking_criteria[0].not_in_statuses, + [TrialStatus.FAILED, TrialStatus.ABANDONED], + ) # Test the actual behavior: Generate 5 trials, mark 3 as ABANDONED, # verify that Sobol can still generate more trials diff --git a/ax/generation_strategy/generation_node.py b/ax/generation_strategy/generation_node.py index a3efebc8cd7..aa3c4237ef4 100644 --- a/ax/generation_strategy/generation_node.py +++ b/ax/generation_strategy/generation_node.py @@ -41,11 +41,12 @@ ) from ax.generation_strategy.generator_spec import GeneratorSpec from ax.generation_strategy.transition_criterion import ( - AutoTransitionAfterGen, + GenerationBlockingCriterion, MaxGenerationParallelism, + MaxTrialsAwaitingData, MinTrials, TransitionCriterion, - TrialBasedCriterion, + TrialCountBlockingCriterion, ) from ax.utils.common.base import SortableBase from ax.utils.common.constants import Keys @@ -131,6 +132,7 @@ class GenerationNode(SerializationMixin, SortableBase): _generator_spec_to_gen_from: GeneratorSpec | None = None # TODO: @mgarrard should this be a dict criterion_class name -> criterion mapping? _transition_criteria: Sequence[TransitionCriterion] + _generation_blocking_criteria: Sequence[GenerationBlockingCriterion] _input_constructors: TInputConstructorsByPurpose _previous_node_name: str | None = None _trial_type: str | None = None @@ -150,6 +152,8 @@ def __init__( name: str, generator_specs: list[GeneratorSpec], transition_criteria: Sequence[TransitionCriterion] | None = None, + generation_blocking_criteria: Sequence[GenerationBlockingCriterion] + | None = None, best_model_selector: BestModelSelector | None = None, should_deduplicate: bool = False, input_constructors: TInputConstructorsByPurpose | None = None, @@ -184,6 +188,7 @@ def __init__( self.best_model_selector = best_model_selector self.should_deduplicate = should_deduplicate self._transition_criteria = transition_criteria or [] + self._generation_blocking_criteria = generation_blocking_criteria or [] self._input_constructors = input_constructors or {} self._previous_node_name = previous_node_name self._trial_type = trial_type @@ -191,6 +196,10 @@ def __init__( self.fallback_specs = ( fallback_specs if fallback_specs is not None else DEFAULT_FALLBACK ) + # Cache for trials_from_node property to avoid recomputation + # on every access. Invalidated when trial count changes. + self._trials_from_node_cache: set[int] | None = None + self._cached_trial_count: int = -1 @property def name(self) -> str: @@ -237,6 +246,17 @@ def transition_criteria(self) -> Sequence[TransitionCriterion]: """ return [] if self._transition_criteria is None else self._transition_criteria + @property + def generation_blocking_criteria(self) -> Sequence[GenerationBlockingCriterion]: + """Returns the sequence of GenerationBlockingCriteria that will be used to + block generation from this node without triggering a transition. + """ + return ( + [] + if self._generation_blocking_criteria is None + else self._generation_blocking_criteria + ) + @property def input_constructors(self) -> TInputConstructorsByPurpose: """Returns the input constructors that will be used to determine any dynamic @@ -249,21 +269,6 @@ def experiment(self) -> Experiment: """Returns the experiment associated with this GenerationStrategy""" return self.generation_strategy.experiment - @property - def is_completed(self) -> bool: - """Returns True if this GenerationNode is complete and should transition to - the next node. - """ - # TODO: @mgarrard make this logic more robust and general - # We won't mark a node completed if it has an AutoTransitionAfterGen criterion - # as this is typically used in cyclic generation strategies - should_transition, _ = self.should_transition_to_next_node( - raise_data_required_error=False - ) - return should_transition and not any( - isinstance(tc, AutoTransitionAfterGen) for tc in self.transition_criteria - ) - @property def previous_node(self) -> GenerationNode | None: """Returns the previous ``GenerationNode``, if any.""" @@ -307,23 +312,22 @@ def generator_name(self) -> str: def num_trials(self) -> int: """Returns the number of trials this node should generate. - Extracts the threshold from the first `MinTrials` transition criterion - that has `block_transition_if_unmet=True`. This represents the minimum - number of trials that must be generated before transitioning. + Extracts the threshold from the first `MinTrials` transition criterion. + This represents the minimum number of trials that must be generated + before transitioning. Returns: The number of trials (threshold value). Raises: - UserInputError: If no `MinTrials` transition criterion with - `block_transition_if_unmet=True` is found. + UserInputError: If no `MinTrials` transition criterion is found. """ for tc in self.transition_criteria: - if isinstance(tc, MinTrials) and tc.block_transition_if_unmet: + if isinstance(tc, MinTrials): return tc.threshold raise UserInputError( "`num_trials` property is only supported when a `MinTrials` " - "transition criterion with `block_transition_if_unmet=True` is present." + "transition criterion is present." ) @property @@ -366,6 +370,10 @@ def __repr__(self) -> str: str_rep += ( f", transition_criteria={str(self._brief_transition_criteria_repr())}" ) + str_rep += ( + ", generation_blocking_criteria=" + f"{str(self._brief_generation_blocking_criteria_repr())}" + ) return f"{str_rep})" def _fit( @@ -740,17 +748,28 @@ def _pick_fitted_adapter_to_gen_from(self) -> GeneratorSpec: def trials_from_node(self) -> set[int]: """Returns a set containing the indices of trials generated by this node. + Results are cached and invalidated when the experiment's trial count changes. + Returns: Set[int]: A set containing all the indices of trials generated by this node. """ + current_trial_count = len(self.experiment.trials) + if ( + self._trials_from_node_cache is not None + and self._cached_trial_count == current_trial_count + ): + return self._trials_from_node_cache + + # (re)-build cache trials_from_node = set() - for _idx, trial in self.experiment.trials.items(): + for trial in self.experiment.trials.values(): for gr in trial.generator_runs: - if ( - gr._generation_node_name is not None - and gr._generation_node_name == self.name - ): + if gr._generation_node_name == self.name: trials_from_node.add(trial.index) + break + + self._trials_from_node_cache = trials_from_node + self._cached_trial_count = current_trial_count return trials_from_node @property @@ -808,76 +827,66 @@ def should_transition_to_next_node( and the name of the node to gen from (either the current or next node) """ # if no transition criteria are defined, this node can generate unlimited trials - if len(self.transition_criteria) == 0: + if ( + len(self.transition_criteria) == 0 + and len(self.generation_blocking_criteria) == 0 + ): return False, self.name # For each "transition edge" (set of all transition criteria that lead from # current node (e.g. "node A") to another specific node ("e.g. "node B") # in the node DAG: - # I. Check if all of the transition criteria along that edge are met; if so, + # Check if all of the transition criteria along that edge are met; if so, # transition to the next node defined by that edge. - # II. If we did not transition along this edge, but the edge has some - # "generation blocking" transition criteria (ex `MaxGenerationParallelism`) - # that are met, raise the error associated with that criterion. for next_node, all_tc in self.transition_edges.items(): - # I. Check if there are any TCs that block transition and whether all - # of them are met. If all of them are met, then we should transition. - transition_blocking = [tc for tc in all_tc if tc.block_transition_if_unmet] - all_transition_blocking_met_should_transition = transition_blocking and all( + # Check if all TCs are met. If all of them are met, then we should + # transition. + all_tc_met_should_transition = all_tc and all( tc.is_met( experiment=self.experiment, curr_node=self, ) - for tc in transition_blocking + for tc in all_tc ) - if all_transition_blocking_met_should_transition: + if all_tc_met_should_transition: return True, next_node - # II. Raise any necessary generation errors: for any met criterion, - # call its `block_continued_generation_error` method if not all - # transition-blocking criteria are met. The method might not raise an - # error, depending on its implementation on given criterion, so the error - # from the first met one that does block continued generation, will raise. - if raise_data_required_error: - generation_blocking = [tc for tc in all_tc if tc.block_gen_if_met] - for tc in generation_blocking: - if tc.is_met(self.experiment, curr_node=self): - tc.block_continued_generation_error( - node_name=self.name, - experiment=self.experiment, - trials_from_node=self.trials_from_node, - ) - # TODO[@mgarrard, @drfreund] Try replacing `block_gen_if_met` with - # a self-transition and rework this error block. + # Only check generation blocking criteria if we're NOT transitioning. + # This ensures transition takes priority over blocking. + if raise_data_required_error: + for criterion in self.generation_blocking_criteria: + if criterion.is_met(self.experiment, curr_node=self): + criterion.block_continued_generation_error( + node_name=self.name, + experiment=self.experiment, + trials_from_node=self.trials_from_node, + ) return False, self.name def new_trial_limit(self, raise_generation_errors: bool = False) -> int: - """How many trials can this generation strategy can currently produce - ``GeneratorRun``-s for (with potentially multiple generator runs produced for - each intended trial). + """How many trials this node can currently produce GeneratorRun-s for. - NOTE: Only considers transition criteria that inherit from - ``TrialBasedCriterion``. + NOTE: Considers TrialCountBlockingCriterion subclasses for limiting + generation. Returns: The number of generator runs that can currently be produced, with -1 meaning unlimited generator runs. """ - # TODO: @mgarrard Should we consider returning `None` if there is no limit? - trial_based_gen_blocking_criteria = [ - criterion - for criterion in self.transition_criteria - if criterion.block_gen_if_met and isinstance(criterion, TrialBasedCriterion) - ] + # TODO: @mgarrard further improve and clarify this method # Cache trials_from_node to avoid repeated computation. trials_from_node = self.trials_from_node - gen_blocking_criterion_delta_from_threshold = [ - criterion.num_till_threshold( - experiment=self.experiment, trials_from_node=trials_from_node - ) - for criterion in trial_based_gen_blocking_criteria - ] + + # Compute limits from generation_blocking_criteria + gen_blocking_criterion_delta_from_threshold = [] + for criterion in self.generation_blocking_criteria: + if isinstance(criterion, TrialCountBlockingCriterion): + gen_blocking_criterion_delta_from_threshold.append( + criterion.num_till_threshold( + experiment=self.experiment, trials_from_node=trials_from_node + ) + ) # Raise any necessary generation errors: for any met criterion, # call its `block_continued_generation_error` method The method might not @@ -885,13 +894,10 @@ def new_trial_limit(self, raise_generation_errors: bool = False) -> int: # error from the first met one that does block continued generation, will be # raised. if raise_generation_errors: - for criterion in trial_based_gen_blocking_criteria: + for criterion in self.generation_blocking_criteria: # TODO[mgarrard]: Raise a group of all the errors, from each gen- # blocking transition criterion. - if criterion.is_met( - self.experiment, - curr_node=self, - ): + if criterion.is_met(self.experiment, curr_node=self): criterion.block_continued_generation_error( node_name=self.name, experiment=self.experiment, @@ -908,7 +914,7 @@ def _brief_transition_criteria_repr(self) -> str: Returns: str: A string representation of the transition criteria for this node. """ - if self.transition_criteria is None: + if not self._transition_criteria: return "None" tc_list = ", ".join( [ @@ -918,6 +924,25 @@ def _brief_transition_criteria_repr(self) -> str: ) return f"[{tc_list}]" + def _brief_generation_blocking_criteria_repr(self) -> str: + """Returns a brief string representation of the + generation blocking criteria for this node. + + Returns: + str: A string representation of the generation blocking criteria. + """ + if self._generation_blocking_criteria is None: + return "None" + bc_list = ", ".join( + [ + f"{bc.__class__.__name__}(threshold={bc.threshold})" + if isinstance(bc, TrialCountBlockingCriterion) + else f"{bc.__class__.__name__}()" + for bc in self.generation_blocking_criteria + ] + ) + return f"[{bc_list}]" + def apply_input_constructors( self, experiment: Experiment, @@ -1086,6 +1111,7 @@ def __new__( # is set in `GenerationStrategy` constructor, because only then is the order # of the generation steps actually known. transition_criteria: list[TransitionCriterion] = [] + generation_blocking_criteria: list[GenerationBlockingCriterion] = [] # Placeholder - will be overwritten in _validate_and_set_step_sequence in GS placeholder_transition_to = f"GenerationStep_{str(index)}" @@ -1095,11 +1121,18 @@ def __new__( threshold=num_trials, transition_to=placeholder_transition_to, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], - block_gen_if_met=enforce_num_trials, - block_transition_if_unmet=True, use_all_trials_in_exp=use_all_trials_in_exp, ) ) + # If enforce_num_trials is True, add a blocking criterion + if enforce_num_trials: + generation_blocking_criteria.append( + MaxTrialsAwaitingData( + threshold=num_trials, + not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], + use_all_trials_in_exp=use_all_trials_in_exp, + ) + ) if min_trials_observed > 0: transition_criteria.append( @@ -1110,19 +1143,14 @@ def __new__( TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED, ], - block_gen_if_met=False, - block_transition_if_unmet=True, use_all_trials_in_exp=use_all_trials_in_exp, ) ) if max_parallelism is not None: - transition_criteria.append( + generation_blocking_criteria.append( MaxGenerationParallelism( threshold=max_parallelism, - transition_to=placeholder_transition_to, only_in_statuses=[TrialStatus.RUNNING], - block_gen_if_met=True, - block_transition_if_unmet=False, ) ) @@ -1137,6 +1165,7 @@ def __new__( generator_specs=[generator_spec], should_deduplicate=should_deduplicate, transition_criteria=transition_criteria, + generation_blocking_criteria=generation_blocking_criteria, ) # Store step index on the node for naming in GenerationStrategy. diff --git a/ax/generation_strategy/generation_strategy.py b/ax/generation_strategy/generation_strategy.py index 059d929080b..dcef5b08160 100644 --- a/ax/generation_strategy/generation_strategy.py +++ b/ax/generation_strategy/generation_strategy.py @@ -180,8 +180,20 @@ def last_generator_run(self) -> GeneratorRun | None: @property def optimization_complete(self) -> bool: - """Checks whether all nodes are completed in the generation strategy.""" - return all(node.is_completed for node in self._nodes) + """Checks whether optimization is complete. + + A strategy is complete when the current node's transition criteria + are met and point back to itself (self-transition). + + Nodes with no transition_criteria are infinite by design and never complete. + """ + if len(self._curr.transition_criteria) == 0: + return False + + can_transition, next_node = self._curr.should_transition_to_next_node( + raise_data_required_error=False + ) + return can_transition and next_node == self._curr.name def gen_single_trial( self, @@ -349,6 +361,10 @@ def _unset_non_persistent_state_fields(self) -> None: n._step_index = None if len(n.generator_specs) > 1: n._generator_spec_to_gen_from = None + # Reset cache fields that are used for performance optimization only + # and should not affect equality comparisons. + n._trials_from_node_cache = None + n._cached_trial_count = -1 # TODO: Deprecate `steps` argument fully in Q1'26. def _validate_and_set_step_sequence(self, steps: list[GenerationNode]) -> None: @@ -382,11 +398,7 @@ def _validate_and_set_step_sequence(self, steps: list[GenerationNode]) -> None: ) ) for tc in step.transition_criteria: - if tc.criterion_class == "MaxGenerationParallelism": - # MaxGenerationParallelism transitions to self (current step) - tc._transition_to = step.name - else: - tc._transition_to = next_step_name + tc._transition_to = next_step_name self._curr = steps[0] def _validate_and_set_node_graph(self, nodes: list[GenerationNode]) -> None: @@ -421,24 +433,9 @@ def _validate_and_set_node_graph(self, nodes: list[GenerationNode]) -> None: # Validate transition edges: # - All `transition_to` targets must exist in this GS # - All TCs on one edge must have the same `continue_trial_generation` setting - # All but `MaxGenerationParallelism` TCs must have a `transition_to` set for node in nodes: for next_node, tcs in node.transition_edges.items(): - if next_node is None: - # TODO[drfreund]: Handle the case of the last generation step not - # having any transition criteria. - # TODO[mgarrard]: Remove MaxGenerationParallelism check when - # we update TransitionCriterion always define `transition_to` - # NOTE: This is done in D86066476 - for tc in tcs: - if "MaxGenerationParallelism" not in tc.criterion_class: - raise GenerationStrategyMisconfiguredException( - error_info="Only MaxGenerationParallelism transition" - " criterion can have a null `transition_to` argument," - f" but {tc.criterion_class} does not define " - f"`transition_to` on {node.name}." - ) - elif next_node not in node_names: + if next_node not in node_names: raise GenerationStrategyMisconfiguredException( error_info=f"`transition_to` argument " f"{next_node} does not correspond to any node in" @@ -600,7 +597,6 @@ def _should_continue_gen_for_trial(self) -> bool: # if we will transition nodes, check if the transition criterion which define # the transition from this node to the next node indicate that we should # continue generating in the same trial, otherwise end the generation. - assert next_node is not None return all( tc.continue_trial_generation for tc in self._curr.transition_edges[next_node] @@ -612,13 +608,13 @@ def _maybe_transition_to_next_node( self, raise_data_required_error: bool = True, ) -> bool: - """Moves this generation strategy to next node if the current node is completed, - and it is not the last node in this generation strategy. This method is safe to - use both when generating candidates or simply checking how many generator runs - (to be made into trials) can currently be produced. + """Moves this generation strategy to next node if the current node's + transition criteria are met. This method is safe to use both when generating + candidates or simply checking how many generator runs (to be made into trials) + can currently be produced. - NOTE: this method raises ``GenerationStrategyCompleted`` error if the current - generation node is complete, but it is also the last in generation strategy. + NOTE: this method raises ``GenerationStrategyCompleted`` error if the + optimization is complete Args: raise_data_required_error: Whether to raise ``DataRequiredError`` in the @@ -636,12 +632,5 @@ def _maybe_transition_to_next_node( f"Generation strategy {self} generated all the trials as " "specified in its nodes." ) - if next_node is None: - # If the last node did not specify which node to transition to, - # move to the next node in the list. - current_node_index = self._nodes.index(self._curr) - next_node = self._nodes[current_node_index + 1].name - for node in self._nodes: - if node.name == next_node: - self._curr = node + self._curr = self.nodes_by_name[next_node] return move_to_next_node diff --git a/ax/generation_strategy/tests/test_dispatch_utils.py b/ax/generation_strategy/tests/test_dispatch_utils.py index c382dd9fe3a..cd70d42a2ba 100644 --- a/ax/generation_strategy/tests/test_dispatch_utils.py +++ b/ax/generation_strategy/tests/test_dispatch_utils.py @@ -24,6 +24,7 @@ from ax.generation_strategy.generation_node import GenerationNode from ax.generation_strategy.transition_criterion import ( MaxGenerationParallelism, + MaxTrialsAwaitingData, MinTrials, ) from ax.generators.random.sobol import SobolGenerator @@ -615,18 +616,20 @@ def test_enforce_sequential_optimization(self) -> None: ).threshold, 5, ) - # Check that enforce_num_trials is True by verifying the MinTrials - # criterion has block_gen_if_met=True - node0_min_trials = assert_is_instance( - sobol_gpei._nodes[0].transition_criteria[0], MinTrials - ) - self.assertTrue(node0_min_trials.block_gen_if_met) + # Check that enforce_num_trials is True by verifying MaxTrialsAwaitingData + # exists in generation_blocking_criteria + node0_blocking_criteria = [ + bc + for bc in sobol_gpei._nodes[0].generation_blocking_criteria + if isinstance(bc, MaxTrialsAwaitingData) + ] + self.assertTrue(len(node0_blocking_criteria) > 0) # Check that max_parallelism is set by verifying MaxGenerationParallelism - # criterion exists on node 1 + # criterion exists in generation_blocking_criteria node1_max_parallelism = [ - tc - for tc in sobol_gpei._nodes[1].transition_criteria - if isinstance(tc, MaxGenerationParallelism) + bc + for bc in sobol_gpei._nodes[1].generation_blocking_criteria + if isinstance(bc, MaxGenerationParallelism) ] self.assertTrue(len(node1_max_parallelism) > 0) with self.subTest("False"): @@ -640,18 +643,20 @@ def test_enforce_sequential_optimization(self) -> None: ).threshold, 5, ) - # Check that enforce_num_trials is False by verifying the MinTrials - # criterion has block_gen_if_met=False - node0_min_trials = assert_is_instance( - sobol_gpei._nodes[0].transition_criteria[0], MinTrials - ) - self.assertFalse(node0_min_trials.block_gen_if_met) + # Check that enforce_num_trials is False by verifying no + # MaxTrialsAwaitingData exists in generation_blocking_criteria + node0_blocking_criteria = [ + bc + for bc in sobol_gpei._nodes[0].generation_blocking_criteria + if isinstance(bc, MaxTrialsAwaitingData) + ] + self.assertEqual(len(node0_blocking_criteria), 0) # Check that max_parallelism is None by verifying no - # MaxGenerationParallelism criterion exists on node 1 + # MaxGenerationParallelism criterion exists in generation_blocking_criteria node1_max_parallelism = [ - tc - for tc in sobol_gpei._nodes[1].transition_criteria - if isinstance(tc, MaxGenerationParallelism) + bc + for bc in sobol_gpei._nodes[1].generation_blocking_criteria + if isinstance(bc, MaxGenerationParallelism) ] self.assertEqual(len(node1_max_parallelism), 0) with self.subTest("False and max_parallelism_override"): @@ -818,10 +823,10 @@ def test_fixed_num_initialization_trials(self) -> None: ) def _get_max_parallelism(self, node: GenerationNode) -> int | None: - """Helper to extract max_parallelism from transition criteria.""" - for tc in node.transition_criteria: - if isinstance(tc, MaxGenerationParallelism): - return tc.threshold + """Helper to extract max_parallelism from generation_blocking_criteria.""" + for bc in node.generation_blocking_criteria: + if isinstance(bc, MaxGenerationParallelism): + return bc.threshold return None def test_max_parallelism_adjustments(self) -> None: diff --git a/ax/generation_strategy/tests/test_generation_node.py b/ax/generation_strategy/tests/test_generation_node.py index 3cae54f0b27..6e10f15ed37 100644 --- a/ax/generation_strategy/tests/test_generation_node.py +++ b/ax/generation_strategy/tests/test_generation_node.py @@ -335,7 +335,8 @@ def test_node_string_representation(self) -> None: "GenerationNode(name='test', " "generator_specs=[GeneratorSpec(generator_enum=BoTorch, " "generator_key_override=None)], " - "transition_criteria=[MinTrials(transition_to='next_node')])", + "transition_criteria=[MinTrials(transition_to='next_node')], " + "generation_blocking_criteria=[])", ) def test_single_fixed_features(self) -> None: @@ -445,8 +446,6 @@ def test_init(self) -> None: threshold=5, transition_to="GenerationStep_-1", # overwritten during GS init not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], - block_gen_if_met=True, - block_transition_if_unmet=True, use_all_trials_in_exp=False, ), ], @@ -471,8 +470,6 @@ def test_init(self) -> None: threshold=5, transition_to="GenerationStep_-1", # overwritten during GS init not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], - block_gen_if_met=False, - block_transition_if_unmet=True, use_all_trials_in_exp=True, ), MinTrials( @@ -482,8 +479,6 @@ def test_init(self) -> None: TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED, ], - block_gen_if_met=False, - block_transition_if_unmet=True, use_all_trials_in_exp=True, ), ], diff --git a/ax/generation_strategy/tests/test_generation_strategy.py b/ax/generation_strategy/tests/test_generation_strategy.py index a44c1a9af05..cb7f3e15862 100644 --- a/ax/generation_strategy/tests/test_generation_strategy.py +++ b/ax/generation_strategy/tests/test_generation_strategy.py @@ -215,7 +215,6 @@ def setUp(self) -> None: MinTrials( threshold=5, transition_to="MBM_node", - block_gen_if_met=True, only_in_statuses=None, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], ) @@ -226,7 +225,6 @@ def setUp(self) -> None: # this self-pointing isn't representative of real-world, but is # useful for testing attributes likes repr etc transition_to="MBM_node", - block_gen_if_met=True, only_in_statuses=None, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], ) @@ -235,7 +233,6 @@ def setUp(self) -> None: MinTrials( threshold=1, transition_to="mbm", - block_transition_if_unmet=True, only_in_statuses=[TrialStatus.RUNNING], ) ] @@ -266,14 +263,12 @@ def setUp(self) -> None: self.mbm_to_sobol2_with_running_trial = MinTrials( threshold=1, transition_to="sobol_2", - block_transition_if_unmet=True, only_in_statuses=[TrialStatus.RUNNING], use_all_trials_in_exp=True, ) self.mbm_to_sobol2_with_completed_trial = MinTrials( threshold=1, transition_to="sobol_2", - block_transition_if_unmet=True, only_in_statuses=[TrialStatus.COMPLETED], use_all_trials_in_exp=True, ) @@ -336,13 +331,11 @@ def setUp(self) -> None: MinTrials( threshold=2, transition_to="sobol_4", - block_transition_if_unmet=True, only_in_statuses=[TrialStatus.RUNNING], use_all_trials_in_exp=True, ), AutoTransitionAfterGen( transition_to="mbm", - block_transition_if_unmet=True, continue_trial_generation=False, ), ], @@ -439,11 +432,13 @@ def test_string_representation(self) -> None: "generator_specs=[GeneratorSpec(generator_enum=Sobol, " "generator_key_override=None)], " "transition_criteria=" - "[MinTrials(transition_to='GenerationStep_1_BoTorch')]), " + "[MinTrials(transition_to='GenerationStep_1_BoTorch')], " + "generation_blocking_criteria=[MaxTrialsAwaitingData(threshold=5)]), " "GenerationNode(name='GenerationStep_1_BoTorch', " "generator_specs=[GeneratorSpec(generator_enum=BoTorch, " "generator_key_override=None)], " - "transition_criteria=[])])" + "transition_criteria=None, " + "generation_blocking_criteria=[])])" ), ) gs2 = GenerationStrategy( @@ -456,7 +451,8 @@ def test_string_representation(self) -> None: "nodes=[GenerationNode(name='GenerationStep_0_Sobol', " "generator_specs=[GeneratorSpec(generator_enum=Sobol, " "generator_key_override=None)], " - "transition_criteria=[])])" + "transition_criteria=None, " + "generation_blocking_criteria=[])])" ), ) @@ -480,7 +476,8 @@ def test_string_representation(self) -> None: "name='test', " "generator_specs=[GeneratorSpec(generator_enum=Sobol, " "generator_key_override=None)], " - "transition_criteria=[])])", + "transition_criteria=None, " + "generation_blocking_criteria=[])])", ) def test_equality(self) -> None: @@ -1220,30 +1217,29 @@ def test_gen_with_fixed_features( # ---------- Tests for GenerationStrategies composed of GenerationNodes -------- def test_gs_setup_with_nodes(self) -> None: """Test GS initialization and validation with nodes""" - node_1_criterion = [ + node_1_transition_criteria = [ MinTrials( threshold=4, - block_gen_if_met=False, transition_to="node_2", only_in_statuses=None, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], ), MinTrials( - only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED], threshold=2, + only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED], transition_to="node_2", ), + ] + node_1_blocking_criteria = [ MaxGenerationParallelism( threshold=1, only_in_statuses=[TrialStatus.RUNNING], - block_gen_if_met=True, - block_transition_if_unmet=False, - transition_to="node_1", ), ] node_1 = GenerationNode( name="node_1", - transition_criteria=node_1_criterion, + transition_criteria=node_1_transition_criteria, + generation_blocking_criteria=node_1_blocking_criteria, generator_specs=[self.sobol_generator_spec], ) node_3 = GenerationNode( @@ -1311,7 +1307,6 @@ def test_gs_setup_with_nodes(self) -> None: transition_criteria=[ MinTrials( threshold=4, - block_gen_if_met=False, transition_to="node_2", only_in_statuses=None, not_in_statuses=[ @@ -1383,7 +1378,6 @@ def test_gs_with_suggested_n_is_zero(self) -> None: transition_criteria=[ AutoTransitionAfterGen( transition_to="sobol_2", - block_transition_if_unmet=True, continue_trial_generation=False, ), ], @@ -1440,7 +1434,6 @@ def test_gen_with_multiple_nodes_pending_points(self) -> None: transition_criteria=[ AutoTransitionAfterGen( transition_to="sobol_1", - block_transition_if_unmet=True, continue_trial_generation=False, ), ], @@ -1599,15 +1592,11 @@ def test_gs_with_nodes_and_blocking_criteria(self) -> None: transition_criteria=[ MinTrials( threshold=3, - block_gen_if_met=True, - block_transition_if_unmet=True, transition_to="MBM_node", ), MinTrials( threshold=2, only_in_statuses=[TrialStatus.COMPLETED], - block_gen_if_met=False, - block_transition_if_unmet=True, transition_to="MBM_node", ), ], @@ -1752,7 +1741,6 @@ def test_node_gs_with_auto_transitions(self) -> None: transition_criteria=[ AutoTransitionAfterGen( transition_to="mbm", - block_transition_if_unmet=True, continue_trial_generation=False, ) ], @@ -1816,7 +1804,6 @@ def test_gs_with_fixed_features_constructor(self) -> None: MinTrials( threshold=1, transition_to="sobol_2", - block_gen_if_met=True, only_in_statuses=None, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], ) @@ -1942,7 +1929,6 @@ def test_gs_with_input_constructor(self) -> None: MinTrials( threshold=1, transition_to="sobol_2", - block_gen_if_met=True, only_in_statuses=None, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], ) @@ -2000,6 +1986,28 @@ def test_gs_with_input_constructor(self) -> None: self.assertEqual(trial.generator_runs[1]._generation_node_name, "sobol_3") self.assertEqual(len(trial.generator_runs[1].arms), 8) + def test_optimization_complete_single_node_no_criteria(self) -> None: + """Test that a single node with no transition_criteria never completes.""" + exp = get_branin_experiment() + gs = GenerationStrategy( + nodes=[ + GenerationNode( + name="infinite sobol", + generator_specs=[self.sobol_generator_spec], + transition_criteria=[], # No criteria = infinite by design + ), + ] + ) + gs.experiment = exp + + # Generate many trials - never completes + for _ in range(3): + self.assertFalse(gs.optimization_complete) + gr = gs.gen_single_trial(experiment=exp) + exp.new_trial(generator_run=gr).mark_running(no_runner_required=True) + + self.assertFalse(gs.optimization_complete) + # ------------- Testing helpers (put tests above this line) ------------- def _run_GS_for_N_rounds( diff --git a/ax/generation_strategy/tests/test_transition_criterion.py b/ax/generation_strategy/tests/test_transition_criterion.py index c34850ca4c0..1d63d30e170 100644 --- a/ax/generation_strategy/tests/test_transition_criterion.py +++ b/ax/generation_strategy/tests/test_transition_criterion.py @@ -11,7 +11,8 @@ from ax.adapter.registry import Generators from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose from ax.core.trial_status import TrialStatus -from ax.exceptions.core import UserInputError +from ax.exceptions.core import DataRequiredError, UserInputError +from ax.exceptions.generation_strategy import MaxParallelismReachedException from ax.generation_strategy.generation_strategy import ( GenerationNode, GenerationStep, @@ -23,6 +24,7 @@ AuxiliaryExperimentCheck, IsSingleObjective, MaxGenerationParallelism, + MaxTrialsAwaitingData, MinTrials, ) from ax.utils.common.logger import get_logger @@ -167,7 +169,6 @@ def test_default_step_criterion_setup(self) -> None: step_0_expected_transition_criteria = [ MinTrials( threshold=3, - block_gen_if_met=True, transition_to="GenerationStep_1_BoTorch", only_in_statuses=None, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], @@ -176,7 +177,6 @@ def test_default_step_criterion_setup(self) -> None: step_1_expected_transition_criteria = [ MinTrials( threshold=4, - block_gen_if_met=False, transition_to="GenerationStep_2_BoTorch", only_in_statuses=None, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], @@ -186,12 +186,11 @@ def test_default_step_criterion_setup(self) -> None: threshold=2, transition_to="GenerationStep_2_BoTorch", ), + ] + step_1_expected_blocking_criteria = [ MaxGenerationParallelism( threshold=1, only_in_statuses=[TrialStatus.RUNNING], - block_gen_if_met=True, - block_transition_if_unmet=False, - transition_to="GenerationStep_1_BoTorch", ), ] step_2_expected_transition_criteria = [] @@ -201,6 +200,9 @@ def test_default_step_criterion_setup(self) -> None: self.assertEqual( gs._nodes[1].transition_criteria, step_1_expected_transition_criteria ) + self.assertEqual( + gs._nodes[1].generation_blocking_criteria, step_1_expected_blocking_criteria + ) self.assertEqual( gs._nodes[2].transition_criteria, step_2_expected_transition_criteria ) @@ -391,12 +393,9 @@ def test_trials_from_node_empty(self) -> None: max_criterion_with_status = MinTrials( threshold=2, transition_to="next_node", - block_gen_if_met=True, only_in_statuses=[TrialStatus.COMPLETED], ) - max_criterion = MinTrials( - threshold=2, transition_to="next_node", block_gen_if_met=True - ) + max_criterion = MinTrials(threshold=2, transition_to="next_node") self.assertFalse( max_criterion.is_met(experiment=experiment, curr_node=gs._nodes[0]) ) @@ -428,8 +427,6 @@ def test_repr(self) -> None: min_trials_criterion = MinTrials( threshold=5, transition_to="GenerationStep_1", - block_gen_if_met=True, - block_transition_if_unmet=False, only_in_statuses=[TrialStatus.COMPLETED], not_in_statuses=[TrialStatus.FAILED], ) @@ -439,8 +436,6 @@ def test_repr(self) -> None: + "'transition_to': 'GenerationStep_1', " + "'only_in_statuses': [.COMPLETED], " + "'not_in_statuses': [.FAILED], " - + "'block_transition_if_unmet': False, " - + "'block_gen_if_met': True, " + "'use_all_trials_in_exp': False, " + "'continue_trial_generation': False, " + "'count_only_trials_with_data': False})", @@ -449,8 +444,6 @@ def test_repr(self) -> None: threshold=0, transition_to="GenerationStep_2", only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED], - block_gen_if_met=True, - block_transition_if_unmet=False, not_in_statuses=[TrialStatus.FAILED], ) self.assertEqual( @@ -460,8 +453,6 @@ def test_repr(self) -> None: + "'only_in_statuses': " + "[.COMPLETED, .EARLY_STOPPED], " + "'not_in_statuses': [.FAILED], " - + "'block_transition_if_unmet': False, " - + "'block_gen_if_met': True, " + "'use_all_trials_in_exp': False, " + "'continue_trial_generation': False, " + "'count_only_trials_with_data': False})", @@ -469,27 +460,54 @@ def test_repr(self) -> None: max_parallelism = MaxGenerationParallelism( only_in_statuses=[TrialStatus.EARLY_STOPPED], threshold=3, - transition_to="GenerationStep_2", - block_gen_if_met=True, - block_transition_if_unmet=False, not_in_statuses=[TrialStatus.FAILED], ) self.assertEqual( str(max_parallelism), "MaxGenerationParallelism({'threshold': 3, " - + "'transition_to': 'GenerationStep_2', " + "'only_in_statuses': " + "[.EARLY_STOPPED], " + "'not_in_statuses': [.FAILED], " - + "'block_transition_if_unmet': False, " - + "'block_gen_if_met': True, " + "'use_all_trials_in_exp': False, " - + "'continue_trial_generation': False})", + + "'count_only_trials_with_data': False})", ) auto_transition = AutoTransitionAfterGen(transition_to="GenerationStep_2") self.assertEqual( str(auto_transition), "AutoTransitionAfterGen({'transition_to': 'GenerationStep_2', " - + "'block_transition_if_unmet': True, " + "'continue_trial_generation': True})", ) + + +class TestGenerationBlockingCriterion(TestCase): + """Tests for new GenerationBlockingCriterion classes.""" + + def setUp(self) -> None: + super().setUp() + self.experiment = get_branin_experiment() + + def test_max_trials_awaiting_data(self) -> None: + with self.subTest("default_not_in_statuses"): + criterion = MaxTrialsAwaitingData(threshold=10) + self.assertEqual( + criterion.not_in_statuses, + [TrialStatus.FAILED, TrialStatus.ABANDONED], + ) + + with self.subTest("block_continued_generation_error"): + criterion = MaxTrialsAwaitingData(threshold=3) + with self.assertRaises(DataRequiredError): + criterion.block_continued_generation_error( + node_name="test", experiment=self.experiment, trials_from_node=set() + ) + + def test_max_generation_parallelism_block_error(self) -> None: + criterion = MaxGenerationParallelism( + threshold=2, only_in_statuses=[TrialStatus.RUNNING] + ) + with self.assertRaises(MaxParallelismReachedException): + criterion.block_continued_generation_error( + node_name="test", + experiment=self.experiment, + trials_from_node={0, 1, 2}, + ) diff --git a/ax/generation_strategy/transition_criterion.py b/ax/generation_strategy/transition_criterion.py index 7b4aceb6c46..41706325dab 100644 --- a/ax/generation_strategy/transition_criterion.py +++ b/ax/generation_strategy/transition_criterion.py @@ -32,6 +32,325 @@ ) +# ============================================================================ +# Trial Counting Utility Functions +# ============================================================================ + + +def get_trials_by_status( + experiment: Experiment, statuses: list[TrialStatus] +) -> set[int]: + """Get trial indices from the experiment with the specified statuses. + + Args: + experiment: The experiment to query. + statuses: The trial statuses to filter on. + + Returns: + Set of trial indices with the specified statuses. + """ + trials_with_statuses = set() + for status in statuses: + trials_with_statuses = trials_with_statuses.union( + experiment.trial_indices_by_status[status] + ) + return trials_with_statuses + + +def filter_trials_by_status( + experiment: Experiment, + only_in_statuses: list[TrialStatus] | None = None, + not_in_statuses: list[TrialStatus] | None = None, +) -> set[int]: + """Filter trial indices by status inclusion/exclusion. + + Args: + experiment: The experiment to query. + only_in_statuses: If provided, only include trials with these statuses. + not_in_statuses: If provided, exclude trials with these statuses. + + Returns: + Set of trial indices matching the filter criteria. + """ + trials_to_check = set(experiment.trials.keys()) + if only_in_statuses is not None: + trials_to_check = get_trials_by_status( + experiment=experiment, statuses=only_in_statuses + ) + if not_in_statuses is not None: + trials_to_check -= get_trials_by_status( + experiment=experiment, statuses=not_in_statuses + ) + return trials_to_check + + +def count_trials_toward_threshold( + experiment: Experiment, + trials_from_node: set[int], + only_in_statuses: list[TrialStatus] | None = None, + not_in_statuses: list[TrialStatus] | None = None, + use_all_trials_in_exp: bool = False, + count_only_trials_with_data: bool = False, +) -> int: + """Count trials contributing toward a threshold. + + Args: + experiment: The experiment to query. + trials_from_node: Set of trial indices generated by the current node. + only_in_statuses: If provided, only count trials with these statuses. + not_in_statuses: If provided, exclude trials with these statuses. + use_all_trials_in_exp: If True, count all trials in the experiment. + Otherwise, only count trials from the current node. + count_only_trials_with_data: If True, only count trials with data. + + Returns: + The number of trials contributing to the threshold. + """ + all_trials_to_check = filter_trials_by_status( + experiment=experiment, + only_in_statuses=only_in_statuses, + not_in_statuses=not_in_statuses, + ) + if count_only_trials_with_data: + data_trial_indices = get_trial_indices_with_required_metrics( + experiment=experiment, + df=experiment.lookup_data().df, + require_data_for_all_metrics=False, + ) + all_trials_to_check = all_trials_to_check.intersection(data_trial_indices) + + if use_all_trials_in_exp: + return len(all_trials_to_check) + + return len(trials_from_node.intersection(all_trials_to_check)) + + +# ============================================================================ +# GenerationBlockingCriterion - for blocking generation without transitioning +# ============================================================================ + + +class GenerationBlockingCriterion(SortableBase): + """A criterion that blocks generation from a GenerationNode without triggering + a transition to another node. + """ + + @abstractmethod + def is_met( + self, + experiment: Experiment, + curr_node: GenerationNode, + ) -> bool: + """Returns True if this criterion's condition is met.""" + pass + + @abstractmethod + def block_continued_generation_error( + self, + node_name: str, + experiment: Experiment, + trials_from_node: set[int], + ) -> None: + """Raises an appropriate error when generation is blocked.""" + pass + + @property + def criterion_class(self) -> str: + """Name of the class of this GenerationBlockingCriterion.""" + return self.__class__.__name__ + + def __repr__(self) -> str: + return f"{self.criterion_class}({serialize_init_args(obj=self)})" + + @property + def _unique_id(self) -> str: + """Unique id for this GenerationBlockingCriterion.""" + return str(self) + + +class TrialCountBlockingCriterion(GenerationBlockingCriterion): + """Abstract base class for blocking criteria based on trial count thresholds. + + This class provides shared logic for blocking criteria that count trials toward + a threshold. Subclasses only need to implement `block_continued_generation_error()` + to define the specific error raised when generation is blocked. + + Args: + threshold: The maximum number of trials allowed before blocking generation. + only_in_statuses: A list of trial statuses to filter on when checking the + criterion threshold. + not_in_statuses: A list of trial statuses to exclude when checking the + criterion threshold. + use_all_trials_in_exp: A flag to use all trials in the experiment, instead of + only those generated by the current GenerationNode. + count_only_trials_with_data: If set to True, only trials with data will be + counted towards the ``threshold``. Defaults to False. + """ + + def __init__( + self, + threshold: int, + only_in_statuses: list[TrialStatus] | None = None, + not_in_statuses: list[TrialStatus] | None = None, + use_all_trials_in_exp: bool = False, + count_only_trials_with_data: bool = False, + ) -> None: + self.threshold = threshold + self.only_in_statuses = only_in_statuses + self.not_in_statuses = not_in_statuses + self.use_all_trials_in_exp = use_all_trials_in_exp + self.count_only_trials_with_data = count_only_trials_with_data + + def num_contributing_to_threshold( + self, experiment: Experiment, trials_from_node: set[int] + ) -> int: + """Returns the number of trials contributing to the threshold.""" + return count_trials_toward_threshold( + experiment=experiment, + trials_from_node=trials_from_node, + only_in_statuses=self.only_in_statuses, + not_in_statuses=self.not_in_statuses, + use_all_trials_in_exp=self.use_all_trials_in_exp, + count_only_trials_with_data=self.count_only_trials_with_data, + ) + + def num_till_threshold( + self, experiment: Experiment, trials_from_node: set[int] + ) -> int: + """Returns the number of trials available before hitting the threshold.""" + return self.threshold - self.num_contributing_to_threshold( + experiment=experiment, trials_from_node=trials_from_node + ) + + def is_met( + self, + experiment: Experiment, + curr_node: GenerationNode, + ) -> bool: + """Returns True if the trial count threshold has been reached.""" + return ( + self.num_contributing_to_threshold( + experiment=experiment, trials_from_node=curr_node.trials_from_node + ) + >= self.threshold + ) + + @abstractmethod + def block_continued_generation_error( + self, + node_name: str, + experiment: Experiment, + trials_from_node: set[int], + ) -> None: + """Raises an appropriate error when generation is blocked. + + Subclasses must implement this to define the specific error behavior. + """ + pass + + +class MaxGenerationParallelism(TrialCountBlockingCriterion): + """A GenerationBlockingCriterion that blocks generation after a maximum number + of trials have been generated for the current GenerationNode and are currently + running. + + Args: + threshold: The maximum number of trials allowed in the specified statuses. + only_in_statuses: A list of trial statuses to filter on when checking the + criterion threshold. + not_in_statuses: A list of trial statuses to exclude when checking the + criterion threshold. + use_all_trials_in_exp: A flag to use all trials in the experiment, instead of + only those generated by the current GenerationNode. + count_only_trials_with_data: If set to True, only trials with data will be + counted towards the ``threshold``. Defaults to False. + """ + + def __init__( + self, + threshold: int, + only_in_statuses: list[TrialStatus] | None = None, + not_in_statuses: list[TrialStatus] | None = None, + use_all_trials_in_exp: bool = False, + count_only_trials_with_data: bool = False, + ) -> None: + super().__init__( + threshold=threshold, + only_in_statuses=only_in_statuses, + not_in_statuses=not_in_statuses, + use_all_trials_in_exp=use_all_trials_in_exp, + count_only_trials_with_data=count_only_trials_with_data, + ) + + def block_continued_generation_error( + self, + node_name: str, + experiment: Experiment, + trials_from_node: set[int], + ) -> None: + """Raises MaxParallelismReachedException.""" + raise MaxParallelismReachedException( + node_name=node_name, + num_running=self.num_contributing_to_threshold( + experiment=experiment, trials_from_node=trials_from_node + ), + ) + + +class MaxTrialsAwaitingData(TrialCountBlockingCriterion): + """A GenerationBlockingCriterion that blocks generation after a maximum number + of trials have been generated, waiting for data before allowing more generation. + + This criterion blocks generation from the associated GenerationNode when the + threshold is met, but does NOT trigger a transition to another node. Use this + when you want to enforce that a node generates at most a certain number of + trials before requiring data. + + Args: + threshold: The maximum number of trials allowed before blocking generation. + only_in_statuses: A list of trial statuses to filter on when checking the + criterion threshold. + not_in_statuses: A list of trial statuses to exclude when checking the + criterion threshold. Defaults to [FAILED, ABANDONED]. + use_all_trials_in_exp: A flag to use all trials in the experiment, instead of + only those generated by the current GenerationNode. Defaults to True. + count_only_trials_with_data: If set to True, only trials with data will be + counted towards the ``threshold``. Defaults to False. + """ + + def __init__( + self, + threshold: int, + only_in_statuses: list[TrialStatus] | None = None, + not_in_statuses: list[TrialStatus] | None = None, + use_all_trials_in_exp: bool = False, + count_only_trials_with_data: bool = False, + ) -> None: + if not_in_statuses is None: + not_in_statuses = [TrialStatus.FAILED, TrialStatus.ABANDONED] + super().__init__( + threshold=threshold, + only_in_statuses=only_in_statuses, + not_in_statuses=not_in_statuses, + use_all_trials_in_exp=use_all_trials_in_exp, + count_only_trials_with_data=count_only_trials_with_data, + ) + + def block_continued_generation_error( + self, + node_name: str, + experiment: Experiment, + trials_from_node: set[int], + ) -> None: + """Raises DataRequiredError when the trial threshold is reached.""" + raise DataRequiredError(DATA_REQUIRED_MSG.format(node_name=node_name)) + + +# ============================================================================ +# TransitionCriterion - for node transitions +# ============================================================================ + + class TransitionCriterion(SortableBase): """ Simple class to describe a condition which must be met for this GenerationNode to @@ -40,16 +359,6 @@ class TransitionCriterion(SortableBase): Args: transition_to: The name of the GenerationNode the GenerationStrategy should transition to when this criterion is met. - block_gen_if_met: A flag to prevent continued generation from the - associated GenerationNode if this criterion is met but other criterion - remain unmet. Ex: ``MinTrials`` has not been met yet, but - MinTrials has been reached. If this flag is set to true on MinTrials then - we will raise an error, otherwise we will continue to generate trials - until ``MinTrials`` is met (thus overriding MinTrials). - block_transition_if_unmet: A flag to prevent the node from completing and - being able to transition to another node. Ex: MaxGenerationParallelism - defaults to setting this to False since we can complete and move on from - this node without ever reaching its threshold. continue_trial_generation: A flag to indicate that all generation for a given trial is not completed, and thus even after transition, the next node will continue to generate arms for the same trial. Example usage: in @@ -62,13 +371,9 @@ class TransitionCriterion(SortableBase): def __init__( self, transition_to: str, - block_transition_if_unmet: bool | None = True, - block_gen_if_met: bool | None = False, continue_trial_generation: bool | None = False, ) -> None: self._transition_to = transition_to - self.block_transition_if_unmet = block_transition_if_unmet - self.block_gen_if_met = block_gen_if_met self.continue_trial_generation = continue_trial_generation @property @@ -87,16 +392,6 @@ def is_met( """If the criterion of this TransitionCriterion is met, returns True.""" pass - @abstractmethod - def block_continued_generation_error( - self, - node_name: str, - experiment: Experiment, - trials_from_node: set[int], - ) -> None: - """Error to be raised if the `block_gen_if_met` flag is set to True.""" - pass - @property def criterion_class(self) -> str: """Name of the class of this TransitionCriterion.""" @@ -117,10 +412,6 @@ class AutoTransitionAfterGen(TransitionCriterion): Args: transition_to: The name of the GenerationNode the GenerationStrategy should transition to next. - block_transition_if_unmet: A flag to prevent the node from completing and - being able to transition to another node. Ex: This criterion defaults to - setting this to True to ensure we validate a GeneratorRun is generated by - the current GenerationNode. continue_trial_generation: A flag to indicate that all generation for a given trial is not completed, and thus even after transition, the next node will continue to generate arms for the same trial. Example usage: in @@ -131,12 +422,10 @@ class AutoTransitionAfterGen(TransitionCriterion): def __init__( self, transition_to: str, - block_transition_if_unmet: bool | None = True, continue_trial_generation: bool | None = True, ) -> None: super().__init__( transition_to=transition_to, - block_transition_if_unmet=block_transition_if_unmet, continue_trial_generation=continue_trial_generation, ) @@ -161,15 +450,6 @@ def is_met( else False ) - def block_continued_generation_error( - self, - node_name: str, - experiment: Experiment, - trials_from_node: set[int], - ) -> None: - """Error to be raised if the `block_gen_if_met` flag is set to True.""" - pass - class IsSingleObjective(TransitionCriterion): """A class to initiate transition based on whether the experiment is optimizing @@ -178,10 +458,6 @@ class IsSingleObjective(TransitionCriterion): Args: transition_to: The name of the GenerationNode the GenerationStrategy should transition to next. - block_transition_if_unmet: A flag to prevent the node from completing and - being able to transition to another node. Ex: This criterion defaults to - setting this to True to ensure we validate a GeneratorRun is generated by - the current GenerationNode. continue_trial_generation: A flag to indicate that all generation for a given trial is not completed, and thus even after transition, the next node will continue to generate arms for the same trial. Example usage: in @@ -192,12 +468,10 @@ class IsSingleObjective(TransitionCriterion): def __init__( self, transition_to: str, - block_transition_if_unmet: bool | None = True, continue_trial_generation: bool | None = False, ) -> None: super().__init__( transition_to=transition_to, - block_transition_if_unmet=block_transition_if_unmet, continue_trial_generation=continue_trial_generation, ) @@ -216,14 +490,6 @@ def is_met( else True ) - def block_continued_generation_error( - self, - node_name: str, - experiment: Experiment, - trials_from_node: set[int], - ) -> None: - pass - class TrialBasedCriterion(TransitionCriterion): """Common class for transition criterion that are based on trial information. @@ -231,16 +497,6 @@ class TrialBasedCriterion(TransitionCriterion): Args: threshold: The threshold as an integer for this criterion. Ex: If we want to generate at most 3 trials, then the threshold is 3. - block_transition_if_unmet: A flag to prevent the node from completing and - being able to transition to another node. Ex: MaxGenerationParallelism - defaults to setting this to False since we can complete and move on from - this node without ever reaching its threshold. - block_gen_if_met: A flag to prevent continued generation from the - associated GenerationNode if this criterion is met but other criterion - remain unmet. Ex: ``MinTrials`` has not been met yet, but - MinTrials has been reached. If this flag is set to true on MinTrials then - we will raise an error, otherwise we will continue to generate trials - until ``MinTrials`` is met (thus overriding MinTrials). only_in_statuses: A list of trial statuses to filter on when checking the criterion threshold. not_in_statuses: A list of trial statuses to exclude when checking the @@ -263,8 +519,6 @@ def __init__( self, threshold: int, transition_to: str, - block_transition_if_unmet: bool | None = True, - block_gen_if_met: bool | None = False, only_in_statuses: list[TrialStatus] | None = None, not_in_statuses: list[TrialStatus] | None = None, use_all_trials_in_exp: bool | None = False, @@ -278,72 +532,35 @@ def __init__( self.count_only_trials_with_data = count_only_trials_with_data super().__init__( transition_to=transition_to, - block_transition_if_unmet=block_transition_if_unmet, - block_gen_if_met=block_gen_if_met, continue_trial_generation=continue_trial_generation, ) def experiment_trials_by_status( self, experiment: Experiment, statuses: list[TrialStatus] ) -> set[int]: - """Get the trial indices from the entire experiment with the desired - statuses. - - Args: - experiment: The experiment associated with this GenerationStrategy. - statuses: The trial statuses to filter on. - Returns: - The trial indices in the experiment with the desired statuses. - """ - exp_trials_with_statuses = set() - for status in statuses: - exp_trials_with_statuses = exp_trials_with_statuses.union( - experiment.trial_indices_by_status[status] - ) - return exp_trials_with_statuses + """Get the trial indices from the experiment with the desired statuses.""" + return get_trials_by_status(experiment=experiment, statuses=statuses) def all_trials_to_check(self, experiment: Experiment) -> set[int]: - """All the trials to check from the entire experiment that meet - all the provided status filters. - - Args: - experiment: The experiment associated with this GenerationStrategy. - """ - trials_to_check = set(experiment.trials.keys()) - if self.only_in_statuses is not None: - trials_to_check = self.experiment_trials_by_status( - experiment=experiment, statuses=self.only_in_statuses - ) - # exclude the trials to those not in the specified statuses - if self.not_in_statuses is not None: - trials_to_check -= self.experiment_trials_by_status( - experiment=experiment, statuses=self.not_in_statuses - ) - return trials_to_check + """All the trials to check that meet the provided status filters.""" + return filter_trials_by_status( + experiment=experiment, + only_in_statuses=self.only_in_statuses, + not_in_statuses=self.not_in_statuses, + ) def num_contributing_to_threshold( self, experiment: Experiment, trials_from_node: set[int] ) -> int: - """Returns the number of trials contributing to the threshold. - - Args: - experiment: The experiment associated with this GenerationStrategy. - trials_from_node: The set of trials generated by this GenerationNode. - """ - all_trials_to_check = self.all_trials_to_check(experiment=experiment) - if self.count_only_trials_with_data: - data_trial_indices = get_trial_indices_with_required_metrics( - experiment=experiment, - df=experiment.lookup_data().df, - require_data_for_all_metrics=False, - ) - all_trials_to_check = all_trials_to_check.intersection(data_trial_indices) - # Some criteria may rely on experiment level data, instead of only trials - # generated from the node associated with the criterion. - if self.use_all_trials_in_exp: - return len(all_trials_to_check) - - return len(trials_from_node.intersection(all_trials_to_check)) + """Returns the number of trials contributing to the threshold.""" + return count_trials_toward_threshold( + experiment=experiment, + trials_from_node=trials_from_node, + only_in_statuses=self.only_in_statuses, + not_in_statuses=self.not_in_statuses, + use_all_trials_in_exp=bool(self.use_all_trials_in_exp), + count_only_trials_with_data=self.count_only_trials_with_data, + ) def num_till_threshold( self, experiment: Experiment, trials_from_node: set[int] @@ -382,93 +599,10 @@ def is_met( ) -class MaxGenerationParallelism(TrialBasedCriterion): - """Specific TransitionCriterion implementation which defines the maximum number - of trials that can simultaneously be in the designated trial statuses. The - default behavior is to block generation from the associated GenerationNode if the - threshold is met. This is configured via the `block_gen_if_met` flag being set to - True. This criterion defaults to not blocking transition to another node via the - `block_transition_if_unmet` flag being set to False. - - Args: - threshold: The threshold as an integer for this criterion. Ex: If we want to - generate at most 3 trials, then the threshold is 3. - only_in_statuses: A list of trial statuses to filter on when checking the - criterion threshold. - not_in_statuses: A list of trial statuses to exclude when checking the - criterion threshold. - transition_to: The name of the GenerationNode the GenerationStrategy should - transition to when this criterion is met, if it exists. - block_transition_if_unmet: A flag to prevent the node from completing and - being able to transition to another node. Ex: MaxGenerationParallelism - defaults to setting this to False since we can complete and move on from - this node without ever reaching its threshold. - block_gen_if_met: A flag to prevent continued generation from the - associated GenerationNode if this criterion is met but other criterion - remain unmet. Ex: ``MinTrials`` has not been met yet, but - MinTrials has been reached. If this flag is set to true on MinTrials then - we will raise an error, otherwise we will continue to generate trials - until ``MinTrials`` is met (thus overriding MinTrials). - use_all_trials_in_exp: A flag to use all trials in the experiment, instead of - only those generated by the current GenerationNode. - continue_trial_generation: A flag to indicate that all generation for a given - trial is not completed, and thus even after transition, the next node will - continue to generate arms for the same trial. Example usage: in - ``BatchTrial``s we may enable generation of arms within a batch from - different ``GenerationNodes`` by setting this flag to True. Defaults to - False for MaxGenerationParallelism since this criterion isn't currently - used for node -> node or trial -> trial transition. - count_only_trials_with_data: If set to True, only trials with data will be - counted towards the ``threshold``. Defaults to False. - """ - - def __init__( - self, - threshold: int, - transition_to: str, - only_in_statuses: list[TrialStatus] | None = None, - not_in_statuses: list[TrialStatus] | None = None, - block_transition_if_unmet: bool | None = False, - block_gen_if_met: bool | None = True, - use_all_trials_in_exp: bool | None = False, - continue_trial_generation: bool | None = False, - ) -> None: - super().__init__( - threshold=threshold, - only_in_statuses=only_in_statuses, - not_in_statuses=not_in_statuses, - transition_to=transition_to, - block_gen_if_met=block_gen_if_met, - block_transition_if_unmet=block_transition_if_unmet, - use_all_trials_in_exp=use_all_trials_in_exp, - continue_trial_generation=continue_trial_generation, - ) - - def block_continued_generation_error( - self, - node_name: str, - experiment: Experiment, - trials_from_node: set[int], - ) -> None: - """Raises the appropriate error (should only be called when the - ``GenerationNode`` is blocked from continued generation). For this - class, the exception is ``MaxParallelismReachedException``. - """ - assert self.block_gen_if_met # Sanity check. - raise MaxParallelismReachedException( - node_name=node_name, - num_running=self.num_contributing_to_threshold( - experiment=experiment, trials_from_node=trials_from_node - ), - ) - - class MinTrials(TrialBasedCriterion): """ Simple class to enforce a minimum threshold for the number of trials with the - designated statuses being generated by a specific GenerationNode. The default - behavior is to block transition to the next node if the threshold is unmet, but - not affect continued generation. + designated statuses being generated by a specific GenerationNode. Args: threshold: The threshold as an integer for this criterion. Ex: If we want to @@ -479,16 +613,6 @@ class MinTrials(TrialBasedCriterion): criterion threshold. transition_to: The name of the GenerationNode the GenerationStrategy should transition to when this criterion is met. - block_transition_if_unmet: A flag to prevent the node from completing and - being able to transition to another node. Ex: MaxGenerationParallelism - defaults to setting this to False since we can complete and move on from - this node without ever reaching its threshold. - block_gen_if_met: A flag to prevent continued generation from the - associated GenerationNode if this criterion is met but other criterion - remain unmet. Ex: ``MinTrials`` has not been met yet, but - MinTrials has been reached. If this flag is set to true on MinTrials then - we will raise an error, otherwise we will continue to generate trials - until ``MinTrials`` is met (thus overriding MinTrials). use_all_trials_in_exp: A flag to use all trials in the experiment, instead of only those generated by the current GenerationNode. continue_trial_generation: A flag to indicate that all generation for a given @@ -506,8 +630,6 @@ def __init__( transition_to: str, only_in_statuses: list[TrialStatus] | None = None, not_in_statuses: list[TrialStatus] | None = None, - block_transition_if_unmet: bool | None = True, - block_gen_if_met: bool | None = False, use_all_trials_in_exp: bool | None = False, continue_trial_generation: bool | None = False, count_only_trials_with_data: bool = False, @@ -517,26 +639,11 @@ def __init__( transition_to=transition_to, only_in_statuses=only_in_statuses, not_in_statuses=not_in_statuses, - block_gen_if_met=block_gen_if_met, - block_transition_if_unmet=block_transition_if_unmet, use_all_trials_in_exp=use_all_trials_in_exp, continue_trial_generation=continue_trial_generation, count_only_trials_with_data=count_only_trials_with_data, ) - def block_continued_generation_error( - self, - node_name: str, - experiment: Experiment, - trials_from_node: set[int], - ) -> None: - """Raises the appropriate error (should only be called when the - ``GenerationNode`` is blocked from continued generation). For this - class, the exception is ``DataRequiredError``. - """ - assert self.block_gen_if_met # Sanity check. - raise DataRequiredError(DATA_REQUIRED_MSG.format(node_name=node_name)) - class AuxiliaryExperimentCheck(TransitionCriterion): """A class to transition from one GenerationNode to another by checking if certain @@ -563,14 +670,6 @@ class AuxiliaryExperimentCheck(TransitionCriterion): purpose we expect to not have. This can be helpful when need to transition out of a node based on AuxiliaryExperimentPurpose. Criterion is met when all inclusion and exclusion checks pass. - block_gen_if_met: A flag to prevent continued generation from the - associated GenerationNode if this criterion is met but other criteria - remain unmet. Defaults to False since auxiliary experiment checks are - typically used for transition logic rather than blocking generation. - block_transition_if_unmet: A flag to prevent the node from completing and - being able to transition to another node. Ex: MaxGenerationParallelism - defaults to setting this to False since we can complete and move on from - this node without ever reaching its threshold. continue_trial_generation: A flag to indicate that all generation for a given trial is not completed, and thus even after transition, the next node will continue to generate arms for the same trial. Example usage: in @@ -587,14 +686,10 @@ def __init__( auxiliary_experiment_purposes_to_exclude: ( list[AuxiliaryExperimentPurpose] | None ) = None, - block_transition_if_unmet: bool | None = True, - block_gen_if_met: bool | None = False, continue_trial_generation: bool | None = False, ) -> None: super().__init__( transition_to=transition_to, - block_transition_if_unmet=block_transition_if_unmet, - block_gen_if_met=block_gen_if_met, continue_trial_generation=continue_trial_generation, ) @@ -651,11 +746,3 @@ def is_met( expected_aux_exp_purposes=self.auxiliary_experiment_purposes_to_exclude, ) return inclusion_check and exclusion_check - - def block_continued_generation_error( - self, - node_name: str, - experiment: Experiment, - trials_from_node: set[int], - ) -> None: - pass diff --git a/ax/orchestration/tests/test_orchestrator.py b/ax/orchestration/tests/test_orchestrator.py index 67f982ca1a1..d6e9d9f6fc1 100644 --- a/ax/orchestration/tests/test_orchestrator.py +++ b/ax/orchestration/tests/test_orchestrator.py @@ -159,12 +159,15 @@ class TestAxOrchestrator(TestCase): "generator_specs=[GeneratorSpec(generator_enum=Sobol, " "generator_key_override=None)], " "transition_criteria=[MinTrials(transition_to='GenerationStep_1_BoTorch'), " - "MinTrials(transition_to='GenerationStep_1_BoTorch')]), " + "MinTrials(transition_to='GenerationStep_1_BoTorch')], " + "generation_blocking_criteria=" + "[MaxTrialsAwaitingData(threshold=5)]), " "GenerationNode(name='GenerationStep_1_BoTorch', " "generator_specs=[GeneratorSpec(generator_enum=BoTorch, " "generator_key_override=None)], " - "transition_criteria=[MaxGenerationParallelism(" - "transition_to='GenerationStep_1_BoTorch')])]), " + "transition_criteria=None, " + "generation_blocking_criteria=" + "[MaxGenerationParallelism(threshold=3)])]), " "options=OrchestratorOptions(max_pending_trials=10, " "trial_type=, batch_size=None, " "total_trials=0, tolerated_trial_failure_rate=0.2, " @@ -1168,11 +1171,11 @@ def test_run_trials_and_yield_results_with_early_stopper(self) -> None: expected_num_polls = 2 self.assertEqual(len(res_list), expected_num_polls + 1) # Both trials in first batch of parallelism will be early stopped - # Extract max_parallelism from transition criteria + # Extract max_parallelism from generation_blocking_criteria node0_max_parallelism = None - for tc in self.two_sobol_steps_GS._nodes[0].transition_criteria: - if isinstance(tc, MaxGenerationParallelism): - node0_max_parallelism = tc.threshold + for bc in self.two_sobol_steps_GS._nodes[0].generation_blocking_criteria: + if isinstance(bc, MaxGenerationParallelism): + node0_max_parallelism = bc.threshold break self.assertEqual( len(res_list[0]["trials_early_stopped_so_far"]), @@ -2852,12 +2855,15 @@ class TestAxOrchestratorMultiTypeExperiment(TestAxOrchestrator): "generator_specs=[GeneratorSpec(generator_enum=Sobol, " "generator_key_override=None)], " "transition_criteria=[MinTrials(transition_to='GenerationStep_1_BoTorch'), " - "MinTrials(transition_to='GenerationStep_1_BoTorch')]), " + "MinTrials(transition_to='GenerationStep_1_BoTorch')], " + "generation_blocking_criteria=" + "[MaxTrialsAwaitingData(threshold=5)]), " "GenerationNode(name='GenerationStep_1_BoTorch', " "generator_specs=[GeneratorSpec(generator_enum=BoTorch, " "generator_key_override=None)], " - "transition_criteria=" - "[MaxGenerationParallelism(transition_to='GenerationStep_1_BoTorch')])]), " + "transition_criteria=None, " + "generation_blocking_criteria=" + "[MaxGenerationParallelism(threshold=3)])]), " "options=OrchestratorOptions(max_pending_trials=10, " "trial_type=, batch_size=None, " "total_trials=0, tolerated_trial_failure_rate=0.2, " diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 47a8f54ad36..7023692d495 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -865,11 +865,11 @@ def get_max_parallelism(self) -> list[tuple[int, int]]: """ parallelism_settings = [] for node in self.generation_strategy._nodes: - # Extract max_parallelism from MaxGenerationParallelism criterion + # Check generation_blocking_criteria for max parallelism max_parallelism = None - for tc in node.transition_criteria: - if isinstance(tc, MaxGenerationParallelism): - max_parallelism = tc.threshold + for bc in node.generation_blocking_criteria: + if isinstance(bc, MaxGenerationParallelism): + max_parallelism = bc.threshold break # Try to get num_trials from the node. If there's no MinTrials # criterion (unlimited trials), num_trials will raise UserInputError. diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index afe81f05d01..b9334d53a63 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -59,7 +59,7 @@ from ax.generation_strategy.generator_spec import GeneratorSpec from ax.generation_strategy.transition_criterion import ( MaxGenerationParallelism, - MinTrials, + MaxTrialsAwaitingData, ) from ax.metrics.branin import branin, BraninMetric from ax.runners.synthetic import SyntheticRunner @@ -1606,22 +1606,25 @@ def test_keep_generating_without_data(self) -> None: {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], ) - # Check that enforce_num_trials is False by checking the MinTrials criterion - # has block_gen_if_met=False - node0_min_trials = [ - tc - for tc in ax_client.generation_strategy._nodes[0].transition_criteria - if isinstance(tc, MinTrials) + # Check that enforce_num_trials is False by verifying no + # MaxTrialsAwaitingData exists in generation_blocking_criteria + node0_blocking_criteria = [ + bc + for bc in ax_client.generation_strategy._nodes[ + 0 + ].generation_blocking_criteria + if isinstance(bc, MaxTrialsAwaitingData) ] - self.assertTrue(len(node0_min_trials) > 0) - self.assertFalse(node0_min_trials[0].block_gen_if_met) + self.assertEqual(len(node0_blocking_criteria), 0) # Check that max_parallelism is None by verifying no MaxGenerationParallelism - # criterion exists on node 1 + # criterion exists in generation_blocking_criteria node1_max_parallelism = [ - tc - for tc in ax_client.generation_strategy._nodes[1].transition_criteria - if isinstance(tc, MaxGenerationParallelism) + bc + for bc in ax_client.generation_strategy._nodes[ + 1 + ].generation_blocking_criteria + if isinstance(bc, MaxGenerationParallelism) ] self.assertEqual(len(node1_max_parallelism), 0) diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 0beccd47c09..c08beca8742 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -16,7 +16,7 @@ from inspect import isclass from io import StringIO from logging import Logger -from typing import Any +from typing import Any, TypeVar import numpy as np import pandas as pd @@ -43,7 +43,11 @@ GenerationStrategy, ) from ax.generation_strategy.generator_spec import GeneratorSpec -from ax.generation_strategy.transition_criterion import MinTrials, TransitionCriterion +from ax.generation_strategy.transition_criterion import ( + GenerationBlockingCriterion, + MinTrials, + TransitionCriterion, +) from ax.generators.torch.botorch_modular.generator import BoTorchGenerator from ax.generators.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.generators.torch.botorch_modular.utils import ModelConfig @@ -70,7 +74,7 @@ from botorch.utils.types import DEFAULT from pyre_extensions import assert_is_instance, none_throws - +T = TypeVar("T") logger: Logger = get_logger(__name__) @@ -297,11 +301,7 @@ def object_from_json( object_json["outcome_transform_options"] = ( outcome_transform_options_json ) - elif ( - isclass(_class) - and issubclass(_class, TransitionCriterion) - and _class is not TransitionCriterion # TransitionCriterion is abstract - ): + elif isclass(_class) and issubclass(_class, TransitionCriterion): # TransitionCriterion may contain nested Ax objects (TrialStatus, etc.) # that need recursive deserialization via object_from_json. return transition_criterion_from_json( @@ -309,6 +309,13 @@ def object_from_json( object_json=object_json, **vars(registry_kwargs), ) + elif isclass(_class) and issubclass(_class, GenerationBlockingCriterion): + # GenerationBlockingCriterion is similar to TransitionCriterion + return generation_blocking_criterion_from_json( + blocking_criterion_class=_class, + object_json=object_json, + **vars(registry_kwargs), + ) elif isclass(_class) and issubclass(_class, SerializationMixin): # Special handling for Data backward compatibility if _class is Data: @@ -447,19 +454,37 @@ def generator_run_from_json( return generator_run +def _criterion_from_json( + criterion_class: type[T], + object_json: dict[str, Any], + decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, + class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, +) -> T: + """Generic helper to load criterion objects from JSON. + + Handles recursive deserialization of nested Ax objects and filters + to valid constructor arguments for backwards compatibility. + """ + decoded = { + key: object_from_json( + object_json=value, + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) + for key, value in object_json.items() + } + init_args = extract_init_args(args=decoded, class_=criterion_class) + # pyre-ignore[45]: Class passed is always a concrete subclass. + return criterion_class(**init_args) + + def transition_criterion_from_json( transition_criterion_class: type[TransitionCriterion], object_json: dict[str, Any], decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> TransitionCriterion: - """Load TransitionCriterion from JSON. - - TransitionCriterion subclasses may contain nested Ax objects (like TrialStatus - enums and AuxiliaryExperimentPurpose) that need recursive deserialization via - object_from_json. We also use extract_init_args for backwards compatibility, - filtering to only valid constructor arguments. - """ + """Load TransitionCriterion from JSON.""" # Handle deprecated MinimumTrialsInStatus -> MinTrials conversion if transition_criterion_class is MinTrials and "status" in object_json: logger.warning( @@ -478,20 +503,27 @@ def transition_criterion_from_json( use_all_trials_in_exp=True, ) - decoded = { - key: object_from_json( - object_json=value, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, - ) - for key, value in object_json.items() - } + return _criterion_from_json( + criterion_class=transition_criterion_class, + object_json=object_json, + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) - # filter to only valid constructor args (backwards compatibility) - init_args = extract_init_args(args=decoded, class_=transition_criterion_class) - # pyre-ignore[45]: Class passed is always a concrete subclass. - return transition_criterion_class(**init_args) +def generation_blocking_criterion_from_json( + blocking_criterion_class: type[GenerationBlockingCriterion], + object_json: dict[str, Any], + decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, + class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, +) -> GenerationBlockingCriterion: + """Load GenerationBlockingCriterion from JSON.""" + return _criterion_from_json( + criterion_class=blocking_criterion_class, + object_json=object_json, + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) def search_space_from_json( @@ -848,13 +880,47 @@ def generation_node_from_json( # if needed during _validate_and_set_step_sequence. generation_node_json.pop("step_index", None) - # Backwards compatibility: For transition criteria with transition_to=None - # set transition_to to point to itself. - transition_criteria_json = generation_node_json.pop("transition_criteria") + transition_criteria_json = generation_node_json.pop("transition_criteria", None) + generation_blocking_criteria_json = generation_node_json.pop( + "generation_blocking_criteria", None + ) + + # Backwards compatibility: For old experiments, TransitionCriterion with + # block_gen_if_met=True need to be migrated to GenerationBlockingCriterion. + # Also, transition_to=None needs to be handled for transition criteria, these + # now point to self. if transition_criteria_json is not None: + migrated_blocking_criteria = [] + migrated_transition_criteria_json = [] for tc_json in transition_criteria_json: + tc_type = tc_json.get("__type", "") + + if tc_json.get("block_gen_if_met") is True: + if tc_type == "MaxGenerationParallelism": + migrated_blocking_criteria.append(tc_json) + elif tc_type == "MinTrials": + # Copy and update type + blocking_json = tc_json.copy() + blocking_json["__type"] = "MaxTrialsAwaitingData" + migrated_blocking_criteria.append(blocking_json) + + # Only keep in transition_criteria if block_transition_if_unmet=True + if tc_json.get("block_transition_if_unmet") is not True: + continue # Skip adding to transition_criteria + + # handle transition_to=None if tc_json.get("transition_to") is None: tc_json["transition_to"] = name + migrated_transition_criteria_json.append(tc_json) + + # Merge migrated criteria with any existing generation_blocking_criteria + if generation_blocking_criteria_json is None: + generation_blocking_criteria_json = migrated_blocking_criteria + else: + generation_blocking_criteria_json = ( + migrated_blocking_criteria + generation_blocking_criteria_json + ) + transition_criteria_json = migrated_transition_criteria_json return GenerationNode( name=name, @@ -874,6 +940,11 @@ def generation_node_from_json( decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ), + generation_blocking_criteria=object_from_json( + object_json=generation_blocking_criteria_json, + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ), input_constructors=decoded_input_constructors, previous_node_name=( generation_node_json.pop("previous_node_name") diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index bfd6157f129..ca7267cb02c 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -50,7 +50,10 @@ GenerationStrategy, ) from ax.generation_strategy.generator_spec import GeneratorSpec -from ax.generation_strategy.transition_criterion import TransitionCriterion +from ax.generation_strategy.transition_criterion import ( + GenerationBlockingCriterion, + TransitionCriterion, +) from ax.generators.torch.botorch_modular.generator import BoTorchGenerator from ax.generators.torch.botorch_modular.surrogate import Surrogate from ax.generators.winsorization_config import WinsorizationConfig @@ -408,6 +411,7 @@ def generation_node_to_dict(generation_node: GenerationNode) -> dict[str, Any]: "best_model_selector": generation_node.best_model_selector, "should_deduplicate": generation_node.should_deduplicate, "transition_criteria": generation_node.transition_criteria, + "generation_blocking_criteria": generation_node.generation_blocking_criteria, "generator_spec_to_gen_from": generation_node._generator_spec_to_gen_from, "previous_node_name": generation_node._previous_node_name, "trial_type": generation_node._trial_type, @@ -445,6 +449,15 @@ def transition_criterion_to_dict(criterion: TransitionCriterion) -> dict[str, An return properties +def generation_blocking_criterion_to_dict( + criterion: GenerationBlockingCriterion, +) -> dict[str, Any]: + """Convert Ax GenerationBlockingCriterion to a dictionary.""" + properties = serialize_init_args(obj=criterion) + properties["__type"] = criterion.__class__.__name__ + return properties + + def generator_spec_to_dict(generator_spec: GeneratorSpec) -> dict[str, Any]: """Convert Ax model spec to a dictionary.""" return { diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index 842a90968be..561a4d09c49 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -77,6 +77,7 @@ AuxiliaryExperimentCheck, IsSingleObjective, MaxGenerationParallelism, + MaxTrialsAwaitingData, MinTrials, TransitionCriterion, ) @@ -125,6 +126,7 @@ derived_parameter_to_dict, experiment_to_dict, fixed_parameter_to_dict, + generation_blocking_criterion_to_dict, generation_node_to_dict, generation_strategy_to_dict, generator_run_to_dict, @@ -214,7 +216,8 @@ L2NormMetric: metric_to_dict, LogNormalPrior: botorch_component_to_dict, MapMetric: metric_to_dict, - MaxGenerationParallelism: transition_criterion_to_dict, + MaxGenerationParallelism: generation_blocking_criterion_to_dict, + MaxTrialsAwaitingData: generation_blocking_criterion_to_dict, Metric: metric_to_dict, MinTrials: transition_criterion_to_dict, AuxiliaryExperimentCheck: transition_criterion_to_dict, @@ -341,6 +344,7 @@ "MapMetric": MapMetric, "MaxTrials": MinTrials, "MaxGenerationParallelism": MaxGenerationParallelism, + "MaxTrialsAwaitingData": MaxTrialsAwaitingData, "Metric": Metric, "MinTrials": MinTrials, # DEPRECATED; backward compatibility for MinimumTrialsInStatus -> MinTrials diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index c36f0fb5b5e..ebf908e9fdf 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -1236,6 +1236,130 @@ def test_generation_node_backwards_compatibility(self) -> None: ) self.assertEqual(node.generator_specs[0].cv_kwargs, {"test_cv_kwarg": True}) + def test_block_gen_if_met_migration(self) -> None: + """Test that TransitionCriteria with block_gen_if_met=True are migrated + to GenerationBlockingCriterion during deserialization.""" + with self.subTest("MaxGenerationParallelism_with_block_gen_if_met"): + # MaxGenerationParallelism with block_gen_if_met=True should be + # migrated to generation_blocking_criteria and removed from + # transition_criteria + json = { + "node_name": "test_node", + "model_specs": [ + { + "__type": "GeneratorSpec", + "model_enum": {"__type": "Generators", "name": "SOBOL"}, + "model_kwargs": {}, + "model_gen_kwargs": {}, + "model_cv_kwargs": {}, + } + ], + "best_model_selector": None, + "should_deduplicate": False, + "transition_criteria": [ + { + "__type": "MaxGenerationParallelism", + "threshold": 3, + "only_in_statuses": [ + {"__type": "TrialStatus", "name": "RUNNING"} + ], + "block_gen_if_met": True, + "transition_to": None, + } + ], + } + node = generation_node_from_json(json) + self.assertEqual(len(node.transition_criteria), 0) + self.assertEqual(len(node.generation_blocking_criteria), 1) + blocking = node.generation_blocking_criteria[0] + self.assertEqual(blocking.__class__.__name__, "MaxGenerationParallelism") + # pyre-ignore[16]: Attribute exists on MaxGenerationParallelism + self.assertEqual(blocking.threshold, 3) + + with self.subTest("MinTrials_with_block_gen_if_met_only"): + # MinTrials with block_gen_if_met=True only should be migrated to + # MaxTrialsAwaitingData and removed from transition_criteria + json = { + "node_name": "test_node", + "model_specs": [ + { + "__type": "GeneratorSpec", + "model_enum": {"__type": "Generators", "name": "SOBOL"}, + "model_kwargs": {}, + "model_gen_kwargs": {}, + "model_cv_kwargs": {}, + } + ], + "best_model_selector": None, + "should_deduplicate": False, + "transition_criteria": [ + { + "__type": "MinTrials", + "threshold": 5, + "only_in_statuses": [ + {"__type": "TrialStatus", "name": "RUNNING"} + ], + "not_in_statuses": None, + "use_all_trials_in_exp": False, + "block_gen_if_met": True, + "block_transition_if_unmet": False, + "transition_to": None, + } + ], + } + node = generation_node_from_json(json) + self.assertEqual(len(node.transition_criteria), 0) + self.assertEqual(len(node.generation_blocking_criteria), 1) + blocking = node.generation_blocking_criteria[0] + self.assertEqual(blocking.__class__.__name__, "MaxTrialsAwaitingData") + self.assertEqual(blocking.threshold, 5) + + with self.subTest("MinTrials_with_block_gen_if_met_and_block_transition"): + # MinTrials with both block_gen_if_met=True and + # block_transition_if_unmet=True should create + # MaxTrialsAwaitingData AND keep in transition_criteria + json = { + "node_name": "test_node", + "model_specs": [ + { + "__type": "GeneratorSpec", + "model_enum": {"__type": "Generators", "name": "SOBOL"}, + "model_kwargs": {}, + "model_gen_kwargs": {}, + "model_cv_kwargs": {}, + } + ], + "best_model_selector": None, + "should_deduplicate": False, + "transition_criteria": [ + { + "__type": "MinTrials", + "threshold": 5, + "only_in_statuses": None, + "not_in_statuses": None, + "use_all_trials_in_exp": True, + "block_gen_if_met": True, + "block_transition_if_unmet": True, + "transition_to": "next_node", + } + ], + } + node = generation_node_from_json(json) + # Should have both + self.assertEqual(len(node.transition_criteria), 1) + self.assertEqual(len(node.generation_blocking_criteria), 1) + tc = node.transition_criteria[0] + self.assertEqual(tc.__class__.__name__, "MinTrials") + # pyre-ignore[16]: Attribute exists on MinTrials + self.assertEqual(tc.threshold, 5) + self.assertEqual(tc.transition_to, "next_node") + blocking = node.generation_blocking_criteria[0] + self.assertEqual(blocking.__class__.__name__, "MaxTrialsAwaitingData") + # pyre-ignore[16]: threshold exists on MaxTrialsAwaitingData + self.assertEqual(blocking.threshold, 5) + # pyre-ignore[16]: use_all_trials_in_exp exists on MaxTrialsAwaitingData + self.assertTrue(blocking.use_all_trials_in_exp) + def test_SobolQMCNormalSampler(self) -> None: # This fails default equality checks, so testing it separately. sampler = SobolQMCNormalSampler(sample_shape=torch.Size([2])) diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index d0a9df77ce4..4e277da9d2a 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -55,13 +55,17 @@ ) from ax.exceptions.storage import JSONDecodeError, SQADecodeError, SQAEncodeError from ax.generation_strategy.dispatch_utils import choose_generation_strategy_legacy -from ax.generation_strategy.transition_criterion import MaxGenerationParallelism +from ax.generation_strategy.transition_criterion import ( + MaxGenerationParallelism, + MaxTrialsAwaitingData, + MinTrials, +) from ax.generators.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.metrics.branin import BraninMetric from ax.runners.synthetic import SyntheticRunner from ax.storage.json_store.decoder import ( + generation_blocking_criterion_from_json, generation_node_from_json, - transition_criterion_from_json, ) from ax.storage.json_store.registry import ( CORE_CLASS_DECODER_REGISTRY, @@ -3397,7 +3401,10 @@ def test_load_candidate_source_auxiliary_experiments(self) -> None: def test_transition_criterion_deserialize_with_extra_fields(self) -> None: """Test that deserialization gracefully handles extra/unknown fields - ie this validates that backwards compatibility is maintained""" + ie this validates that backwards compatibility is maintained for + MaxGenerationParallelism, which is now a GenerationBlockingCriterion. + Old serialized experiments may have transition_to, block_gen_if_met, etc. + fields that are now deprecated and should be ignored.""" # Simulate old serialized format with extra fields that no longer exist old_format_json = { "threshold": 5, @@ -3411,10 +3418,12 @@ def test_transition_criterion_deserialize_with_extra_fields(self) -> None: "some_deprecated_field": "should_be_ignored", } - # Should not raise, extra field should be ignored + # Should not raise, extra fields should be ignored. + # Note: MaxGenerationParallelism is now a GenerationBlockingCriterion, + # so we use generation_blocking_criterion_from_json. criterion = assert_is_instance( - transition_criterion_from_json( - transition_criterion_class=MaxGenerationParallelism, + generation_blocking_criterion_from_json( + blocking_criterion_class=MaxGenerationParallelism, object_json=old_format_json, decoder_registry=CORE_DECODER_REGISTRY, class_decoder_registry=CORE_CLASS_DECODER_REGISTRY, @@ -3422,13 +3431,13 @@ def test_transition_criterion_deserialize_with_extra_fields(self) -> None: MaxGenerationParallelism, ) self.assertEqual(criterion.threshold, 5) - self.assertEqual(criterion.transition_to, "test_node") def test_gen_node_deserialize_with_tc_transition_to_none( self, ) -> None: - """Test backwards compatibility when loading a MaxGenerationParallelism - that was stored with transition_to=None + """Test backwards compatibility when loading an old GenerationNode that + has MaxGenerationParallelism stored in transition_criteria. The decoder + should automatically migrate it to generation_blocking_criteria. """ old_format_node_json = { "__type": "GenerationNode", @@ -3446,7 +3455,8 @@ def test_gen_node_deserialize_with_tc_transition_to_none( "__type": "MaxGenerationParallelism", "threshold": 3, "only_in_statuses": [{"__type": "TrialStatus", "name": "RUNNING"}], - "transition_to": None, # Old default + "block_gen_if_met": True, + "transition_to": None, # Old default - should be ignored } ], } @@ -3457,11 +3467,65 @@ def test_gen_node_deserialize_with_tc_transition_to_none( class_decoder_registry=CORE_CLASS_DECODER_REGISTRY, ) self.assertEqual(node.name, "test_node") - self.assertEqual(len(node.transition_criteria), 1) + # MaxGenerationParallelism should have been migrated from + # transition_criteria to generation_blocking_criteria + self.assertEqual(len(node.transition_criteria), 0) + self.assertEqual(len(node.generation_blocking_criteria), 1) criterion = assert_is_instance( - node.transition_criteria[0], + node.generation_blocking_criteria[0], MaxGenerationParallelism, ) self.assertEqual(criterion.threshold, 3) - # transition_to should now be set to the node name (pointing to itself) - self.assertEqual(criterion.transition_to, "test_node") + + def test_block_gen_if_met_mintrials_migration(self) -> None: + """Test backwards compatibility when loading an old GenerationNode that + has MinTrials with block_gen_if_met=True. The decoder should + automatically migrate it to MaxTrialsAwaitingData in + generation_blocking_criteria. + """ + # MinTrials with both block_gen_if_met=True and + # block_transition_if_unmet=True should create MaxTrialsAwaitingData + # AND keep in transition_criteria + old_format_node_json = { + "__type": "GenerationNode", + "name": "test_node", + "generator_specs": [ + { + "__type": "GeneratorSpec", + "generator_enum": {"__type": "Generators", "name": "SOBOL"}, + "generator_kwargs": {}, + "generator_gen_kwargs": {}, + } + ], + "transition_criteria": [ + { + "__type": "MinTrials", + "threshold": 5, + "only_in_statuses": None, + "not_in_statuses": None, + "use_all_trials_in_exp": True, + "block_gen_if_met": True, + "block_transition_if_unmet": True, + "transition_to": "next_node", + } + ], + } + + node = generation_node_from_json( + generation_node_json=old_format_node_json, + decoder_registry=CORE_DECODER_REGISTRY, + class_decoder_registry=CORE_CLASS_DECODER_REGISTRY, + ) + self.assertEqual(node.name, "test_node") + # Should have both, one to represent the blocking criterion and + # one to represent the transition criterion + self.assertEqual(len(node.transition_criteria), 1) + self.assertEqual(len(node.generation_blocking_criteria), 1) + tc = assert_is_instance(node.transition_criteria[0], MinTrials) + self.assertEqual(tc.threshold, 5) + self.assertEqual(tc.transition_to, "next_node") + blocking = assert_is_instance( + node.generation_blocking_criteria[0], MaxTrialsAwaitingData + ) + self.assertEqual(blocking.threshold, 5) + self.assertTrue(blocking.use_all_trials_in_exp) diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index c28bfad5399..9535e1b5747 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -88,7 +88,9 @@ from ax.generation_strategy.generator_spec import GeneratorSpec from ax.generation_strategy.transition_criterion import ( AutoTransitionAfterGen, + GenerationBlockingCriterion, MaxGenerationParallelism, + MaxTrialsAwaitingData, MinTrials, TransitionCriterion, ) @@ -179,6 +181,7 @@ def get_experiment_with_map_data_type() -> Experiment: def get_trial_based_criterion() -> list[TransitionCriterion]: + """Returns a list of trial-based TransitionCriteria for testing.""" return [ MinTrials( threshold=3, @@ -186,16 +189,21 @@ def get_trial_based_criterion() -> list[TransitionCriterion]: only_in_statuses=[TrialStatus.RUNNING, TrialStatus.COMPLETED], not_in_statuses=None, ), + AutoTransitionAfterGen( + transition_to="next_node", + ), + ] + + +def get_generation_blocking_criterion() -> list[GenerationBlockingCriterion]: + """Returns a list of GenerationBlockingCriteria for testing.""" + return [ MaxGenerationParallelism( threshold=5, only_in_statuses=None, not_in_statuses=[ TrialStatus.RUNNING, ], - transition_to="Sobol", - ), - AutoTransitionAfterGen( - transition_to="next_node", ), ] @@ -2960,14 +2968,12 @@ def get_online_sobol_mbm_generation_strategy() -> GenerationStrategy: MinTrials( threshold=1, transition_to="MBM_node", - block_gen_if_met=True, only_in_statuses=None, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], ), MinTrials( threshold=1, transition_to="MBM_node", - block_gen_if_met=True, only_in_statuses=[ TrialStatus.RUNNING, TrialStatus.COMPLETED, @@ -2985,9 +2991,25 @@ def get_online_sobol_mbm_generation_strategy() -> GenerationStrategy: generator_kwargs=step_generator_kwargs, generator_gen_kwargs={}, ) + sobol_blocking_criteria = [ + MaxTrialsAwaitingData( + threshold=1, + only_in_statuses=None, + not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], + ), + MaxTrialsAwaitingData( + threshold=1, + only_in_statuses=[ + TrialStatus.RUNNING, + TrialStatus.COMPLETED, + TrialStatus.EARLY_STOPPED, + ], + ), + ] sobol_node = GenerationNode( name="sobol_node", transition_criteria=sobol_criterion, + generation_blocking_criteria=sobol_blocking_criteria, generator_specs=[sobol_generator_spec], input_constructors={InputConstructorPurpose.N: NodeInputConstructors.ALL_N}, ) diff --git a/ax/utils/testing/modeling_stubs.py b/ax/utils/testing/modeling_stubs.py index 6e11278a8b6..75f7a73407b 100644 --- a/ax/utils/testing/modeling_stubs.py +++ b/ax/utils/testing/modeling_stubs.py @@ -41,6 +41,7 @@ AutoTransitionAfterGen, IsSingleObjective, MaxGenerationParallelism, + MaxTrialsAwaitingData, MinTrials, ) from ax.generators.torch.botorch_modular.surrogate import ( @@ -184,7 +185,6 @@ def sobol_gpei_generation_node_gs( MinTrials( threshold=5, transition_to="MBM_node", - block_gen_if_met=True, only_in_statuses=None, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], ) @@ -195,7 +195,6 @@ def sobol_gpei_generation_node_gs( MinTrials( threshold=2, transition_to="MBM_node", - block_gen_if_met=True, only_in_statuses=None, not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], ), @@ -204,17 +203,20 @@ def sobol_gpei_generation_node_gs( MinTrials( threshold=0, transition_to="MBM_node", - block_gen_if_met=False, only_in_statuses=[TrialStatus.CANDIDATE], not_in_statuses=None, ), + ] + sobol_blocking_criteria = [ + MaxTrialsAwaitingData( + threshold=5, + only_in_statuses=None, + not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], + ), MaxGenerationParallelism( threshold=1000, - transition_to="MBM_node", - block_gen_if_met=True, only_in_statuses=[TrialStatus.RUNNING], not_in_statuses=None, - continue_trial_generation=False, ), ] auto_mbm_criterion = [AutoTransitionAfterGen(transition_to="MBM_node")] @@ -235,6 +237,7 @@ def sobol_gpei_generation_node_gs( sobol_node = GenerationNode( name="sobol_node", transition_criteria=sobol_criterion, + generation_blocking_criteria=sobol_blocking_criteria, generator_specs=[sobol_generator_spec], ) if with_model_selection: diff --git a/tutorials/external_generation_node/external_generation_node.ipynb b/tutorials/external_generation_node/external_generation_node.ipynb index 928a551a921..8e5933086e8 100644 --- a/tutorials/external_generation_node/external_generation_node.ipynb +++ b/tutorials/external_generation_node/external_generation_node.ipynb @@ -47,7 +47,7 @@ "from ax.generation_strategy.generation_node import GenerationNode\n", "from ax.generation_strategy.generation_strategy import GenerationStrategy\n", "from ax.generation_strategy.generator_spec import GeneratorSpec\n", - "from ax.generation_strategy.transition_criterion import MinTrials\n", + "from ax.generation_strategy.transition_criterion import MaxTrialsAwaitingData, MinTrials\n", "\n", "from sklearn.ensemble import RandomForestRegressor" ] @@ -231,10 +231,16 @@ " # This specifies the maximum number of trials to generate from this node,\n", " # and the next node in the strategy.\n", " threshold=25,\n", - " block_transition_if_unmet=True,\n", " transition_to=\"RandomForest\",\n", " )\n", " ],\n", + " generation_blocking_criteria=[\n", + " MaxTrialsAwaitingData(\n", + " # Block generation once we have 25 trials with data.\n", + " threshold=25,\n", + " not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],\n", + " )\n", + " ],\n", " ),\n", " RandomForestGenerationNode(num_samples=128, regressor_options={}),\n", " ],\n", @@ -388,6 +394,8 @@ ], "metadata": { "fileHeader": "", + "fileUid": "0712ebfb-bf0d-43ad-9aec-e0447e101ece", + "isAdHoc": false, "kernelspec": { "display_name": "python3", "language": "python",