Skip to content

Commit

Permalink
[beta] Enable compiling optimizer step (tested with AdamW) (#103)
Browse files Browse the repository at this point in the history
Also fixes single-node training on Augusta.
  • Loading branch information
epwalsh authored Nov 14, 2024
1 parent fdbb76e commit b9e9193
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 21 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Added `olmo_core.distributed.checkpoint.get_checkpoint_metadata()` function.
- (BETA) Added flag to compile the optimizer step. So far only tested with AdamW. May not work with other optimizers.

### Fixed

- Old ephemeral checkpoints won't be removed until after the latest ephemeral checkpoint is saved successfully.
- Made GCS uploads more robust.
- numpy.random.dirichlet() does not always sum to 1.0, so allow for a small tolerance in validating domain weights.
- Fixed single-node training on Google Augusta cluster.
- `numpy.random.dirichlet()` does not always sum to 1.0, so allow for a small tolerance in validating domain weights.

## [v1.6.2](https://github.com/allenai/OLMo-core/releases/tag/v1.6.2) - 2024-11-08

Expand Down
15 changes: 8 additions & 7 deletions src/olmo_core/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minut
elif "pluto" in get_node_hostname():
set_env_var("NCCL_IB_HCA", "^=mlx5_1,mlx5_2")
elif "augusta" in get_node_hostname():
# NOTE: For single-node training we still need all of these settings and we also
# need host networking enabled so that the ethernet interface names don't change.
set_env_var("NCCL_CROSS_NIC", "0")
set_env_var("NCCL_ALGO", "Ring,Tree")
set_env_var("NCCL_PROTO", "Simple")
Expand All @@ -68,11 +70,6 @@ def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minut
set_env_var("NCCL_FASTRAK_ENABLE_HOTPATH_LOGGING", "0")
set_env_var("NCCL_FASTRAK_PLUGIN_ACCEPT_TIMEOUT_MS", "600000")
set_env_var("NCCL_NVLS_ENABLE", "0")
set_env_var("NCCL_FASTRAK_CTRL_DEV", "enp0s12")
set_env_var(
"NCCL_FASTRAK_IFNAME",
"enp6s0,enp7s0,enp13s0,enp14s0,enp134s0,enp135s0,enp141s0,enp142s0",
)
set_env_var("NCCL_USE_SNAP", "1")
set_env_var("NCCL_FASTRAK_USE_LLCM", "1")
set_env_var("NCCL_FASTRAK_LLCM_DEVICE_DIRECTORY", "/dev/aperture_devices")
Expand All @@ -90,8 +87,12 @@ def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minut
"NCCL_SHIMNET_GUEST_CONFIG_CHECKER_CONFIG_FILE",
"/var/lib/tcpxo/lib64/a3plus_guest_config.textproto",
)
if multi_node:
set_env_var("NCCL_SOCKET_IFNAME", "enp0s12")
set_env_var("NCCL_FASTRAK_CTRL_DEV", "enp0s12")
set_env_var(
"NCCL_FASTRAK_IFNAME",
"enp6s0,enp7s0,enp13s0,enp14s0,enp134s0,enp135s0,enp141s0,enp142s0",
)
set_env_var("NCCL_SOCKET_IFNAME", "enp0s12")

if backend_supports_cuda(backend):
# Set CUDA device.
Expand Down
3 changes: 2 additions & 1 deletion src/olmo_core/launch/beaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,8 @@ def build_experiment_spec(self, torchrun: bool = True) -> ExperimentSpec:
command=["bash", "/olmo-core/entrypoint.sh"],
replicas=self.num_nodes if self.num_nodes > 1 else None,
leader_selection=self.num_nodes > 1,
host_networking=self.num_nodes > 1,
host_networking=self.num_nodes > 1
or any(["augusta" in cluster for cluster in self.clusters]),
propagate_failure=True if self.num_nodes > 1 else None,
propagate_preemption=True if self.num_nodes > 1 else None,
synchronized_start_timeout="90m" if self.num_nodes > 1 else None,
Expand Down
32 changes: 28 additions & 4 deletions src/olmo_core/optim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ..config import Config
from ..exceptions import OLMoConfigurationError
from ..utils import get_default_device, move_to_device

__all__ = [
"OptimConfig",
Expand Down Expand Up @@ -45,6 +46,21 @@ class OptimConfig(Config, Generic[Opt], metaclass=ABCMeta):
Use this to pull out groups parameters into a separate param groups with their own options.
"""

compile: bool = False
"""
Compile the optimizer step.
.. warning::
Optimizer step compilation is still in beta and may not work with some optimizers.
You could also see unexpected behavior and very poor performance when turning this feature
on in the middle of a run that was previously trained without compiling the optimizer
due to the LR being restored to a float instead of a tensor.
"""

@property
def device(self) -> torch.device:
return get_default_device()

def build_groups(self, model: nn.Module) -> Union[Iterable[torch.Tensor], List[Dict[str, Any]]]:
"""
Build parameters groups.
Expand Down Expand Up @@ -108,20 +124,28 @@ def build(self, model: nn.Module) -> Opt:
"""
kwargs = self.as_dict()
kwargs.pop("group_overrides")
kwargs.pop("compile")

optim = self.optimizer()(self.build_groups(model), **kwargs)

# Set 'lr' and 'initial_lr' in each group if needed.
for group in optim.param_groups:
# Set 'initial_lr' in each group for schedulers if needed.
if "initial_lr" in group:
continue

lr: Optional[float] = None
if "lr" in group:
lr = group["lr"]
elif hasattr(self, "lr"):
lr = getattr(self, "lr")

if lr is not None:
if self.compile:
# 'lr' should be a tensor.
group["lr"] = move_to_device(torch.tensor(lr), self.device)
else:
group["lr"] = lr
group.setdefault("initial_lr", lr)

if self.compile:
log.info("Compiling optimizer step...")
optim.step = torch.compile(optim.step)

return optim
23 changes: 16 additions & 7 deletions src/olmo_core/optim/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from math import cos, pi
from typing import Optional
from typing import Optional, Union

import torch


@dataclass
Expand All @@ -14,7 +16,9 @@ class Scheduler(metaclass=ABCMeta):
initial_lr_field: str = "initial_lr"

@abstractmethod
def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
def get_lr(
self, initial_lr: Union[float, torch.Tensor], step: int, max_steps: int
) -> Union[float, torch.Tensor]:
"""
Get the learning rate for a step given the initial/max learning rate and the maximum
number of steps.
Expand All @@ -28,7 +32,9 @@ class ConstantScheduler(Scheduler):
Constant learning rate schedule, basically a no-op.
"""

def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
def get_lr(
self, initial_lr: Union[float, torch.Tensor], step: int, max_steps: int
) -> Union[float, torch.Tensor]:
del step, max_steps
return initial_lr

Expand All @@ -44,7 +50,9 @@ class CosWithWarmup(Scheduler):
t_max: Optional[int] = None
warmup_min_lr: float = 0.0

def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
def get_lr(
self, initial_lr: Union[float, torch.Tensor], step: int, max_steps: int
) -> Union[float, torch.Tensor]:
max_steps = max_steps if self.t_max is None else self.t_max
eta_min = initial_lr * self.alpha_f
if step < self.warmup_steps:
Expand All @@ -58,7 +66,8 @@ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:


def _linear_warmup(
initial_lr: float, step: int, warmup_steps: int, warmup_min_lr: float = 0.0
) -> float:
assert 0 <= warmup_min_lr < initial_lr
initial_lr: Union[float, torch.Tensor], step: int, warmup_steps: int, warmup_min_lr: float = 0.0
) -> Union[float, torch.Tensor]:
if isinstance(initial_lr, float): # not worth the potential host-device sync if it's a tensor
assert 0 <= warmup_min_lr < initial_lr
return warmup_min_lr + (initial_lr - warmup_min_lr) * min(step, warmup_steps) / warmup_steps
13 changes: 12 additions & 1 deletion src/olmo_core/train/callbacks/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from dataclasses import dataclass, field

import torch

from olmo_core.optim.scheduler import ConstantScheduler, Scheduler

from .callback import Callback
Expand All @@ -23,11 +25,20 @@ def pre_optim_step(self):
f"'{initial_lr_field}' not found in optimizer param group"
)

# Ensure 'initial_lr' is set.
if group.get(self.scheduler.initial_lr_field) is None:
group[self.scheduler.initial_lr_field] = group["lr"]
group[self.scheduler.lr_field] = self.scheduler.get_lr(

# Set new LR.
new_lr = self.scheduler.get_lr(
group[self.scheduler.initial_lr_field], self.step, self.trainer.max_steps
)

if isinstance(current_lr := group.get(self.scheduler.lr_field), torch.Tensor):
current_lr.fill_(new_lr)
else:
group[self.scheduler.lr_field] = new_lr

self.trainer.record_metric(
f"optim/LR (group {group_idx})", group[self.scheduler.lr_field]
)

0 comments on commit b9e9193

Please sign in to comment.