Skip to content

Commit

Permalink
Allow running on Augusta cluster (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Nov 5, 2024
1 parent c7c3a5a commit bec0a3c
Show file tree
Hide file tree
Showing 9 changed files with 226 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Added `retries` field to `BeakerLaunchConfig`.
- Allow running on Augusta cluster with existing train scripts.

## [v1.6.0](https://github.com/allenai/OLMo-core/releases/tag/v1.6.0) - 2024-11-01

Expand Down
1 change: 1 addition & 0 deletions src/olmo_core/data/numpy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ def _read_chunk_from_array(self, path: PathOrStr, index: int) -> torch.Tensor:
def _get_file_size_and_length(
self, path: PathOrStr, idx: int, dtype: Optional[NumpyUIntTypes] = None
) -> Tuple[int, int]:
del idx
dtype = dtype or self.dtype
item_size = dtype(0).itemsize
file_size = get_file_size(path)
Expand Down
33 changes: 25 additions & 8 deletions src/olmo_core/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minut

# Set host-specific env var defaults.
if _running_in_beaker():
multi_node = int(os.environ.get(OLMO_NUM_NODES_ENV_VAR, "1")) > 1
# See https://beaker-docs.apps.allenai.org/experiments/distributed-training.html
if "jupiter" in get_node_hostname():
set_env_var("NCCL_IB_HCA", "^=mlx5_bond_0")
if int(os.environ.get(OLMO_NUM_NODES_ENV_VAR, "1")) > 1:
if multi_node:
# Only for multi-node
set_env_var("NCCL_SOCKET_IFNAME", "ib")
elif "pluto" in get_node_hostname():
Expand All @@ -68,11 +69,25 @@ def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minut
"NCCL_FASTRAK_IFNAME",
"enp6s0,enp7s0,enp13s0,enp14s0,enp134s0,enp135s0,enp141s0,enp142s0",
)
set_env_var("NCCL_SOCKET_IFNAME", "enp0s12")
set_env_var("NCCL_USE_SNAP", "1")
set_env_var("NCCL_FASTRAK_USE_LLCM", "1")

validate_env_vars()
set_env_var("NCCL_FASTRAK_LLCM_DEVICE_DIRECTORY", "/dev/aperture_devices")
# NOTE: This path var must be set prior to launching Python
# set_env_var(
# "LD_LIBRARY_PATH",
# "/var/lib/tcpxo/lib64:" + os.environ.get("LD_LIBRARY_PATH", ""),
# override=True,
# )
set_env_var("NCCL_TUNER_PLUGIN", "libnccl-tuner.so")
set_env_var(
"NCCL_TUNER_CONFIG_PATH", "/var/lib/tcpxo/lib64/a3plus_tuner_config.textproto"
)
set_env_var(
"NCCL_SHIMNET_GUEST_CONFIG_CHECKER_CONFIG_FILE",
"/var/lib/tcpxo/lib64/a3plus_guest_config.textproto",
)
if multi_node:
set_env_var("NCCL_SOCKET_IFNAME", "enp0s12")

if backend_supports_cuda(backend):
# Set CUDA device.
Expand All @@ -83,6 +98,8 @@ def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minut

dist.init_process_group(backend, timeout=timeout)

validate_env_vars()


def validate_env_vars():
"""
Expand All @@ -94,12 +111,12 @@ def validate_env_vars():
if OLMO_LOCAL_RANK_ENV_VAR not in os.environ:
raise OLMoEnvironmentError(f"Missing env var '{OLMO_LOCAL_RANK_ENV_VAR}'")

if (
os.environ.get(OLMO_SHARED_FS_ENV_VAR) != "1"
and os.environ.get(OLMO_FS_LOCAL_RANK_ENV_VAR) is None
if os.environ.get(OLMO_SHARED_FS_ENV_VAR) != "1" and (
os.environ.get(OLMO_FS_LOCAL_RANK_ENV_VAR) is None
and os.environ.get(OLMO_LOCAL_RANK_ENV_VAR) is None
):
raise OLMoEnvironmentError(
f"Missing env var '{OLMO_FS_LOCAL_RANK_ENV_VAR}' for non-shared filesystem. "
f"Missing env var '{OLMO_FS_LOCAL_RANK_ENV_VAR}'/'{OLMO_LOCAL_RANK_ENV_VAR}' for non-shared filesystem. "
f"If this is a shared filesystem you can set '{OLMO_SHARED_FS_ENV_VAR}=1' instead."
)

Expand Down
6 changes: 4 additions & 2 deletions src/olmo_core/internal/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def build_common_components(
if "jupiter" in cluster:
root_dir = "/weka/oe-training-default/ai2-llm"
weka_buckets.append(BeakerWekaBucket("oe-training-default", "/weka/oe-training-default"))
elif "augusta" in cluster:
root_dir = "gs://ai2-llm"

beaker_user = (Beaker.from_env().account.whoami().name).upper()
cmd_to_launch = SubCmd.train
Expand Down Expand Up @@ -181,7 +183,7 @@ def build_common_components(
name=VSLCurriculumType.grow_p2, num_cycles=8, balanced=False
),
work_dir=(
None
"./dataset-cache"
if is_url(root_dir)
else f"{root_dir}/checkpoints/{beaker_user.lower()}/dataset-cache"
),
Expand All @@ -205,7 +207,7 @@ def build_common_components(
sequence_length=dataset_config.effective_sequence_length,
tokenizer=tokenizer_config,
work_dir=(
None
"./dataset-cache"
if is_url(root_dir)
else f"{root_dir}/checkpoints/{beaker_user.lower()}/dataset-cache"
),
Expand Down
1 change: 1 addition & 0 deletions src/olmo_core/launch/beaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def build_experiment_spec(self, torchrun: bool = True) -> ExperimentSpec:
entrypoint_script = [
"#!/usr/bin/env bash",
"set -exuo pipefail",
"[[ -d /var/lib/tcpxo/lib64 ]] && export LD_LIBRARY_PATH=/var/lib/tcpxo/lib64:$LD_LIBRARY_PATH",
"mkdir -p /olmo-core-runtime",
"cd /olmo-core-runtime",
*self.setup_steps,
Expand Down
4 changes: 2 additions & 2 deletions src/olmo_core/train/callbacks/callback.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Dict
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional

import torch

Expand Down Expand Up @@ -177,7 +177,7 @@ class CallbackConfig(Callback, Config):
"""

@abstractmethod
def build(self, trainer: "Trainer") -> Callback:
def build(self, trainer: "Trainer") -> Optional[Callback]:
"""
Build the actual :class:`Callback`.
"""
Expand Down
12 changes: 10 additions & 2 deletions src/olmo_core/train/callbacks/evaluator_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,12 @@ class LMEvaluatorCallbackConfig(CallbackConfig):
eval_batch_size: Optional[int] = None
eval_duration: Duration = field(default_factory=lambda: Duration.epochs(1))
log_interval: int = 5
enabled: bool = True

def build(self, trainer: "Trainer") -> Optional[Callback]:
if not self.enabled:
return None

def build(self, trainer: "Trainer") -> Callback:
eval_batch_size = (
self.eval_batch_size
if self.eval_batch_size is not None
Expand Down Expand Up @@ -227,8 +231,12 @@ class DownstreamEvaluatorCallbackConfig(CallbackConfig):
eval_interval: int = 1000
eval_duration: Duration = field(default_factory=lambda: Duration.epochs(1))
log_interval: int = 5
enabled: bool = True

def build(self, trainer: "Trainer") -> Optional[Callback]:
if not self.enabled:
return None

def build(self, trainer: "Trainer") -> Callback:
from olmo_eval import HFTokenizer

global_eval_batch_size = (
Expand Down
3 changes: 2 additions & 1 deletion src/olmo_core/train/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def build(

for cb_name, cb_config in callback_configs.items():
cb = cb_config.build(trainer)
trainer.add_callback(cb_name, cb)
if cb is not None:
trainer.add_callback(cb_name, cb)

return trainer
180 changes: 180 additions & 0 deletions src/scripts/train/all_reduce_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
"""
Run an all-reduce benchmark. Run this script without any arguments to see usage info.
"""

from __future__ import annotations

import logging
import os
import sys
from dataclasses import dataclass
from typing import List

import torch
import torch.distributed as dist

from olmo_core.config import Config, StrEnum
from olmo_core.distributed.utils import get_local_rank, get_world_size
from olmo_core.launch.beaker import BeakerLaunchConfig, OLMoCoreBeakerImage
from olmo_core.train import prepare_training_environment, teardown_training_environment
from olmo_core.utils import generate_uuid, prepare_cli_environment

log = logging.getLogger(__name__)


TRIALS = 5

# these emulate the payload which will become a M * N * 4-sized tensor below
N = 500000
M = 2000


class SubCmd(StrEnum):
launch = "launch"
run = "run"
dry_run = "dry_run"

def prepare_environment(self):
if self in (SubCmd.launch, SubCmd.dry_run):
prepare_cli_environment()
elif self == SubCmd.run:
prepare_training_environment()
else:
raise NotADirectoryError(self)

def execute(self, config: BenchmarkConfig):
log.info(config)
if self == SubCmd.launch:
config.launch.launch(follow=True)
elif self == SubCmd.dry_run:
pass
elif self == SubCmd.run:
try:
# Show env vars for debugging.
for var_name in sorted(os.environ.keys()):
var_val = os.environ[var_name]
log.info(f"Env var {var_name} set to '{var_val}'")

mat = torch.rand(N, M, dtype=torch.float32).cuda(get_local_rank())

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

# do a few warm up iterations
for i in range(2):
timed_allreduce(mat, start_event, end_event)

# real benchmark
algbw_gather = []
for i in range(TRIALS):
log.info(f"{i+1}")
algbw_gather += timed_allreduce(mat, start_event, end_event)

algbw = torch.mean(torch.stack(algbw_gather))

# the 2*(n-1)/n busbw correction factor specific to all-reduce is explained here:
# https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#allreduce
# busbw reflects how optimally the hardware is used
n = dist.get_world_size()
busbw = algbw * (2 * (n - 1) / n)

log.info(
f"The average bandwidth of all_reduce with a {M*N*4/1e9}GB payload ({TRIALS} trials, {n} ranks):\n"
f"algbw: {algbw/1e9:.3f} GBps ({algbw*8/1e9:.1f} Gbps)\n"
f"busbw: {busbw/1e9:.3f} GBps ({busbw*8/1e9:.1f} Gbps)\n"
)
finally:
teardown_training_environment()
else:
raise NotADirectoryError(self)


@dataclass
class BenchmarkConfig(Config):
launch: BeakerLaunchConfig


def build_config(script: str, run_name: str, cluster: str, overrides: List[str]) -> BenchmarkConfig:
launch_config = BeakerLaunchConfig(
name=f"{run_name}-{generate_uuid()[:8]}",
budget="ai2/oe-training",
cmd=[script, SubCmd.run, run_name, cluster, *overrides],
task_name="benchmark",
workspace="ai2/OLMo-core",
clusters=[cluster],
beaker_image=OLMoCoreBeakerImage.nightly, # some features require nightly at the moment
num_nodes=1,
num_gpus=8,
allow_dirty=False,
setup_steps=[
# Clone repo.
'git clone "$REPO_URL" .',
'git checkout "$GIT_REF"',
"git submodule update --init --recursive",
# Setup python environment.
"conda shell.bash activate base",
"pip install -e '.[all]'",
"pip freeze",
],
)

return BenchmarkConfig(launch=launch_config).merge(overrides)


def timed_allreduce(mat, start_event, end_event):
dist.barrier()
start_event.record()
dist.all_reduce(mat)
end_event.record()

torch.cuda.synchronize()
duration = start_event.elapsed_time(end_event) / 1000

size = M * N * 4 # 4 is 4 bytes in fp32
# note that this is following the same math as NVIDIA/nccl-tests
algbw = torch.tensor([size / duration]).cuda(get_local_rank())

# calculate mean across all ranks
dist.reduce(algbw, dst=0, op=dist.ReduceOp.SUM)
algbw /= get_world_size()

return algbw


def main():
usage = f"""
[yellow]Usage:[/] [i blue]python[/] [i cyan]{sys.argv[0]}[/] [i b magenta]{'|'.join(SubCmd)}[/] [i b]RUN_NAME CLUSTER[/] [i][OVERRIDES...][/]
[b]Subcommands[/]
[b magenta]launch:[/] Launch the benchmark on Beaker with the [b magenta]run[/] subcommand.
[b magenta]run:[/] Run the benchmark. You usually shouldn't invoke the script with this subcommand directly.
Instead use [b magenta]launch[/] or run it with torchrun.
[b magenta]dry_run:[/] Pretty print the config and exit.
[b]Examples[/]
$ [i]python {sys.argv[0]} {SubCmd.launch} run01 ai2/pluto-cirrascale --launch.num_nodes=2[/]
""".strip()

if len(sys.argv) < 4 or sys.argv[1] not in set(SubCmd):
import rich

rich.get_console().print(usage, highlight=False)
sys.exit(1)

script, cmd, run_name, cluster, *overrides = sys.argv

cmd = SubCmd(cmd)
cmd.prepare_environment()

config = build_config(
script,
run_name,
cluster,
overrides,
)

cmd.execute(config)


if __name__ == "__main__":
main()

0 comments on commit bec0a3c

Please sign in to comment.