Skip to content

[GENERAL SUPPORT]: How to use a specific acquisition function from BoTorch? #4833

@Yangliu-SY

Description

@Yangliu-SY

Question

I wanted to use the acquisition function LogConstrainedExpectedImprovement from botorch.acquisition.analytic. I tried to follow the example in modular_botorch.ipynb. But I got an error

TypeError: construct_inputs_logcei() missing 1 required positional argument: 'objective_index'

I asked Copilot and the LLMs seemed stuck on older version of Ax with 'ax.models' and 'ax.modelbridge'. I searched the issues on Github but some links to documentation were broken, e.g. https://ax.dev/docs/tutorials/modular_botax.html#appendix-2-default-surrogate-models-and-acquisition-functions

How can I use this acquisition function from botorch.acquisition.analytic? It seems some of the acquisition functions from botorch.acquisition.analytic are not usable with Ax because Ax doesn't call the input constructors with the right arguments.

Please provide any relevant code snippet if applicable.

# queue_mmc.py

import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Compute the value of Erlang C with log-domain to avoid overflow
def erlang_c_log(arrival_rate, service_rate, num_servers):
    rho = arrival_rate / (num_servers * service_rate)
    if rho >= 1:
        return 1.0  # System is unstable

    log_sum_terms = np.logaddexp.reduce([n * np.log(arrival_rate / service_rate) - math.lgamma(n + 1) for n in range(num_servers)])
    log_last_term = (num_servers * np.log(arrival_rate / service_rate)) - math.lgamma(num_servers + 1) - np.log(1 - rho)
    log_P0 = -np.logaddexp(log_sum_terms, log_last_term)

    log_Pw = log_last_term + log_P0
    Pw = np.exp(log_Pw)
    return Pw

# Compute the response time in an M/M/c queue using Erlang C function above
def mmc_response_time(arrival_rate, service_rate, num_servers):
    rho = arrival_rate / (num_servers * service_rate)
    if rho >= 1:
        return 9999.0  # System is unstable

    Pw = erlang_c_log(arrival_rate, service_rate, num_servers)
    Wq = (Pw * (1 / service_rate)) / (num_servers * (1 - rho))
    W = Wq + (1 / service_rate)
    return W

# Main codes

import random

from ax.api.client import Client
from ax.api.configs import RangeParameterConfig
from ax.api.protocols.metric import IMetric

import queue_mmc

service_rate = 100  # service rate (requests per second per server)
demand_per_job = 20  # requests per job

# Suppress Ax logging and warnings
import logging
import warnings

from ax.utils.common.logger import set_ax_logger_levels

set_ax_logger_levels(logging.CRITICAL)

# Suppress warnings
warnings.filterwarnings('ignore')

# ## Step 2: Initialize the Client

client = Client()

# ## Step 3: Configure the Experiment

parameters = [
    RangeParameterConfig(
        name="num_servers", parameter_type="int", bounds=(1, 100)
    ),
    RangeParameterConfig(
        name="arrival_rate_given", parameter_type="float", bounds=(1000, 9000)
    ),
]

client.configure_experiment(parameters=parameters, name="Server Optimization Experiment", )

# ## Step 4: Configure Optimization

client.configure_optimization(objective="-number_of_servers", outcome_constraints=["response_time<=0.011"])
from ax.generation_strategy.center_generation_node import CenterGenerationNode
from ax.generation_strategy.transition_criterion import MinTrials
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.generation_strategy.generation_node import GenerationNode
from ax.generation_strategy.generator_spec import GeneratorSpec
from ax.adapter.registry import Generators

def construct_generation_strategy(
    generator_spec: GeneratorSpec, node_name: str,
) -> GenerationStrategy:
    """Constructs a Center + Sobol + Modular BoTorch `GenerationStrategy`
    using the provided `generator_spec` for the Modular BoTorch node.
    """
    botorch_node = GenerationNode(
        name=node_name,
        generator_specs=[generator_spec],
    )
    sobol_node = GenerationNode(
        name="Sobol",
        generator_specs=[
            GeneratorSpec(
                generator_enum=Generators.SOBOL,
                # Let's use model_kwargs to set the random seed.
                model_kwargs={"seed": 0},
            ),
        ],
        transition_criteria=[
            # Transition to BoTorch node once there are 2 trials on the experiment.
            MinTrials(
                threshold=2,
                transition_to=botorch_node.name,
                use_all_trials_in_exp=True,
            )
        ]
    )
    # Center node is a customized node that uses a simplified logic and has a
    # built-in transition criteria that transitions after generating once.
    center_node = CenterGenerationNode(next_node_name=sobol_node.name)
    return GenerationStrategy(
        name=f"Center+Sobol+{node_name}",
        nodes=[center_node, sobol_node, botorch_node]
    )

from botorch.acquisition.analytic import LogConstrainedExpectedImprovement

generator_spec = GeneratorSpec(
    generator_enum=Generators.BOTORCH_MODULAR,
    model_kwargs={
        "botorch_acqf_class": LogConstrainedExpectedImprovement,
    },
)

generation_strategy = construct_generation_strategy(
    generator_spec=generator_spec,
    node_name="BoTorch w/ Log CEI",
)

client.set_generation_strategy(
    generation_strategy=generation_strategy,
)

# ## Step 5: Run Trials

def evaluation(num_servers, arrival_rate_given):
    rt = queue_mmc.mmc_response_time(arrival_rate_given, service_rate, num_servers)
    return {"number_of_servers": (num_servers, 0.0), "response_time": (rt, 0.0)}

for time_index in range(20):
    arrival_rate_job = 50 + 20 * random.randint(0, 20)
    arrival_rate = arrival_rate_job * demand_per_job
    print(f"Time index {time_index}, arrival rate: {arrival_rate}")

    for _ in range(20):
        trials = client.get_next_trials(max_trials=1, fixed_parameters={"arrival_rate_given": arrival_rate})

        for trial_index, parameters in trials.items():
            num_servers = parameters["num_servers"]
            arrival_rate_given = parameters["arrival_rate_given"]

            result = evaluation(num_servers, arrival_rate_given)
            print(f"Trial {trial_index} with num_servers={num_servers} for {arrival_rate_given} resulted in {result}")

            # Complete the trial with the result
            client.complete_trial(trial_index=trial_index, raw_data=result)

Code of Conduct

  • I agree to follow this Ax's Code of Conduct

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions