Skip to content

Commit

Permalink
Add no_sync context manager (#6675)
Browse files Browse the repository at this point in the history
Fix #1902

---------

Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
tjruwase and loadams authored Nov 14, 2024
1 parent 9a2c209 commit 9439058
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# Currently have dependency loops for the type hints.
InferenceModel = Type["InferenceModel"]
LayerContainer = Type["LayerContainer"]
LayerContainer = Type["LayerContainer"] # noqa: F811

MAPPING_KEY = "PARAM_MAPPING"
PLIST_HELPERS = "_ds_plist_strip_vals"
Expand Down Expand Up @@ -161,7 +161,7 @@ def __call__(cls, *args, **kwargs):
return instance


class LayerContainer(metaclass=LayerMetaclass):
class LayerContainer(metaclass=LayerMetaclass): # noqa: F811
"""
Abstract base class for containing model parameters.
Expand Down
37 changes: 30 additions & 7 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from contextlib import contextmanager

from typing import Callable, Dict, Union, Iterable, Container

Expand Down Expand Up @@ -216,6 +217,7 @@ def __init__(self,
self.loaded_checkpoint_mp_world_size = None
self.loaded_checkpoint_dp_world_size = None
self.enable_backward_allreduce = True
self.inside_no_sync_ctxt = False
self.progressive_layer_drop = None
self.eigenvalue = None
self.block_eigenvalue = None
Expand Down Expand Up @@ -1981,12 +1983,31 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
grads = None
self.buffered_allreduce_fallback(grads=grads, elements_per_buffer=bucket_size)

@contextmanager
def no_sync(self):
r"""
Context manager to disable gradient reduction during backward pass.
This context manager has the following effects on other DeepSpeed features.
1. Incompatible with ZeRO stage 2/3 which rely on reduction for gradient partitioning.
2. It is illegal to call engine.step() within the context manager.
3. Tracking of gradient accumulation steps is disabled.
"""
assert not self.zero_optimization_partition_gradients(), \
f"no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage {self.zero_optimization_stage()}"

assert not self.inside_no_sync_ctxt, f"no_sync context manager reentry is unsupported"

self.inside_no_sync_ctxt = True
try:
yield
finally:
self.inside_no_sync_ctxt = False

@instrument_w_nvtx
def backward(self, loss, allreduce_gradients=True, release_loss=False, retain_graph=False, scale_wrt_gas=True):
def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=True):
r"""Execute backward pass on the loss
Arguments:
loss: Torch tensor on which to execute backward propagation
allreduce_gradients: is deprecated, ignored, and will soon be removed'
retain_graph: bool, default: false
forward on user defined choice of retain_graph
"""
Expand All @@ -1996,11 +2017,10 @@ def backward(self, loss, allreduce_gradients=True, release_loss=False, retain_gr
if self.scale_wrt_gas is not None:
scale_wrt_gas = self.scale_wrt_gas

if not allreduce_gradients:
logger.warning(f"Argument `allreduce_gradients` is deprecated, ignored, and will soon be removed")
do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt

# scale loss w.r.t. gradient accumulation if needed
if self.gradient_accumulation_steps() > 1 and scale_wrt_gas:
# scale loss w.r.t. gradient accumulation if reduction is not disabled
if do_gradient_reduction and self.gradient_accumulation_steps() > 1 and scale_wrt_gas:
loss = self._scale_loss_by_gas(loss.float())

# Log training loss
Expand Down Expand Up @@ -2049,7 +2069,7 @@ def backward(self, loss, allreduce_gradients=True, release_loss=False, retain_gr

self._start_timers(self.engine_timers.backward_reduce_timers)

if allreduce_gradients and self.enable_backward_allreduce:
if do_gradient_reduction:
# Traditional code path that allreduces the module parameter grads
self.allreduce_gradients()

Expand Down Expand Up @@ -2185,6 +2205,9 @@ def step(self, lr_kwargs=None):
r"""Execute the weight update step after forward and backward propagation
on effective_train_batch.
"""
assert not self.inside_no_sync_ctxt, \
"It is illegal to call Engine.step() inside no_sync context manager"

see_memory_usage("Engine before step", force=self.memory_breakdown())

# Check early because self.global_steps is incremented at some point here.
Expand Down
5 changes: 0 additions & 5 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2297,11 +2297,6 @@ def load_state_dict(self,
def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):
self.load_hp_checkpoint_state_from_checkpoint_dir("bit16_groups", checkpoint_folder)

@property
def param_groups(self):
"""Forward the wrapped optimizer's parameters."""
return self.optimizer.param_groups

def _load_global_state(self, sd):
self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler)
self.dynamic_loss_scale = sd.get('dynamic_loss_scale', self.dynamic_loss_scale)
Expand Down
197 changes: 197 additions & 0 deletions tests/unit/runtime/test_no_sync_ctxt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import pytest

from contextlib import nullcontext
import torch

from unit.simple_model import SimpleModel, random_dataloader
from unit.common import DistributedTest

import deepspeed
import deepspeed.comm as dist
from deepspeed.utils import safe_get_full_grad


class TestNoSyncCtxt(DistributedTest):
world_size = 2

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
@pytest.mark.parametrize("zero_stage", [0, 1, 2, 3])
def test_zero_stage(self, zero_stage, dtype):
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-3
}
},
"zero_optimization": {
"stage": zero_stage,
},
}

invalid_cfg = zero_stage > 1
if dtype == torch.bfloat16:
config_dict["bf16"] = {"enabled": True}
elif dtype == torch.float16:
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}

hidden_dim = 64
total_samples = 32
model = SimpleModel(hidden_dim)
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
data_loader = random_dataloader(model=model,
total_samples=total_samples,
hidden_dim=hidden_dim,
device=model.device,
dtype=dtype)
dist.barrier()

with pytest.raises(AssertionError) if invalid_cfg else nullcontext() as assertinfo:
with model.no_sync():
for _, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
if invalid_cfg:
assert ("no_sync context manager is incompatible" in str(assertinfo))

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
@pytest.mark.parametrize("zero_stage", [0, 1])
def test_engine_step(self, zero_stage, dtype):
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-3
}
},
"zero_optimization": {
"stage": zero_stage,
},
}

if dtype == torch.bfloat16:
config_dict["bf16"] = {"enabled": True}
elif dtype == torch.float16:
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}

hidden_dim = 64
total_samples = 32
model = SimpleModel(hidden_dim)
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
data_loader = random_dataloader(model=model,
total_samples=total_samples,
hidden_dim=hidden_dim,
device=model.device,
dtype=dtype)
dist.barrier()

with model.no_sync():
for _, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
with pytest.raises(AssertionError) as assertinfo:
model.step()
assert ("It is illegal to call Engine.step() inside no_sync context manager" in str(assertinfo))

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
@pytest.mark.parametrize("zero_stage", [0, 1])
def test_multiple_ctxts(self, zero_stage, dtype):
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-3
}
},
"zero_optimization": {
"stage": zero_stage,
},
}

if dtype == torch.bfloat16:
config_dict["bf16"] = {"enabled": True}
elif dtype == torch.float16:
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}

hidden_dim = 64
total_samples = 32
model = SimpleModel(hidden_dim)
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
data_loader = random_dataloader(model=model,
total_samples=total_samples,
hidden_dim=hidden_dim,
device=model.device,
dtype=dtype)
dist.barrier()

param_list = list(model.parameters())
first_losses = []
first_grad_norms = []
with model.no_sync():
for _, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
first_losses.append(loss.item())
model.backward(loss)
grad_norm = sum([safe_get_full_grad(p).norm() for p in param_list])
first_grad_norms.append(grad_norm.item())

second_losses = []
second_grad_norms = []

model.zero_grad()
with model.no_sync():
for _, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
second_losses.append(loss.item())
model.backward(loss)
grad_norm = sum([safe_get_full_grad(p).norm() for p in param_list])
second_grad_norms.append(grad_norm.item())

assert len(first_losses) == len(second_losses)
for x, y in zip(first_losses, second_losses):
assert x == y

assert len(first_grad_norms) == len(second_grad_norms)
for x, y in zip(first_grad_norms, second_grad_norms):
assert x == y

def test_reentry(self):
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-3
}
},
"zero_optimization": {
"stage": 1,
},
}

hidden_dim = 64
model = SimpleModel(hidden_dim)
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
dist.barrier()

with model.no_sync():
with pytest.raises(AssertionError) as assertinfo:
with model.no_sync():
pass
assert ("no_sync context manager reentry is unsupported" in str(assertinfo))

0 comments on commit 9439058

Please sign in to comment.