diff --git a/optimas/generators/ax/developer/multitask.py b/optimas/generators/ax/developer/multitask.py index 32cb49d6..4274d539 100644 --- a/optimas/generators/ax/developer/multitask.py +++ b/optimas/generators/ax/developer/multitask.py @@ -58,7 +58,6 @@ from optimas.generators.ax.base import AxGenerator from optimas.core import ( - TrialParameter, Task, Trial, TrialStatus, @@ -154,9 +153,6 @@ class AxMultitaskGenerator(AxGenerator): VOCS object defining variables, objectives, constraints, and observables. lofi_task, hifi_task : Task The low- and high-fidelity tasks. - analyzed_parameters : list of Parameter, optional - List of parameters to analyze at each trial, but which are not - optimization objectives. By default ``None``. use_cuda : bool, optional Whether to allow the generator to run on a CUDA GPU. By default ``False``. @@ -190,16 +186,6 @@ def __init__( model_save_period: Optional[int] = 5, model_history_dir: Optional[str] = "model_history", ) -> None: - - # As trial parameters these get written to history array - # Ax trial_index and arm toegther locate a point - # Multiple points (Optimas trials) can share the same Ax trial_index - # vocs interface note: These are not part of vocs. They are only stored - # to allow keeping track of them from previous runs. - custom_trial_parameters = [ - TrialParameter("arm_name", "ax_arm_name", dtype="U32"), - TrialParameter("ax_trial_id", "ax_trial_index", dtype=int), - ] self._check_inputs(vocs, lofi_task, hifi_task) super().__init__( @@ -210,7 +196,6 @@ def __init__( save_model=save_model, model_save_period=model_save_period, model_history_dir=model_history_dir, - custom_trial_parameters=custom_trial_parameters, ) self.lofi_task = lofi_task self.hifi_task = hifi_task @@ -226,6 +211,10 @@ def __init__( self.gr_lofi = None self._experiment = self._create_experiment() + # Internal mapping: _id -> (arm_name, ax_trial_id, trial_type) + self._id_mapping = {} + self._next_id = 0 + def get_gen_specs( self, sim_workers: int, run_params: Dict, sim_max: int ) -> Dict: @@ -285,11 +274,22 @@ def suggest(self, num_points: Optional[int]) -> List[dict]: if trial_param.name == "trial_type": point[trial_param.name] = trial_type - point["ax_trial_id"] = trial_index - point["arm_name"] = arm.name + # Generate unique _id and store mapping + current_id = self._next_id + self._id_mapping[current_id] = { + "ax_trial_id": trial_index, + "arm_name": arm.name, + } + point["_id"] = current_id + self._next_id += 1 points.append(point) return points + def _get_trial_mapping(self, gen_id: int) -> Tuple[int, str]: + """Get mapping information for a trial gen_id.""" + mapping = self._id_mapping[gen_id] + return mapping["ax_trial_id"], mapping["arm_name"] + def ingest(self, results: List[dict]) -> None: """Incorporate evaluated trials into experiment.""" # reconstruct Optimas trials @@ -304,32 +304,40 @@ def ingest(self, results: List[dict]) -> None: ) trials.append(trial) + # Apply _id mapping to all trials before processing + for trial in trials: + if trial.gen_id is not None: + if trial.gen_id not in self._id_mapping: + raise ValueError( + f"Trial has _id={trial.gen_id} which is not recognized by this generator." + ) + trial.ax_trial_id, trial.arm_name = self._get_trial_mapping( + trial.gen_id + ) + if self.gen_state == NOT_STARTED: self._incorporate_external_data(trials) else: self._complete_evaluations(trials) def _incorporate_external_data(self, trials: List[Trial]) -> None: - """Incorporate external data (e.g., from history) into experiment.""" - # Get trial indices. - trial_indices = [] - for trial in trials: - trial_indices.append(trial.ax_trial_id) - trial_indices = np.unique(np.array(trial_indices)) - - # Group trials by index. - grouped_trials = {} - for index in trial_indices: - grouped_trials[index] = [] + """Incorporate external data (e.g., from history) into experiment. + + Unknown/external points have no gen_id. We create new arms and add + observations directly to the experiment, then let the model use them + as if starting fresh. + """ + # Group by trial_type (default to hifi if not specified) + grouped_by_type = {} for trial in trials: - grouped_trials[trial.ax_trial_id].append(trial) - - # Add trials to experiment. - for index in trial_indices: - # Get all trials with current index. - trials_i = grouped_trials[index] - trial_type = trials_i[0].trial_type - # Create arms. + trial_type = getattr(trial, "trial_type", self.hifi_task.name) + if trial_type not in grouped_by_type: + grouped_by_type[trial_type] = [] + grouped_by_type[trial_type].append(trial) + + param_to_name = {} + arm_count = 0 + for trial_type, trials_i in grouped_by_type.items(): arms = [] for trial in trials_i: params = {} @@ -337,7 +345,15 @@ def _incorporate_external_data(self, trials: List[Trial]) -> None: trial.varying_parameters, trial.parameter_values ): params[var.name] = val - arms.append(Arm(parameters=params, name=trial.arm_name)) + arm = Arm(parameters=params) + if arm.signature not in param_to_name: + param_to_name[arm.signature] = f"external_{arm_count}" + arm_count += 1 + arms.append( + Arm(parameters=params, name=param_to_name[arm.signature]) + ) + self._next_id += 1 + # Create new batch trial. gr = GeneratorRun(arms=arms, weights=[1.0] * len(arms)) ax_trial = self._experiment.new_batch_trial( @@ -345,19 +361,20 @@ def _incorporate_external_data(self, trials: List[Trial]) -> None: ) ax_trial.run() # Incorporate observations. - for trial in trials_i: + for i, trial in enumerate(trials_i): + arm_name = ax_trial.arms[i].name if trial.status != TrialStatus.FAILED: objective_eval = {} oe = trial.objective_evaluations[0] objective_eval["f"] = (oe.value, oe.sem) - ax_trial.run_metadata[trial.arm_name] = objective_eval + ax_trial.run_metadata[arm_name] = objective_eval else: - ax_trial.mark_arm_abandoned(trial.arm_name) + ax_trial.mark_arm_abandoned(arm_name) # Mark batch trial as completed. ax_trial.mark_completed() # Keep track of high-fidelity trials. if trial_type == self.hifi_task.name: - self.hifi_trials.append(index) + self.hifi_trials.append(ax_trial.index) def _complete_evaluations(self, trials: List[Trial]) -> None: """Complete evaluated trials."""