Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Ax multitask generator for Ax > 0.4.0 #1508

Merged
merged 8 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/extra.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ jobs:
conda env update --file install/gen_deps_environment.yml

- name: Install gpcam
if: matrix.python-version != '3.12'
if: matrix.python-version <= '3.13'
run: |
pip install gpcam

Expand Down
1 change: 1 addition & 0 deletions install/gen_deps_environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ dependencies:
- mumps-mpi
- DFO-LS
- mpmath
- ax-platform
2 changes: 1 addition & 1 deletion install/ubuntu_no312.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
gpcam==8.1.6
gpcam==8.1.12
scikit-build==0.18.1
packaging==24.1
git+https://github.com/bandframework/surmise.git
126 changes: 116 additions & 10 deletions libensemble/gen_funcs/persistent_ax_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,41 +8,145 @@
This `gen_f` is meant to be used with the `alloc_f` function
`only_persistent_gens`

This test currently requires ax-platform<=0.4.0
Ax notes:
Each arm = a set of simulation inputs (a sim_id)
Each trial = a batch of simulations.
The metric = the recorded simulation output (f) that Ax optimizes.
Ax runner handles the execution of trials - AxRunner wraps Runner to use libE tr.run()

"""

import os
from copy import deepcopy
from typing import Optional
from pyre_extensions import assert_is_instance
import warnings

import numpy as np
import pandas as pd
import torch

from ax import Metric, Runner
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.multi_type_experiment import MultiTypeExperiment
from ax.core.objective import Objective
from ax.core.observation import ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.exceptions.core import AxParameterWarning
from ax.modelbridge.factory import get_sobol
from ax.modelbridge.registry import Models, ST_MTGP_trans
from ax.modelbridge.torch import TorchModelBridge
from ax.modelbridge.transforms.convert_metric_names import tconfig_from_mt_experiment
from ax.storage.metric_registry import register_metrics
from ax.runners import SyntheticRunner
from ax.storage.json_store.save import save_experiment
from ax.storage.metric_registry import register_metric
from ax.storage.runner_registry import register_runner
from ax.utils.common.result import Ok

try:
from ax.modelbridge.factory import get_MTGP
# For Ax >= 0.5.0
from ax.modelbridge.transforms.derelativize import Derelativize
from ax.modelbridge.transforms.convert_metric_names import ConvertMetricNames
from ax.modelbridge.transforms.trial_as_task import TrialAsTask
from ax.modelbridge.transforms.stratified_standardize_y import StratifiedStandardizeY
from ax.modelbridge.transforms.task_encode import TaskChoiceToIntTaskChoice
from ax.modelbridge.registry import MBM_X_trans
MT_MTGP_trans = MBM_X_trans + [
Derelativize,
ConvertMetricNames,
TrialAsTask,
StratifiedStandardizeY,
TaskChoiceToIntTaskChoice,
]

except ImportError:
# For Ax >= 0.3.4
from ax.modelbridge.factory import get_MTGP_LEGACY as get_MTGP
# For Ax < 0.5.0
from ax.modelbridge.registry import MT_MTGP_trans

from libensemble.message_numbers import EVAL_GEN_TAG, FINISHED_PERSISTENT_GEN_TAG, PERSIS_STOP, STOP_TAG
from libensemble.tools.persistent_support import PersistentSupport

__all__ = ["persistent_gp_mt_ax_gen_f"]

warnings.filterwarnings(
"ignore",
message="`cache_root` is only supported for GPyTorchModels",
category=RuntimeWarning,
)

warnings.filterwarnings(
"ignore",
message="Changing `is_ordered` to `True` for `ChoiceParameter`",
category=AxParameterWarning,
)


# get_MTGP based on https://ax.dev/docs/tutorials/multi_task/
def get_MTGP(
experiment: Experiment,
data: Data,
search_space: Optional[SearchSpace] = None,
trial_index: Optional[int] = None,
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.double,
) -> TorchModelBridge:
"""Instantiates a Multi-task Gaussian Process (MTGP) model that generates
points with EI.

If the input experiment is a MultiTypeExperiment then a
Multi-type Multi-task GP model will be instantiated.
Otherwise, the model will be a Single-type Multi-task GP.
"""

if isinstance(experiment, MultiTypeExperiment):
trial_index_to_type = {
t.index: t.trial_type for t in experiment.trials.values()
}
transforms = MT_MTGP_trans
transform_configs = {
"TrialAsTask": {"trial_level_map": {"trial_type": trial_index_to_type}},
"ConvertMetricNames": tconfig_from_mt_experiment(experiment),
}
else:
# Set transforms for a Single-type MTGP model.
transforms = ST_MTGP_trans
transform_configs = None

# Choose the status quo features for the experiment from the selected trial.
# If trial_index is None, we will look for a status quo from the last
# experiment trial to use as a status quo for the experiment.
if trial_index is None:
trial_index = len(experiment.trials) - 1
elif trial_index >= len(experiment.trials):
raise ValueError("trial_index is bigger than the number of experiment trials")

status_quo = experiment.trials[trial_index].status_quo
if status_quo is None:
status_quo_features = None
else:
status_quo_features = ObservationFeatures(
parameters=status_quo.parameters,
trial_index=trial_index, # pyre-ignore[6]
)

return assert_is_instance(
Models.ST_MTGP(
experiment=experiment,
search_space=search_space or experiment.search_space,
data=data,
transforms=transforms,
transform_configs=transform_configs,
torch_dtype=dtype,
torch_device=device,
status_quo_features=status_quo_features,
),
TorchModelBridge,
)


def persistent_gp_mt_ax_gen_f(H, persis_info, gen_specs, libE_info):
"""
Expand Down Expand Up @@ -99,6 +203,7 @@ def persistent_gp_mt_ax_gen_f(H, persis_info, gen_specs, libE_info):
optimization_config=opt_config,
)

# hifi_task has been added as default but we need to add lofi task and link them.
exp.add_trial_type(lofi_task, ax_runner)
exp.add_tracking_metric(metric=lofi_objective, trial_type=lofi_task, canonical_name="hifi_metric")

Expand Down Expand Up @@ -143,7 +248,7 @@ def persistent_gp_mt_ax_gen_f(H, persis_info, gen_specs, libE_info):

# But launch them at low fidelity.
tr = exp.new_batch_trial(trial_type=lofi_task, generator_run=gr)
tr.run()
tr.run() # Runs sims via libE (see AxRunner.run below)
tr.mark_completed()
tag = tr.run_metadata["tag"]
if tag in [STOP_TAG, PERSIS_STOP]:
Expand All @@ -159,7 +264,7 @@ def persistent_gp_mt_ax_gen_f(H, persis_info, gen_specs, libE_info):
# Select max-utility points from the low fidelity batch to generate a high fidelity batch.
gr = max_utility_from_GP(n=n_opt_hifi, m=m, gr=gr, hifi_task=hifi_task)
tr = exp.new_batch_trial(trial_type=hifi_task, generator_run=gr)
tr.run()
tr.run() # Runs sims via libE (see AxRunner.run below)
tr.mark_completed()
tag = tr.run_metadata["tag"]
if tag in [STOP_TAG, PERSIS_STOP]:
Expand All @@ -171,7 +276,9 @@ def persistent_gp_mt_ax_gen_f(H, persis_info, gen_specs, libE_info):
if not os.path.exists("model_history"):
os.mkdir("model_history")
# Register metric and runner in order to be able to save to json.
_, encoder_registry, decoder_registry = register_metric(AxMetric)
_, encoder_registry, decoder_registry = register_metrics(
{AxMetric: None}
)
_, encoder_registry, decoder_registry = register_runner(
AxRunner,
encoder_registry=encoder_registry,
Expand Down Expand Up @@ -224,9 +331,8 @@ def run(self, trial):
for j in range(n_param):
param_array[j] = params[f"x{j}"]
H_o["x"][i] = param_array
H_o["resource_sets"][i] = 1
H_o["resource_sets"][i] = 1 # one is default but could be diff for hi/lo
H_o["task"][i] = task

tag, Work, calc_in = self.ps.send_recv(H_o)

trial_metadata["tag"] = tag
Expand Down
1 change: 0 additions & 1 deletion libensemble/tests/regression_tests/test_gpCAM.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# TESTSUITE_COMMS: mpi local
# TESTSUITE_NPROCS: 4
# TESTSUITE_EXTRA: true
# TESTSUITE_EXCLUDE: true

import sys
import warnings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,24 @@
Example of multi-fidelity optimization using a persistent GP gen_func (calling
Ax).

Execute via one of the following commands (e.g. 5 workers):
mpiexec -np 5 python test_persistent_gp_multitask_ax.py
python test_persistent_gp_multitask_ax.py --nworkers 4 --comms local
python test_persistent_gp_multitask_ax.py --nworkers 4 --comms tcp
Test is set to use the gen_on_manager option (persistent generator runs on
a thread). Therefore nworkers is the number of simulation workers.

Execute via one of the following commands:
mpiexec -np 4 python test_persistent_gp_multitask_ax.py
python test_persistent_gp_multitask_ax.py --nworkers 3 --comms local
python test_persistent_gp_multitask_ax.py --nworkers 3 --comms tcp

When running with the above commands, the number of concurrent evaluations of
the objective function will be 3, as one of the three workers will be the
persistent generator.
the objective function will be 3.

Requires numpy<2.
"""

# Do not change these lines - they are parsed by run-tests.sh
# TESTSUITE_COMMS: local mpi
# TESTSUITE_NPROCS: 5
# TESTSUITE_NPROCS: 4
# TESTSUITE_EXTRA: true
# TESTSUITE_OS_SKIP: OSX
# TESTSUITE_EXCLUDE: true

import warnings

Expand Down Expand Up @@ -50,6 +50,7 @@ def run_simulation(H, persis_info, sim_specs, libE_info):
z = 8
elif task == "cheap_model":
z = 1
print('in sim', task)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want this print to remain? (It's fine with me, but I wanted to point it out.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did originally leave that print in by accident but then I thought it's actually kind of useful to see that it's doing both models.


libE_output = np.zeros(1, dtype=sim_specs["out"])
calc_status = WORKER_DONE
Expand All @@ -63,6 +64,7 @@ def run_simulation(H, persis_info, sim_specs, libE_info):
# Main block is necessary only when using local comms with spawn start method (default on macOS and Windows).
if __name__ == "__main__":
nworkers, is_manager, libE_specs, _ = parse_args()
libE_specs["gen_on_manager"] = True

mt_params = {
"name_hifi": "expensive_model",
Expand Down
3 changes: 2 additions & 1 deletion libensemble/tests/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,8 @@ def run_regression_tests(root_dir, python_exec, args, current_os):
user_comms_list = ["mpi", "local", "tcp"]

print_heading(f"Running regression tests (comms: {', '.join(user_comms_list)})")
build_forces(root_dir) # Build forces.x before running tests
if not REG_LIST_TESTS_ONLY:
build_forces(root_dir) # Build forces.x before running tests

reg_test_list = REG_TEST_LIST
reg_test_files = []
Expand Down
Loading