From 70c4f6a59224b978d7b96d32b2d1e182446403e4 Mon Sep 17 00:00:00 2001 From: Jerry Lin Date: Fri, 13 Feb 2026 15:08:52 -0800 Subject: [PATCH] Add in-sample candidate generation support for MBG Summary: Enable in-sample candidate generation in the Modular BoTorch Generator (MBG), allowing optimization to select from existing training data rather than optimizing over the full search space. This is achieved via adding 2 optional kwargs in `Acquisition.optimize`: `candidate_set` and `sampling_strategy` This hits 4 birds with 1 stone - the same mechanism may support: - Model-based/Contextual bandits - Bake-off/Best arm selection - in-sample preference learning (in-sample PBO and BOPE preference game) - LILO (for LLM to label observed points) Supports multiple selection methods including GP Thompson Sampling, Top-Two Thompson Sampling (TTTS), greedy acquisition (qSimpleRegret), q-batched acquisition (qLogNEI), random selection, and Boltzmann sampling via `model_gen_options`. This is partially inspired by the [Support for SamplingStrategy](https://docs.google.com/document/d/19mLXg88bjA_NzYCz59KmY7Zq10XnqSD4HWkFkNbhYR4/edit?usp=sharing) design doc but with a wider range of applications in mind. Reviewed By: Balandat Differential Revision: D92124823 --- .../torch/botorch_modular/acquisition.py | 144 ++++++++++++++++++ .../torch/botorch_modular/generator.py | 111 ++++++++++++++ ax/generators/torch/botorch_modular/utils.py | 4 + ax/generators/torch/tests/test_acquisition.py | 95 ++++++++++++ ax/generators/torch/tests/test_generator.py | 119 +++++++++++++++ 5 files changed, 473 insertions(+) diff --git a/ax/generators/torch/botorch_modular/acquisition.py b/ax/generators/torch/botorch_modular/acquisition.py index 4a8fb74849b..ad5789554f7 100644 --- a/ax/generators/torch/botorch_modular/acquisition.py +++ b/ax/generators/torch/botorch_modular/acquisition.py @@ -50,6 +50,7 @@ from botorch.acquisition.multioutput_acquisition import MultiOutputAcquisitionFunction from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform from botorch.exceptions.errors import BotorchError, InputDataError +from botorch.generation.sampling import SamplingStrategy from botorch.models.model import Model from botorch.optim.optimize import ( optimize_acqf, @@ -465,6 +466,123 @@ def objective_weights(self) -> Tensor | None: """The objective weights for all outcomes.""" return self._full_objective_weights + def select_from_candidate_set( + self, + n: int, + candidate_set: Tensor, + sampling_strategy: SamplingStrategy | None = None, + ) -> tuple[Tensor, Tensor, Tensor]: + """Select n candidates from a discrete set with optional weight allocation. + + This method selects candidates from ``candidate_set`` using either a + ``SamplingStrategy`` (e.g., Thompson Sampling with win-counting for weight + allocation) or greedy acquisition function optimization. + + ``candidate_set`` is the stable interface for any candidate generation + method. Any method that produces candidates (in-sample training data, + pathwise TS optimization, user-provided sets, etc.) feeds into this + parameter. The selection/weight-allocation logic is agnostic to how + candidates were generated. + + Args: + n: The number of candidates to select. + candidate_set: A ``(num_choices, d)`` tensor of discrete candidate + points to select from. + sampling_strategy: An optional BoTorch ``SamplingStrategy`` instance + (e.g., ``MaxPosteriorSampling`` for Thompson Sampling, or + ``BoltzmannSampling`` for acquisition-weighted sampling). When + provided, candidates are selected by sampling from ``candidate_set`` + according to the strategy. When ``num_samples > n``, win-counting + mode is used: many posterior samples are drawn, wins are counted + per candidate, and the top-n candidates are returned with weights + proportional to their win probability (normalized to sum to 1). + If not provided, greedy acquisition function selection is used via + ``optimize_acqf_discrete``. + + Returns: + A three-element tuple containing an ``n x d``-dim tensor of selected + candidates, a tensor with the associated acquisition values, and a + tensor with the weight for each candidate (normalized to sum to 1 + for win-counting mode, or uniform for direct/greedy selection). + + Raises: + ValueError: If ``candidate_set`` is empty or has fewer points than + ``n``. + """ + if candidate_set.shape[0] == 0: + raise ValueError( + "`candidate_set` is empty. Provide a non-empty set of candidates." + ) + if candidate_set.shape[0] < n: + raise ValueError( + f"`candidate_set` has {candidate_set.shape[0]} candidates, " + f"but {n} were requested. Provide at least {n} candidates." + ) + + if sampling_strategy is not None: + # Check if this is a win-counting strategy (e.g., Thompson Sampling) + # or a direct selection strategy (e.g., Boltzmann Sampling). + # If num_samples is explicitly set and > n, use win-counting mode. + # Otherwise, use direct selection mode. + num_samples_attr = getattr(sampling_strategy, "num_samples", None) + num_samples: int | None = ( + int(num_samples_attr) if num_samples_attr is not None else None + ) + + if num_samples is not None and num_samples > n: + # Win-counting mode: sample many times, count wins, return top-n + # with weights proportional to win counts (normalized to sum to 1). + sampled_candidates = sampling_strategy( + candidate_set.unsqueeze(0), num_samples=num_samples + ).squeeze(0) # (num_samples, d) + + # Count wins for each unique candidate + unique_candidates, inverse_indices = torch.unique( + sampled_candidates, dim=0, return_inverse=True + ) + counts = torch.bincount( + inverse_indices, minlength=unique_candidates.shape[0] + ) + + # Select top-n candidates by win count. + # When num_unique < n (fewer unique winners than requested), + # we return all unique winners. The caller should handle + # candidates.shape[0] <= n, consistent with + # optimize_acqf_discrete which may also return fewer than n. + num_unique = unique_candidates.shape[0] + top_n = min(n, num_unique) + top_counts, top_indices = torch.topk(counts, top_n) + + candidates = unique_candidates[top_indices] + arm_weights = top_counts.to(dtype=self.dtype, device=self.device) + arm_weights = arm_weights / arm_weights.sum() + else: + # Direct selection mode: sample exactly n candidates with equal + # weights. Used for strategies like BoltzmannSampling where + # weighting is built into the selection process. + sampled_candidates = sampling_strategy( + candidate_set.unsqueeze(0), num_samples=n + ).squeeze(0) # (n, d) + candidates = sampled_candidates + arm_weights = torch.ones(n, dtype=self.dtype, device=self.device) + + acqf_values = self.evaluate(candidates.unsqueeze(1)).view(-1) + return candidates, acqf_values, arm_weights + + # Greedy selection from provided discrete candidate set via acqf. + # optimize_acqf_discrete may return fewer than n candidates when + # there are fewer feasible choices; arm_weights matches actual count. + candidates, acqf_values = optimize_acqf_discrete( + acq_function=self.acqf, + q=n, + choices=candidate_set, + unique=True, + ) + arm_weights = torch.ones( + candidates.shape[0], dtype=self.dtype, device=self.device + ) + return candidates, acqf_values, arm_weights + def optimize( self, n: int, @@ -473,6 +591,8 @@ def optimize( fixed_features: dict[int, float] | None = None, rounding_func: Callable[[Tensor], Tensor] | None = None, optimizer_options: dict[str, Any] | None = None, + candidate_set: Tensor | None = None, + sampling_strategy: SamplingStrategy | None = None, ) -> tuple[Tensor, Tensor, Tensor]: """Generate a set of candidates via multi-start optimization. Obtains candidates and their associated acquisition function values. @@ -498,12 +618,36 @@ def optimize( that typically only exist in MBM, such as BoTorch transforms. See the docstring of `TorchOptConfig` for more information on passing down these options while constructing a generation strategy. + candidate_set: An optional tensor of shape `(num_choices, d)` containing + discrete candidate points to select from instead of optimizing over + the search space. When provided, selection is delegated to + ``select_from_candidate_set``. This enables in-sample candidate + generation when set to the training data (X_observed). + sampling_strategy: An optional BoTorch ``SamplingStrategy`` instance + (e.g., ``MaxPosteriorSampling`` for Thompson Sampling, or + ``BoltzmannSampling`` for acquisition-weighted sampling). + Passed to ``select_from_candidate_set`` when ``candidate_set`` + is provided. Requires ``candidate_set`` to be provided. Returns: A three-element tuple containing an `n x d`-dim tensor of generated candidates, a tensor with the associated acquisition values, and a tensor with the weight for each candidate. """ + # Dispatch to candidate set selection if candidate_set or + # sampling_strategy is provided. + if sampling_strategy is not None or candidate_set is not None: + if candidate_set is None: + raise ValueError( + "`candidate_set` is required when using `sampling_strategy`. " + "Provide the discrete set of candidates to sample from." + ) + return self.select_from_candidate_set( + n=n, + candidate_set=candidate_set, + sampling_strategy=sampling_strategy, + ) + # Options that would need to be passed in the transformed space are # disallowed, since this would be very difficult for an end user to do # directly, and someone who uses BoTorch at this level of detail would diff --git a/ax/generators/torch/botorch_modular/generator.py b/ax/generators/torch/botorch_modular/generator.py index 062f5433c8a..8df7aa5dd56 100644 --- a/ax/generators/torch/botorch_modular/generator.py +++ b/ax/generators/torch/botorch_modular/generator.py @@ -36,6 +36,8 @@ from ax.utils.common.docutils import copy_doc from ax.utils.common.logger import get_logger from botorch.acquisition.acquisition import AcquisitionFunction +from botorch.acquisition.objective import ScalarizedPosteriorTransform +from botorch.generation.sampling import BoltzmannSampling, MaxPosteriorSampling from botorch.models.deterministic import FixedSingleSampleModel from botorch.settings import validate_input_scaling from botorch.utils.datasets import SupervisedDataset @@ -45,6 +47,103 @@ logger: Logger = get_logger(__name__) +def _build_candidate_generation_options( + model_gen_options: dict[str, Any], + torch_opt_config: TorchOptConfig, + surrogate: Surrogate, + acqf: Acquisition, +) -> tuple[Tensor | None, Any | None]: + """Build candidate_set and sampling_strategy for discrete candidate generation. + + This helper function processes model_gen_options to construct: + - candidate_set: The discrete set of candidates to select from. + Any method that produces candidates (in-sample training data, or future + methods like pathwise TS optimization) feeds into this parameter. + The selection/weight-allocation logic downstream is agnostic to how + candidates were generated. + - sampling_strategy: The BoTorch SamplingStrategy to use for selection + + Args: + model_gen_options: Dictionary containing candidate generation options: + - in_sample: If True, use training data as candidate_set + - sampling_strategy_class: Class of SamplingStrategy to instantiate + - sampling_strategy_kwargs: kwargs to pass to the strategy + torch_opt_config: Configuration containing objective weights + surrogate: The surrogate model providing training data + acqf: The Acquisition object providing the acquisition function + + Returns: + A tuple of (candidate_set, sampling_strategy) where either or both may + be None if not configured in model_gen_options. + + Note: + When in_sample=True without a sampling_strategy, candidates are selected + via optimize_acqf_discrete, which uses sequential greedy selection with + X_pending for q > 1. This requires the acquisition function to support + X_pending (e.g., qSimpleRegret, qLogNEI). Analytic acquisition functions + like PosteriorMean do not support X_pending and will only work for q=1. + """ + candidate_set = None + sampling_strategy = None + + # Determine candidate set for in-sample generation + if model_gen_options.get("in_sample", False): + if surrogate.Xs: + candidate_set = surrogate.Xs[0] + else: + raise ValueError( + "in_sample=True requires training data, but no data is available." + ) + + # Build sampling strategy if requested + sampling_strategy_class = model_gen_options.get("sampling_strategy_class") + if sampling_strategy_class is not None: + strategy_kwargs = dict(model_gen_options.get("sampling_strategy_kwargs", {})) + # Extract num_samples to set on strategy instance after construction + # This allows acquisition.optimize() to use getattr() to retrieve it + num_samples = strategy_kwargs.pop("num_samples", None) + + if issubclass(sampling_strategy_class, MaxPosteriorSampling): + # Thompson Sampling: sample from model posterior + # For minimization objectives (objective_weights=-1), we need to pass + # a posterior_transform that negates the values so MaxPosteriorSampling + # correctly finds the arm with lowest predicted value. + + # Get objective weights - for minimization this is -1, maximization is 1 + objective_weights = torch_opt_config.objective_weights + # Only use non-zero weights (actual objectives, not constraints) + obj_mask = objective_weights.nonzero().view(-1) + posterior_transform = ScalarizedPosteriorTransform( + weights=objective_weights[obj_mask] + ) + + sampling_strategy = sampling_strategy_class( + model=surrogate.model, + posterior_transform=posterior_transform, + **strategy_kwargs, + ) + elif issubclass(sampling_strategy_class, BoltzmannSampling): + # Boltzmann Sampling: sample weighted by acquisition values + sampling_strategy = sampling_strategy_class( + acq_func=acqf.acqf, + **strategy_kwargs, + ) + else: + # Generic SamplingStrategy - try to instantiate with provided kwargs + # User is responsible for providing appropriate kwargs + sampling_strategy = sampling_strategy_class(**strategy_kwargs) + + # Set num_samples on the strategy instance if provided. + # This is used by select_from_candidate_set() to determine + # win-counting vs. direct selection mode. The attribute is ephemeral + # (not part of nn.Module state_dict) since the strategy is created + # fresh each gen() call and never serialized. + if num_samples is not None: + sampling_strategy.num_samples = num_samples + + return candidate_set, sampling_strategy + + class BoTorchGenerator(TorchGenerator, Base): """**All classes in 'botorch_modular' directory are under construction, incomplete, and should be treated as alpha @@ -281,6 +380,16 @@ def gen( acqf = none_throws(self._acquisition) botorch_rounding_func = get_rounding_func(torch_opt_config.rounding_func) + + # Handle candidate generation via model_gen_options + model_gen_options = torch_opt_config.model_gen_options or {} + candidate_set, sampling_strategy = _build_candidate_generation_options( + model_gen_options=model_gen_options, + torch_opt_config=torch_opt_config, + surrogate=self.surrogate, + acqf=acqf, + ) + candidates, expected_acquisition_value, weights = acqf.optimize( n=n, search_space_digest=search_space_digest, @@ -293,6 +402,8 @@ def gen( opt_options, dict, ), + candidate_set=candidate_set, + sampling_strategy=sampling_strategy, ) gen_metadata = self._get_gen_metadata_from_acqf( acqf=acqf, diff --git a/ax/generators/torch/botorch_modular/utils.py b/ax/generators/torch/botorch_modular/utils.py index e30703455ea..d0466e87280 100644 --- a/ax/generators/torch/botorch_modular/utils.py +++ b/ax/generators/torch/botorch_modular/utils.py @@ -510,6 +510,10 @@ def construct_acquisition_and_optimizer_options( Keys.OPTIMIZER_KWARGS.value, Keys.ACQF_KWARGS.value, Keys.AX_ACQUISITION_KWARGS.value, + # Keys for candidate generation + "in_sample", + "sampling_strategy_class", + "sampling_strategy_kwargs", } ) > 0 diff --git a/ax/generators/torch/tests/test_acquisition.py b/ax/generators/torch/tests/test_acquisition.py index 56df0e1960a..0606715df4a 100644 --- a/ax/generators/torch/tests/test_acquisition.py +++ b/ax/generators/torch/tests/test_acquisition.py @@ -674,6 +674,98 @@ def test_optimize_discrete_single_candidate(self) -> None: expected = torch.tensor([all_choices[7]], **self.tkwargs) self.assertTrue(torch.equal(candidates, expected)) + def test_select_from_candidate_set(self) -> None: + """Test all select_from_candidate_set paths and optimize dispatch.""" + from botorch.generation.sampling import SamplingStrategy + + acquisition = self.get_acquisition_function() + + with self.subTest("validation_too_few_candidates"): + with self.assertRaisesRegex(ValueError, "but 3 were requested"): + acquisition.select_from_candidate_set( + n=3, + candidate_set=torch.tensor([[1.0, 2.0, 3.0]], **self.tkwargs), + ) + + with self.subTest("validation_empty_candidate_set"): + with self.assertRaisesRegex(ValueError, "empty"): + acquisition.select_from_candidate_set( + n=1, + candidate_set=torch.empty(0, 3, **self.tkwargs), + ) + + with self.subTest("win_counting_normalized"): + candidate_set = torch.tensor( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], + **self.tkwargs, + ) + + class _AlternatingWinStrategy(SamplingStrategy): + """Candidate 0 wins 75% of the time, candidate 1 wins 25%.""" + + num_samples: int = 0 + + def forward(self, X: Tensor, num_samples: int = 1) -> Tensor: + n_first = int(num_samples * 0.75) + n_second = num_samples - n_first + first = X[..., 0:1, :].expand(*X.shape[:-2], n_first, X.shape[-1]) + second = X[..., 1:2, :].expand(*X.shape[:-2], n_second, X.shape[-1]) + return torch.cat([first, second], dim=-2) + + strategy = _AlternatingWinStrategy() + strategy.num_samples = 100 + + candidates, _, weights = acquisition.select_from_candidate_set( + n=2, + candidate_set=candidate_set, + sampling_strategy=strategy, + ) + self.assertEqual(candidates.shape[0], 2) + self.assertAlmostEqual(weights.sum().item(), 1.0, places=4) + self.assertAlmostEqual(weights[0].item(), 0.75, places=4) + self.assertAlmostEqual(weights[1].item(), 0.25, places=4) + + with self.subTest("direct_selection_without_num_samples"): + candidate_set = torch.tensor( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], + **self.tkwargs, + ) + + class _DirectStrategy(SamplingStrategy): + """Always returns the first n candidates.""" + + def forward(self, X: Tensor, num_samples: int = 1) -> Tensor: + return X[..., :num_samples, :] + + candidates, _, weights = acquisition.select_from_candidate_set( + n=2, + candidate_set=candidate_set, + sampling_strategy=_DirectStrategy(), + ) + self.assertEqual(candidates.shape[0], 2) + self.assertTrue(torch.all(weights == 1.0)) + self.assertEqual(weights.shape, (2,)) + + with self.subTest("greedy_via_optimize_acqf_discrete"): + candidate_set = torch.rand(10, 3, **self.tkwargs) + candidates, _, weights = acquisition.select_from_candidate_set( + n=1, + candidate_set=candidate_set, + ) + self.assertEqual(candidates.shape, (1, 3)) + self.assertEqual(weights.shape, (1,)) + self.assertAlmostEqual(weights[0].item(), 1.0, places=6) + self.assertTrue((candidate_set == candidates[0]).all(dim=-1).any()) + + with self.subTest("optimize_raises_strategy_without_candidate_set"): + strategy = Mock(spec=SamplingStrategy) + with self.assertRaisesRegex(ValueError, "candidate_set.*required"): + acquisition.optimize( + n=1, + search_space_digest=self.search_space_digest, + sampling_strategy=strategy, + ) + # mock `optimize_acqf_discrete_local_search` because it isn't handled by # `mock_botorch_optimize` @mock.patch( @@ -1948,6 +2040,9 @@ def test_optimize_mixed(self) -> None: def test_optimize_acqf_mixed_alternating(self) -> None: pass + def test_select_from_candidate_set(self) -> None: + pass + # Mock so that we can check that arguments are passed correctly. @mock.patch(f"{ACQUISITION_PATH}._get_X_pending_and_observed") @mock.patch( diff --git a/ax/generators/torch/tests/test_generator.py b/ax/generators/torch/tests/test_generator.py index c034506df42..ef03c2252f8 100644 --- a/ax/generators/torch/tests/test_generator.py +++ b/ax/generators/torch/tests/test_generator.py @@ -602,6 +602,8 @@ def _test_gen( fixed_features=self.fixed_features, rounding_func=None, optimizer_options=self.optimizer_options, + candidate_set=None, + sampling_strategy=None, ) # make sure ACQF_KWARGS are passed properly self.assertEqual(none_throws(model._acquisition).acqf._eta, 3.0) @@ -620,6 +622,123 @@ def _test_gen( ) self.assertTrue(torch.isfinite(gen_results.points).all()) + def test_build_candidate_generation_options(self) -> None: + """Test _build_candidate_generation_options across all code paths.""" + from ax.generators.torch.botorch_modular.generator import ( + _build_candidate_generation_options, + ) + from botorch.generation.sampling import BoltzmannSampling, MaxPosteriorSampling + + mock_acqf = Mock() + mock_acqf.acqf = Mock() + + with self.subTest("no_options_returns_none"): + candidate_set, sampling_strategy = _build_candidate_generation_options( + model_gen_options={}, + torch_opt_config=self.torch_opt_config, + surrogate=Mock(), + acqf=mock_acqf, + ) + self.assertIsNone(candidate_set) + self.assertIsNone(sampling_strategy) + + with self.subTest("in_sample_no_data_raises"): + mock_surrogate = Mock() + mock_surrogate.Xs = [] + with self.assertRaisesRegex(ValueError, "no data is available"): + _build_candidate_generation_options( + model_gen_options={"in_sample": True}, + torch_opt_config=self.torch_opt_config, + surrogate=mock_surrogate, + acqf=mock_acqf, + ) + + train_X = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + mock_surrogate = Mock() + mock_surrogate.Xs = [train_X] + mock_surrogate.model = Mock() + + with self.subTest("in_sample_only"): + candidate_set, sampling_strategy = _build_candidate_generation_options( + model_gen_options={"in_sample": True}, + torch_opt_config=self.torch_opt_config, + surrogate=mock_surrogate, + acqf=mock_acqf, + ) + self.assertIsNotNone(candidate_set) + self.assertTrue(torch.equal(candidate_set, train_X)) + self.assertIsNone(sampling_strategy) + + with self.subTest("max_posterior_sampling"): + candidate_set, sampling_strategy = _build_candidate_generation_options( + model_gen_options={ + "in_sample": True, + "sampling_strategy_class": MaxPosteriorSampling, + "sampling_strategy_kwargs": {"num_samples": 100}, + }, + torch_opt_config=self.torch_opt_config, + surrogate=mock_surrogate, + acqf=mock_acqf, + ) + self.assertIsNotNone(candidate_set) + self.assertTrue(torch.equal(candidate_set, train_X)) + self.assertIsInstance(sampling_strategy, MaxPosteriorSampling) + self.assertEqual(sampling_strategy.num_samples, 100) + + with self.subTest("boltzmann_sampling"): + candidate_set, sampling_strategy = _build_candidate_generation_options( + model_gen_options={ + "sampling_strategy_class": BoltzmannSampling, + "sampling_strategy_kwargs": {"eta": 1.0}, + }, + torch_opt_config=self.torch_opt_config, + surrogate=mock_surrogate, + acqf=mock_acqf, + ) + self.assertIsNone(candidate_set) + self.assertIsInstance(sampling_strategy, BoltzmannSampling) + + with self.subTest("generic_strategy"): + + class _CustomStrategy: + def __init__(self, temperature: float = 1.0) -> None: + self.temperature = temperature + + candidate_set, sampling_strategy = _build_candidate_generation_options( + model_gen_options={ + "sampling_strategy_class": _CustomStrategy, + "sampling_strategy_kwargs": {"temperature": 2.0}, + }, + torch_opt_config=self.torch_opt_config, + surrogate=mock_surrogate, + acqf=mock_acqf, + ) + self.assertIsNone(candidate_set) + self.assertIsInstance(sampling_strategy, _CustomStrategy) + self.assertEqual(sampling_strategy.temperature, 2.0) + + with self.subTest("no_mutation_of_input_dict"): + + class _DummyStrategy: + def __init__(self, **kwargs: object) -> None: + pass + + model_gen_options = { + "in_sample": True, + "sampling_strategy_class": _DummyStrategy, + "sampling_strategy_kwargs": {"num_samples": 50}, + } + _build_candidate_generation_options( + model_gen_options=model_gen_options, + torch_opt_config=self.torch_opt_config, + surrogate=mock_surrogate, + acqf=mock_acqf, + ) + self.assertEqual( + model_gen_options["sampling_strategy_kwargs"], + {"num_samples": 50}, + ) + def test_gen_SingleTaskGP(self) -> None: self._test_gen( botorch_model_class=SingleTaskGP,