Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions ax/api/utils/generation_strategy_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
GenerationStrategy,
)
from ax.generation_strategy.generator_spec import GeneratorSpec
from ax.generation_strategy.transition_criterion import MinTrials
from ax.generation_strategy.transition_criterion import MaxTrialsAwaitingData, MinTrials
from ax.generators.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
Expand Down Expand Up @@ -70,21 +70,28 @@ def _get_sobol_node(
MinTrials( # This represents the initialization budget.
threshold=initialization_budget,
transition_to="MBM",
block_gen_if_met=(not allow_exceeding_initialization_budget),
block_transition_if_unmet=True,
use_all_trials_in_exp=use_existing_trials_for_initialization,
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
),
MinTrials( # This represents minimum observed trials requirement.
threshold=min_observed_initialization_trials,
transition_to="MBM",
block_gen_if_met=False,
block_transition_if_unmet=True,
use_all_trials_in_exp=True,
only_in_statuses=[TrialStatus.COMPLETED],
count_only_trials_with_data=True,
),
]
# If we want to enforce the initialization budget, add a generation blocking
# criterion that prevents exceeding the budget.
generation_blocking_criteria = None
if not allow_exceeding_initialization_budget:
generation_blocking_criteria = [
MaxTrialsAwaitingData(
threshold=initialization_budget,
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
use_all_trials_in_exp=use_existing_trials_for_initialization,
)
]
return GenerationNode(
name="Sobol",
generator_specs=[
Expand All @@ -94,6 +101,7 @@ def _get_sobol_node(
)
],
transition_criteria=transition_criteria,
generation_blocking_criteria=generation_blocking_criteria,
should_deduplicate=True,
)

Expand Down
19 changes: 13 additions & 6 deletions ax/api/utils/tests/test_generation_strategy_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ax.exceptions.core import UserInputError
from ax.generation_strategy.center_generation_node import CenterGenerationNode
from ax.generation_strategy.dispatch_utils import get_derelativize_config
from ax.generation_strategy.transition_criterion import MinTrials
from ax.generation_strategy.transition_criterion import MaxTrialsAwaitingData, MinTrials
from ax.generators.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
Expand Down Expand Up @@ -99,16 +99,12 @@ def test_choose_gs_fast_with_options(self) -> None:
MinTrials(
threshold=2,
transition_to="MBM",
block_gen_if_met=False,
block_transition_if_unmet=True,
use_all_trials_in_exp=False,
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
),
MinTrials(
threshold=4,
transition_to="MBM",
block_gen_if_met=False,
block_transition_if_unmet=True,
use_all_trials_in_exp=True,
only_in_statuses=[TrialStatus.COMPLETED],
count_only_trials_with_data=True,
Expand Down Expand Up @@ -390,7 +386,18 @@ def test_abandoned_and_failed_trials_excluded_from_initialization_budget(
first_tc.not_in_statuses, [TrialStatus.FAILED, TrialStatus.ABANDONED]
)
self.assertEqual(first_tc.threshold, 5)
self.assertTrue(first_tc.block_gen_if_met)
# Verify MaxTrialsAwaitingData is in generation_blocking_criteria
blocking_criteria = [
bc
for bc in sobol_node._generation_blocking_criteria
if isinstance(bc, MaxTrialsAwaitingData)
]
self.assertEqual(len(blocking_criteria), 1)
self.assertEqual(blocking_criteria[0].threshold, 5)
self.assertEqual(
blocking_criteria[0].not_in_statuses,
[TrialStatus.FAILED, TrialStatus.ABANDONED],
)

# Test the actual behavior: Generate 5 trials, mark 3 as ABANDONED,
# verify that Sobol can still generate more trials
Expand Down
Loading