diff --git a/doc/environment.yaml b/doc/environment.yaml index 90e70aba..a72082a4 100644 --- a/doc/environment.yaml +++ b/doc/environment.yaml @@ -5,7 +5,7 @@ dependencies: - pip - pip: - -e .. - - ax-platform == 0.4.0 + - ax-platform > 0.4.0 - autodoc_pydantic >= 2.0.1 - ipykernel - matplotlib diff --git a/optimas/generators/ax/developer/multitask.py b/optimas/generators/ax/developer/multitask.py index de4c76ab..0e04fdfd 100644 --- a/optimas/generators/ax/developer/multitask.py +++ b/optimas/generators/ax/developer/multitask.py @@ -6,9 +6,7 @@ import numpy as np import torch -from packaging import version -from ax.version import version as ax_version from ax.core.arm import Arm from ax.core.batch_trial import BatchTrial from ax.core.multi_type_experiment import MultiTypeExperiment @@ -22,8 +20,15 @@ from ax.core.observation import ObservationFeatures from ax.core.generator_run import GeneratorRun from ax.storage.json_store.save import save_experiment -from ax.storage.metric_registry import register_metric -from ax.modelbridge.factory import get_MTGP_LEGACY as get_MTGP +from ax.storage.metric_registry import register_metrics + +from ax.modelbridge.registry import Models, MT_MTGP_trans +from ax.core.experiment import Experiment +from ax.core.data import Data +from ax.modelbridge.transforms.convert_metric_names import ( + tconfig_from_mt_experiment, +) +from ax.utils.common.typeutils import checked_cast from optimas.generators.ax.base import AxGenerator from optimas.core import ( @@ -37,13 +42,69 @@ ) from .ax_metric import AxMetric - # Define generator states. NOT_STARTED = "not_started" LOFI_RETURNED = "lofi_returned" HIFI_RETURNED = "hifi_returned" +# get_MTGP is not part of the Ax codebase, as of Ax 0.4.1, due to this PR: +# https://github.com/facebook/Ax/pull/2508 +# Here we use `get_MTGP` from https://ax.dev/tutorials/multi_task.html#4b.-Multi-task-Bayesian-optimization +def get_MTGP( + experiment: Experiment, + data: Data, + search_space: Optional[SearchSpace] = None, + trial_index: Optional[int] = None, + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.double, +) -> TorchModelBridge: + """Instantiates a Multi-task Gaussian Process (MTGP) model that generates + points with EI. + """ + trial_index_to_type = { + t.index: t.trial_type for t in experiment.trials.values() + } + transforms = MT_MTGP_trans + transform_configs = { + "TrialAsTask": {"trial_level_map": {"trial_type": trial_index_to_type}}, + "ConvertMetricNames": tconfig_from_mt_experiment(experiment), + } + + # Choose the status quo features for the experiment from the selected trial. + # If trial_index is None, we will look for a status quo from the last + # experiment trial to use as a status quo for the experiment. + if trial_index is None: + trial_index = len(experiment.trials) - 1 + elif trial_index >= len(experiment.trials): + raise ValueError( + "trial_index is bigger than the number of experiment trials" + ) + + status_quo = experiment.trials[trial_index].status_quo + if status_quo is None: + status_quo_features = None + else: + status_quo_features = ObservationFeatures( + parameters=status_quo.parameters, + trial_index=trial_index, # pyre-ignore[6] + ) + + return checked_cast( + TorchModelBridge, + Models.ST_MTGP( + experiment=experiment, + search_space=search_space or experiment.search_space, + data=data, + transforms=transforms, + transform_configs=transform_configs, + torch_dtype=dtype, + torch_device=device, + status_quo_features=status_quo_features, + ), + ) + + class AxMultitaskGenerator(AxGenerator): """Multitask Bayesian optimization using the Ax developer API. @@ -307,7 +368,9 @@ def _create_experiment(self) -> MultiTypeExperiment: ) # Register metric in order to be able to save experiment to json file. - _, encoder_registry, decoder_registry = register_metric(AxMetric) + _, encoder_registry, decoder_registry = register_metrics( + {AxMetric: None} + ) self._encoder_registry = encoder_registry self._decoder_registry = decoder_registry diff --git a/optimas/generators/ax/import_error_dummy_generator.py b/optimas/generators/ax/import_error_dummy_generator.py index 0e1a3215..436476dd 100644 --- a/optimas/generators/ax/import_error_dummy_generator.py +++ b/optimas/generators/ax/import_error_dummy_generator.py @@ -12,5 +12,5 @@ def __init__(self, *args, **kwargs) -> None: raise RuntimeError( "You need to install ax-platform, in order " "to use Ax-based generators in optimas.\n" - "e.g. with `pip install ax-platform >= 0.4.0`" + "e.g. with `pip install ax-platform > 0.4.0`" ) diff --git a/optimas/generators/ax/service/base.py b/optimas/generators/ax/service/base.py index 3d2bf42b..02ec2b8c 100644 --- a/optimas/generators/ax/service/base.py +++ b/optimas/generators/ax/service/base.py @@ -205,7 +205,7 @@ def _tell(self, trials: List[Trial]) -> None: # i.e., min trials). if isinstance(tc, (MinTrials, MaxTrials)): tc.threshold -= 1 - generation_strategy._maybe_move_to_next_step() + generation_strategy._maybe_transition_to_next_node() finally: if trial.ignored: continue diff --git a/optimas/generators/ax/service/single_fidelity.py b/optimas/generators/ax/service/single_fidelity.py index 3ec93dc3..c5179a22 100644 --- a/optimas/generators/ax/service/single_fidelity.py +++ b/optimas/generators/ax/service/single_fidelity.py @@ -141,9 +141,6 @@ def _create_generation_steps( else: # Use a SAAS model with qNEI acquisition function. MODEL_CLASS = Models.FULLYBAYESIAN - # Disable additional logs from fully Bayesian model. - bo_model_kwargs["disable_progbar"] = True - bo_model_kwargs["verbose"] = False else: if len(self.objectives) > 1: # Use a model with qNEHVI acquisition function. diff --git a/pyproject.toml b/pyproject.toml index 1c61e28d..8a82eb3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,11 +34,11 @@ test = [ 'flake8', 'pytest', 'pytest-mpi', - 'ax-platform == 0.4.0', + 'ax-platform > 0.4.0', 'matplotlib', ] all = [ - 'ax-platform == 0.4.0', + 'ax-platform > 0.4.0', ] [project.urls]