Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Allow optimas to run with latest Ax version #239

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ dependencies:
- pip
- pip:
- -e ..
- ax-platform == 0.4.0
- ax-platform >= 0.4.0
- autodoc_pydantic >= 2.0.1
- ipykernel
- matplotlib
Expand Down
75 changes: 69 additions & 6 deletions optimas/generators/ax/developer/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@

import numpy as np
import torch
from packaging import version

from ax.version import version as ax_version
from ax.core.arm import Arm
from ax.core.batch_trial import BatchTrial
from ax.core.multi_type_experiment import MultiTypeExperiment
Expand All @@ -22,8 +20,15 @@
from ax.core.observation import ObservationFeatures
from ax.core.generator_run import GeneratorRun
from ax.storage.json_store.save import save_experiment
from ax.storage.metric_registry import register_metric
from ax.modelbridge.factory import get_MTGP_LEGACY as get_MTGP
from ax.storage.metric_registry import register_metrics

from ax.modelbridge.registry import Models, MT_MTGP_trans
from ax.core.experiment import Experiment
from ax.core.data import Data
from ax.modelbridge.transforms.convert_metric_names import (
tconfig_from_mt_experiment,
)
from ax.utils.common.typeutils import checked_cast

from optimas.generators.ax.base import AxGenerator
from optimas.core import (
Expand All @@ -37,13 +42,69 @@
)
from .ax_metric import AxMetric


# Define generator states.
NOT_STARTED = "not_started"
LOFI_RETURNED = "lofi_returned"
HIFI_RETURNED = "hifi_returned"


# get_MTGP is not part of the Ax codebase, as of Ax 0.4.1, due to this PR:
# https://github.com/facebook/Ax/pull/2508
# Here we use `get_MTGP` from https://ax.dev/tutorials/multi_task.html#4b.-Multi-task-Bayesian-optimization
def get_MTGP(
experiment: Experiment,
data: Data,
search_space: Optional[SearchSpace] = None,
trial_index: Optional[int] = None,
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.double,
) -> TorchModelBridge:
"""Instantiates a Multi-task Gaussian Process (MTGP) model that generates
points with EI.
"""
trial_index_to_type = {
t.index: t.trial_type for t in experiment.trials.values()
}
transforms = MT_MTGP_trans
transform_configs = {
"TrialAsTask": {"trial_level_map": {"trial_type": trial_index_to_type}},
"ConvertMetricNames": tconfig_from_mt_experiment(experiment),
}

# Choose the status quo features for the experiment from the selected trial.
# If trial_index is None, we will look for a status quo from the last
# experiment trial to use as a status quo for the experiment.
if trial_index is None:
trial_index = len(experiment.trials) - 1
elif trial_index >= len(experiment.trials):
raise ValueError(
"trial_index is bigger than the number of experiment trials"
)

status_quo = experiment.trials[trial_index].status_quo
if status_quo is None:
status_quo_features = None
else:
status_quo_features = ObservationFeatures(
parameters=status_quo.parameters,
trial_index=trial_index, # pyre-ignore[6]
)

return checked_cast(
TorchModelBridge,
Models.ST_MTGP(
experiment=experiment,
search_space=search_space or experiment.search_space,
data=data,
transforms=transforms,
transform_configs=transform_configs,
torch_dtype=dtype,
torch_device=device,
status_quo_features=status_quo_features,
),
)


class AxMultitaskGenerator(AxGenerator):
"""Multitask Bayesian optimization using the Ax developer API.

Expand Down Expand Up @@ -307,7 +368,9 @@ def _create_experiment(self) -> MultiTypeExperiment:
)

# Register metric in order to be able to save experiment to json file.
_, encoder_registry, decoder_registry = register_metric(AxMetric)
_, encoder_registry, decoder_registry = register_metrics(
{AxMetric, None}
)
self._encoder_registry = encoder_registry
self._decoder_registry = decoder_registry

Expand Down
2 changes: 1 addition & 1 deletion optimas/generators/ax/service/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def _tell(self, trials: List[Trial]) -> None:
# i.e., min trials).
if isinstance(tc, (MinTrials, MaxTrials)):
tc.threshold -= 1
generation_strategy._maybe_move_to_next_step()
generation_strategy._maybe_transition_to_next_node()
finally:
if trial.ignored:
continue
Expand Down
3 changes: 0 additions & 3 deletions optimas/generators/ax/service/single_fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,6 @@ def _create_generation_steps(
else:
# Use a SAAS model with qNEI acquisition function.
MODEL_CLASS = Models.FULLYBAYESIAN
# Disable additional logs from fully Bayesian model.
bo_model_kwargs["disable_progbar"] = True
bo_model_kwargs["verbose"] = False
else:
if len(self.objectives) > 1:
# Use a model with qNEHVI acquisition function.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ test = [
'flake8',
'pytest',
'pytest-mpi',
'ax-platform == 0.4.0',
'ax-platform >= 0.4.0',
'matplotlib',
]
all = [
Expand Down
Loading