Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions ax/generators/torch/botorch_modular/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from botorch.acquisition.multioutput_acquisition import MultiOutputAcquisitionFunction
from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform
from botorch.exceptions.errors import BotorchError, InputDataError
from botorch.generation.sampling import SamplingStrategy
from botorch.models.model import Model
from botorch.optim.optimize import (
optimize_acqf,
Expand Down Expand Up @@ -465,6 +466,123 @@ def objective_weights(self) -> Tensor | None:
"""The objective weights for all outcomes."""
return self._full_objective_weights

def select_from_candidate_set(
self,
n: int,
candidate_set: Tensor,
sampling_strategy: SamplingStrategy | None = None,
) -> tuple[Tensor, Tensor, Tensor]:
"""Select n candidates from a discrete set with optional weight allocation.

This method selects candidates from ``candidate_set`` using either a
``SamplingStrategy`` (e.g., Thompson Sampling with win-counting for weight
allocation) or greedy acquisition function optimization.

``candidate_set`` is the stable interface for any candidate generation
method. Any method that produces candidates (in-sample training data,
pathwise TS optimization, user-provided sets, etc.) feeds into this
parameter. The selection/weight-allocation logic is agnostic to how
candidates were generated.

Args:
n: The number of candidates to select.
candidate_set: A ``(num_choices, d)`` tensor of discrete candidate
points to select from.
sampling_strategy: An optional BoTorch ``SamplingStrategy`` instance
(e.g., ``MaxPosteriorSampling`` for Thompson Sampling, or
``BoltzmannSampling`` for acquisition-weighted sampling). When
provided, candidates are selected by sampling from ``candidate_set``
according to the strategy. When ``num_samples > n``, win-counting
mode is used: many posterior samples are drawn, wins are counted
per candidate, and the top-n candidates are returned with weights
proportional to their win probability (normalized to sum to 1).
If not provided, greedy acquisition function selection is used via
``optimize_acqf_discrete``.

Returns:
A three-element tuple containing an ``n x d``-dim tensor of selected
candidates, a tensor with the associated acquisition values, and a
tensor with the weight for each candidate (normalized to sum to 1
for win-counting mode, or uniform for direct/greedy selection).

Raises:
ValueError: If ``candidate_set`` is empty or has fewer points than
``n``.
"""
if candidate_set.shape[0] == 0:
raise ValueError(
"`candidate_set` is empty. Provide a non-empty set of candidates."
)
if candidate_set.shape[0] < n:
raise ValueError(
f"`candidate_set` has {candidate_set.shape[0]} candidates, "
f"but {n} were requested. Provide at least {n} candidates."
)

if sampling_strategy is not None:
# Check if this is a win-counting strategy (e.g., Thompson Sampling)
# or a direct selection strategy (e.g., Boltzmann Sampling).
# If num_samples is explicitly set and > n, use win-counting mode.
# Otherwise, use direct selection mode.
num_samples_attr = getattr(sampling_strategy, "num_samples", None)
num_samples: int | None = (
int(num_samples_attr) if num_samples_attr is not None else None
)

if num_samples is not None and num_samples > n:
# Win-counting mode: sample many times, count wins, return top-n
# with weights proportional to win counts (normalized to sum to 1).
sampled_candidates = sampling_strategy(
candidate_set.unsqueeze(0), num_samples=num_samples
).squeeze(0) # (num_samples, d)

# Count wins for each unique candidate
unique_candidates, inverse_indices = torch.unique(
sampled_candidates, dim=0, return_inverse=True
)
counts = torch.bincount(
inverse_indices, minlength=unique_candidates.shape[0]
)

# Select top-n candidates by win count.
# When num_unique < n (fewer unique winners than requested),
# we return all unique winners. The caller should handle
# candidates.shape[0] <= n, consistent with
# optimize_acqf_discrete which may also return fewer than n.
num_unique = unique_candidates.shape[0]
top_n = min(n, num_unique)
top_counts, top_indices = torch.topk(counts, top_n)

candidates = unique_candidates[top_indices]
arm_weights = top_counts.to(dtype=self.dtype, device=self.device)
arm_weights = arm_weights / arm_weights.sum()
else:
# Direct selection mode: sample exactly n candidates with equal
# weights. Used for strategies like BoltzmannSampling where
# weighting is built into the selection process.
sampled_candidates = sampling_strategy(
candidate_set.unsqueeze(0), num_samples=n
).squeeze(0) # (n, d)
candidates = sampled_candidates
arm_weights = torch.ones(n, dtype=self.dtype, device=self.device)

acqf_values = self.evaluate(candidates.unsqueeze(1)).view(-1)
return candidates, acqf_values, arm_weights

# Greedy selection from provided discrete candidate set via acqf.
# optimize_acqf_discrete may return fewer than n candidates when
# there are fewer feasible choices; arm_weights matches actual count.
candidates, acqf_values = optimize_acqf_discrete(
acq_function=self.acqf,
q=n,
choices=candidate_set,
unique=True,
)
arm_weights = torch.ones(
candidates.shape[0], dtype=self.dtype, device=self.device
)
return candidates, acqf_values, arm_weights

def optimize(
self,
n: int,
Expand All @@ -473,6 +591,8 @@ def optimize(
fixed_features: dict[int, float] | None = None,
rounding_func: Callable[[Tensor], Tensor] | None = None,
optimizer_options: dict[str, Any] | None = None,
candidate_set: Tensor | None = None,
sampling_strategy: SamplingStrategy | None = None,
) -> tuple[Tensor, Tensor, Tensor]:
"""Generate a set of candidates via multi-start optimization. Obtains
candidates and their associated acquisition function values.
Expand All @@ -498,12 +618,36 @@ def optimize(
that typically only exist in MBM, such as BoTorch transforms.
See the docstring of `TorchOptConfig` for more information on passing
down these options while constructing a generation strategy.
candidate_set: An optional tensor of shape `(num_choices, d)` containing
discrete candidate points to select from instead of optimizing over
the search space. When provided, selection is delegated to
``select_from_candidate_set``. This enables in-sample candidate
generation when set to the training data (X_observed).
sampling_strategy: An optional BoTorch ``SamplingStrategy`` instance
(e.g., ``MaxPosteriorSampling`` for Thompson Sampling, or
``BoltzmannSampling`` for acquisition-weighted sampling).
Passed to ``select_from_candidate_set`` when ``candidate_set``
is provided. Requires ``candidate_set`` to be provided.

Returns:
A three-element tuple containing an `n x d`-dim tensor of generated
candidates, a tensor with the associated acquisition values, and a tensor
with the weight for each candidate.
"""
# Dispatch to candidate set selection if candidate_set or
# sampling_strategy is provided.
if sampling_strategy is not None or candidate_set is not None:
if candidate_set is None:
raise ValueError(
"`candidate_set` is required when using `sampling_strategy`. "
"Provide the discrete set of candidates to sample from."
)
return self.select_from_candidate_set(
n=n,
candidate_set=candidate_set,
sampling_strategy=sampling_strategy,
)

# Options that would need to be passed in the transformed space are
# disallowed, since this would be very difficult for an end user to do
# directly, and someone who uses BoTorch at this level of detail would
Expand Down
111 changes: 111 additions & 0 deletions ax/generators/torch/botorch_modular/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from ax.utils.common.docutils import copy_doc
from ax.utils.common.logger import get_logger
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.objective import ScalarizedPosteriorTransform
from botorch.generation.sampling import BoltzmannSampling, MaxPosteriorSampling
from botorch.models.deterministic import FixedSingleSampleModel
from botorch.settings import validate_input_scaling
from botorch.utils.datasets import SupervisedDataset
Expand All @@ -45,6 +47,103 @@
logger: Logger = get_logger(__name__)


def _build_candidate_generation_options(
model_gen_options: dict[str, Any],
torch_opt_config: TorchOptConfig,
surrogate: Surrogate,
acqf: Acquisition,
) -> tuple[Tensor | None, Any | None]:
"""Build candidate_set and sampling_strategy for discrete candidate generation.

This helper function processes model_gen_options to construct:
- candidate_set: The discrete set of candidates to select from.
Any method that produces candidates (in-sample training data, or future
methods like pathwise TS optimization) feeds into this parameter.
The selection/weight-allocation logic downstream is agnostic to how
candidates were generated.
- sampling_strategy: The BoTorch SamplingStrategy to use for selection

Args:
model_gen_options: Dictionary containing candidate generation options:
- in_sample: If True, use training data as candidate_set
- sampling_strategy_class: Class of SamplingStrategy to instantiate
- sampling_strategy_kwargs: kwargs to pass to the strategy
torch_opt_config: Configuration containing objective weights
surrogate: The surrogate model providing training data
acqf: The Acquisition object providing the acquisition function

Returns:
A tuple of (candidate_set, sampling_strategy) where either or both may
be None if not configured in model_gen_options.

Note:
When in_sample=True without a sampling_strategy, candidates are selected
via optimize_acqf_discrete, which uses sequential greedy selection with
X_pending for q > 1. This requires the acquisition function to support
X_pending (e.g., qSimpleRegret, qLogNEI). Analytic acquisition functions
like PosteriorMean do not support X_pending and will only work for q=1.
"""
candidate_set = None
sampling_strategy = None

# Determine candidate set for in-sample generation
if model_gen_options.get("in_sample", False):
if surrogate.Xs:
candidate_set = surrogate.Xs[0]
else:
raise ValueError(
"in_sample=True requires training data, but no data is available."
)

# Build sampling strategy if requested
sampling_strategy_class = model_gen_options.get("sampling_strategy_class")
if sampling_strategy_class is not None:
strategy_kwargs = dict(model_gen_options.get("sampling_strategy_kwargs", {}))
# Extract num_samples to set on strategy instance after construction
# This allows acquisition.optimize() to use getattr() to retrieve it
num_samples = strategy_kwargs.pop("num_samples", None)

if issubclass(sampling_strategy_class, MaxPosteriorSampling):
# Thompson Sampling: sample from model posterior
# For minimization objectives (objective_weights=-1), we need to pass
# a posterior_transform that negates the values so MaxPosteriorSampling
# correctly finds the arm with lowest predicted value.

# Get objective weights - for minimization this is -1, maximization is 1
objective_weights = torch_opt_config.objective_weights
# Only use non-zero weights (actual objectives, not constraints)
obj_mask = objective_weights.nonzero().view(-1)
posterior_transform = ScalarizedPosteriorTransform(
weights=objective_weights[obj_mask]
)

sampling_strategy = sampling_strategy_class(
model=surrogate.model,
posterior_transform=posterior_transform,
**strategy_kwargs,
)
elif issubclass(sampling_strategy_class, BoltzmannSampling):
# Boltzmann Sampling: sample weighted by acquisition values
sampling_strategy = sampling_strategy_class(
acq_func=acqf.acqf,
**strategy_kwargs,
)
else:
# Generic SamplingStrategy - try to instantiate with provided kwargs
# User is responsible for providing appropriate kwargs
sampling_strategy = sampling_strategy_class(**strategy_kwargs)

# Set num_samples on the strategy instance if provided.
# This is used by select_from_candidate_set() to determine
# win-counting vs. direct selection mode. The attribute is ephemeral
# (not part of nn.Module state_dict) since the strategy is created
# fresh each gen() call and never serialized.
if num_samples is not None:
sampling_strategy.num_samples = num_samples

return candidate_set, sampling_strategy


class BoTorchGenerator(TorchGenerator, Base):
"""**All classes in 'botorch_modular' directory are under
construction, incomplete, and should be treated as alpha
Expand Down Expand Up @@ -281,6 +380,16 @@ def gen(
acqf = none_throws(self._acquisition)

botorch_rounding_func = get_rounding_func(torch_opt_config.rounding_func)

# Handle candidate generation via model_gen_options
model_gen_options = torch_opt_config.model_gen_options or {}
candidate_set, sampling_strategy = _build_candidate_generation_options(
model_gen_options=model_gen_options,
torch_opt_config=torch_opt_config,
surrogate=self.surrogate,
acqf=acqf,
)

candidates, expected_acquisition_value, weights = acqf.optimize(
n=n,
search_space_digest=search_space_digest,
Expand All @@ -293,6 +402,8 @@ def gen(
opt_options,
dict,
),
candidate_set=candidate_set,
sampling_strategy=sampling_strategy,
)
gen_metadata = self._get_gen_metadata_from_acqf(
acqf=acqf,
Expand Down
4 changes: 4 additions & 0 deletions ax/generators/torch/botorch_modular/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,10 @@ def construct_acquisition_and_optimizer_options(
Keys.OPTIMIZER_KWARGS.value,
Keys.ACQF_KWARGS.value,
Keys.AX_ACQUISITION_KWARGS.value,
# Keys for candidate generation
"in_sample",
"sampling_strategy_class",
"sampling_strategy_kwargs",
}
)
> 0
Expand Down
Loading