Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use one param coordinator for both train/inference scenarios #6662

Merged
merged 16 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 26 additions & 31 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,20 +133,28 @@ def __init__(
self.persistent_parameters = self.mark_persistent_parameters(self.param_numel_persistence_threshold,
self.model_persistence_threshold)

self.param_coordinators = {}
self._prefetch_bucket_sz = int(prefetch_bucket_size)
self._max_reuse_distance_in_numel = int(max_reuse_distance)
self._max_available_parameters_in_numel = int(max_live_parameters)
self.__allgather_stream = None if get_accelerator().is_synchronized_device() else get_accelerator().Stream(
) if overlap_comm else get_accelerator().default_stream()

if not hasattr(module, "ds_inflight_param_registry"):
module.ds_inflight_param_registry = dict()
# we need two registries, one for training and one for eval. They will be used when creating PartitionedParameterCoordinator
module.ds_inflight_param_registry[True] = InflightParamRegistry()
module.ds_inflight_param_registry[False] = InflightParamRegistry()
module.ds_inflight_param_registry = InflightParamRegistry()
self.__inflight_param_registry = module.ds_inflight_param_registry

self.param_coordinator = PartitionedParameterCoordinator(
prefetch_bucket_sz=self._prefetch_bucket_sz,
max_reuse_distance_in_numel=self._max_reuse_distance_in_numel,
max_available_parameters_in_numel=self._max_available_parameters_in_numel,
allgather_stream=self.__allgather_stream,
inflight_param_registry=self.__inflight_param_registry,
prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme,
timers=self.timers,
zero_quantized_weights=self.zero_quantized_weights,
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights,
)

self.forward_hooks = []
self.backward_hooks = []
self.setup_zero_stage3_hooks()
Expand All @@ -161,26 +169,13 @@ def partition_all_parameters(self):
"""Partitioning Parameters that were not partitioned usually if parameters
of modules whose input parameters do not require grad computation do not
trigger post call and will therefore will remain unpartitioned"""
self.get_param_coordinator(training=self.module.training).release_and_reset_all(self.module)
self.get_param_coordinator().release_and_reset_all(self.module)
for param in iter_params(self.module, recurse=True):
if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
raise RuntimeError(f"{param.ds_summary()} expected to be released")

def get_param_coordinator(self, training):
if not training in self.param_coordinators:
self.param_coordinators[training] = PartitionedParameterCoordinator(
prefetch_bucket_sz=self._prefetch_bucket_sz,
max_reuse_distance_in_numel=self._max_reuse_distance_in_numel,
max_available_parameters_in_numel=self._max_available_parameters_in_numel,
allgather_stream=self.__allgather_stream,
inflight_param_registry=self.__inflight_param_registry[training],
prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme,
timers=self.timers,
zero_quantized_weights=self.zero_quantized_weights,
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights,
)

return self.param_coordinators[training]
def get_param_coordinator(self):
return self.param_coordinator

def empty_partition_cache(self):
self.partition_all_parameters()
Expand Down Expand Up @@ -228,14 +223,14 @@ def setup_zero_stage3_hooks(self):

#reset step if in inference mode
@instrument_w_nvtx
def _end_of_forward_hook(module, *args):
def _start_of_forward_hook(module, *args):

self.get_param_coordinator().reset_step()

if not torch._C.is_grad_enabled():
self.get_param_coordinator(training=False).reset_step()
self.module.register_forward_pre_hook(_start_of_forward_hook)

#likely one of them should be enough but just to be safe
self._register_hooks_recursively(self.module)
self.module.register_forward_hook(_end_of_forward_hook)

# Add top module to stack trace
global FWD_MODULE_STACK
Expand Down Expand Up @@ -447,7 +442,7 @@ def pre_sub_module_forward_function(self, sub_module):
global FWD_MODULE_STACK
FWD_MODULE_STACK.append(sub_module)

param_coordinator = self.get_param_coordinator(training=sub_module.training)
param_coordinator = self.get_param_coordinator()
param_coordinator.trace_prologue(sub_module)
if param_coordinator.is_record_trace():
param_coordinator.record_module(sub_module)
Expand All @@ -460,29 +455,29 @@ def post_sub_module_forward_function(self, sub_module):
see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release",
force=False)

param_coordinator = self.get_param_coordinator(training=sub_module.training)
param_coordinator = self.get_param_coordinator()
param_coordinator.release_sub_module(sub_module)

see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release",
force=False)

@torch.no_grad()
def pre_sub_module_backward_function(self, sub_module):
assert sub_module.training, "backward pass is invalid for module in evaluation mode"
param_coordinator = self.get_param_coordinator(training=True)
# assert sub_module.training, "backward pass is invalid for module in evaluation mode"
param_coordinator = self.get_param_coordinator()
param_coordinator.trace_prologue(sub_module)
if param_coordinator.is_record_trace():
param_coordinator.record_module(sub_module)
param_coordinator.fetch_sub_module(sub_module, forward=False)

@torch.no_grad()
def post_sub_module_backward_function(self, sub_module):
assert sub_module.training, "backward pass is invalid for module in evaluation mode"
# assert sub_module.training, "backward pass is invalid for module in evaluation mode"
see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
force=False)

self.get_param_coordinator(training=True).release_sub_module(sub_module)
self.get_param_coordinator().release_sub_module(sub_module)

see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
Expand Down
12 changes: 11 additions & 1 deletion deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id
from deepspeed.accelerator import get_accelerator
import deepspeed.runtime.compiler as compiler
from deepspeed.runtime.compiler import is_compiling

import logging

Expand Down Expand Up @@ -92,7 +93,7 @@ def __init__(
# keeps track of the number of submodules invoked so far.
self.__step_id: int = 0
# network tracing mode
self.__trace_mode: ZeRoTraceMode = ZeRoTraceMode.RECORD
self.__trace_mode: ZeRoTraceMode = ZeRoTraceMode.INVALID
# sequence of submodules/parameters in forward pass + backward pass
self.__submodule_order: Iterable[Module] = []
self.__param_order: Iterable[__class__.__ParamInTrace] = []
Expand Down Expand Up @@ -188,13 +189,18 @@ def trace_prologue(self, sub_module: Module) -> None:
@compiler.disable
def record_module(self, sub_module: Module) -> None:
"""adds sub module to trace"""
if is_compiling():
return

if not self.is_record_trace():
raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}")

self.__submodule_order.append(sub_module)
self.__step_id_module_fetched_for[sub_module.id].append(self.__step_id)

def record_parameters(self, sub_module: Module) -> None:
if is_compiling():
return
"""adds sub module to trace"""
if not self.is_record_trace():
raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}")
Expand All @@ -209,8 +215,12 @@ def construct_parameter_trace_from_module_trace(self):
for sub_module in self.__submodule_order:
self.record_parameters(sub_module)

@compiler.disable
def reset_step(self) -> None:
"""indicate that we have completed one fwd+bwd for the model"""
if is_compiling():
return

self._clean_inflight_param_registry()

if not self.is_complete_trace(): # not self.trace_complete:
Expand Down
8 changes: 3 additions & 5 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,8 +593,8 @@ def defragment(tensors: List[Tensor]) -> Tensor:

return device_buffer

def _get_param_coordinator(self, training):
return self.parameter_offload.get_param_coordinator(training)
def _get_param_coordinator(self):
return self.parameter_offload.get_param_coordinator()

def _configure_offloading(self, offload_optimizer_config, offload_param_config):
###################### offload optimizer setup ##################################
Expand Down Expand Up @@ -1874,7 +1874,7 @@ def _pre_step(self):
see_memory_usage(f"In step before checking overflow", force=False)

print_rank_0("Finished Tracing at Beginning of Step")
self._get_param_coordinator(training=True).hierarchy = 0
self._get_param_coordinator().hierarchy = 0

print_rank_0("Finished Tracing at Beginning of Step")

Expand Down Expand Up @@ -2258,8 +2258,6 @@ def backward(self, loss, retain_graph=False):
else:
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)

self._get_param_coordinator(training=True).reset_step()

if self.swap_optimizer:
self.optimizer_swapper.post_backward()

Expand Down
45 changes: 45 additions & 0 deletions tests/unit/runtime/zero/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -1628,3 +1628,48 @@ def test_empty_param_groups(self, dtype, use_client_optimizer, empty_weight_grou
optimizer=optimizer,
config=config_dict,
)


class TestZero3SwitchModes(DistributedTest):
world_size = 2

@pytest.mark.parametrize("prefetch_ratio", [0.0, 0.5, 1.0])
def test(self, prefetch_ratio, zero_stage=3):

hidden_dim = 10
model = SimpleModel(hidden_dim)

prefetch_bucket_size = int(sum([p.numel() for p in model.parameters(recurse=True)]) * prefetch_ratio)
config_dict = {
"train_micro_batch_size_per_gpu": 2,
"gradient_accumulation_steps": 2,
"zero_optimization": {
"stage": zero_stage,
"stage3_prefetch_bucket_size": prefetch_bucket_size
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-3
}
},
"fp16": {
"enabled": True,
"initial_scale_power": 8
}
}

model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters())
data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device)

for _ in range(3):
model.train()
for batch in data_loader:
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()

model.eval()
with torch.no_grad():
for batch in data_loader:
loss = model(batch[0], batch[1])