diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 05bb23e8ddd9..8fe503c9f422 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -811,6 +811,9 @@ def zero_max_reuse_distance(self): def zero_prefetch_bucket_size(self): return self._config.zero_config.prefetch_bucket_size + def zero_module_granularity_threshold(self): + return self._config.zero_config.module_granularity_threshold + def zero_param_persistence_threshold(self): return self._config.zero_config.param_persistence_threshold @@ -1611,6 +1614,7 @@ def _configure_zero_optimizer(self, optimizer): zero_param_parallel_group=zero_param_parallel_group, zero_quantized_weights=self.zero_quantized_weights(), zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(), + zero_module_granularity_threshold=self.zero_module_granularity_threshold(), ) else: log_dist( @@ -1657,6 +1661,7 @@ def _configure_zero_optimizer(self, optimizer): zero_hpz_partition_size=self.zero_hpz_partition_size(), zero_quantized_weights=self.zero_quantized_weights(), zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(), + zero_module_granularity_threshold=self.zero_module_granularity_threshold(), ) else: diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 1cfcd784e2ce..7cac7e3c1ce7 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -21,6 +21,7 @@ "stage3_max_live_parameters" : 1000000000, "stage3_max_reuse_distance" : 1000000000, "stage3_use_all_reduce_for_fetch_params": [true|false], + "stage3_module_granularity_threshold": 0, "allgather_partitions": [true|false], "use_multi_rank_bucket_allreduce": [true|false], "allgather_bucket_size": 500000000, @@ -245,6 +246,14 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): this option is enabled and then saves the fp16 model weights. """ + module_granularity_threshold: int = Field(pp_int(0), alias="stage3_module_granularity_threshold") + """ + The granularity of a module is determined by the ratio of "parameter_count / (1 + descendant count)". + ZeRO3 classifies modules with a granularity below the threshold as fine-grained, + which are treated as integral units during parameter fetching. This reduces host overhead + and the separate allgather overhead introduced by hooks for fine-grained layers when fetching parameters. + """ + use_all_reduce_for_fetch_params: bool = Field(False, alias="stage3_use_all_reduce_for_fetch_params") """ Use all_reduce op when fetching module parameters at stage3. This improves performance by reducing diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 4b0ddb7679a9..082d7e874e4d 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -6,7 +6,7 @@ import sys import torch from collections import OrderedDict -from deepspeed.utils import z3_leaf_module +from deepspeed.utils import z3_leaf_module, set_z3_leaf_module from deepspeed.runtime.utils import see_memory_usage from deepspeed.runtime.zero.utils import apply_to_tensors_only, is_zero_param from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum @@ -14,6 +14,7 @@ from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.partitioned_param_coordinator import PartitionedParameterCoordinator, InflightParamRegistry, iter_params from deepspeed.accelerator import get_accelerator +from deepspeed import utils FWD_MODULE_STACK = list() @@ -101,6 +102,7 @@ def __init__( zero_param_parallel_group=None, zero_quantized_weights=False, zero_quantized_nontrainable_weights=False, + zero_module_granularity_threshold=0, ): see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=True) @@ -155,8 +157,16 @@ def __init__( zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights, ) + if zero_module_granularity_threshold > 0: + self.min_granularity_value = sys.maxsize + self.min_granularity_layer = None + self.granularity_info = set() + self.z3_leaf_layers = [] + self._set_z3_leaf_modules_by_threshold(module, zero_module_granularity_threshold) + self.forward_hooks = [] self.backward_hooks = [] + self.setup_zero_stage3_hooks() print_rank_0( f'Created module hooks: forward = {len(self.forward_hooks)}, backward = {len(self.backward_hooks)}', @@ -482,3 +492,82 @@ def post_sub_module_backward_function(self, sub_module): see_memory_usage( f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", force=False) + + def _set_z3_leaf_modules_by_threshold(self, module, zero_module_granularity_threshold): + + self._get_granularity_recursively(module) + print_rank_0(f"{'MODULE NAME'.ljust(30)}|{'GRANULARITY VALUE'.rjust(20)}", force=True) + for granularity in self.granularity_info: + print_rank_0(granularity, force=True) + + if self.min_granularity_value <= zero_module_granularity_threshold: + self._set_leaf_by_threshold_preorder(module, zero_module_granularity_threshold) + utils.logger.info( + f"z3_leaf_module was set by stage3_module_granularity_threshold:{zero_module_granularity_threshold}") + for layer in self.z3_leaf_layers: + print_rank_0(f"{layer.__class__.__name__}:{layer.ds_model_granularity}", force=True) + else: + utils.logger.warning( + f"The smallest module granularity is [{self.min_granularity_layer}:{self.min_granularity_value}]. "\ + f"To make stage3_module_granularity_threshold effective, you need to set stage3_module_granularity_threshold >= {self.min_granularity_value}. "\ + f"Current Value:{zero_module_granularity_threshold}" + ) + + def _get_granularity_recursively(self, module): + """This function is used to recursively obtain the granularity of each module.""" + + # avoid setting as leaf for particularly large models, even if the granularity is very small + # an oversized leaf module increases the number of live parameters, introducing memory overhead + Z3_MAX_LEAF_SIZE = 1e9 + + if not list(module.parameters()): + # skip Modules without parameters, such as GELU, etc. + module.ds_model_granularity = sys.maxsize + return 0, 0 + + num_layers = 0 + num_params = 0 + num_params += sum(p.ds_numel for p in module.parameters(recurse=False)) + if not any(module.children()): + # torch leaf module + module.ds_model_granularity = sys.maxsize + return 1, num_params + + for child in module.children(): + layers_in_child, params_in_child = self._get_granularity_recursively(child) + num_layers += layers_in_child + num_params += params_in_child + + if module.__class__.__name__ in torch.nn.modules.container.__all__: + # Do not set container modules like ModuleList as leaf modules + # as this will prevent hooks from being set on their children + # and they may do not invoke the forward method + module.ds_model_granularity = sys.maxsize + return num_layers, num_params + + num_layers += 1 + ds_model_granularity = (num_params // num_layers) if num_params <= Z3_MAX_LEAF_SIZE else sys.maxsize + module.ds_model_granularity = ds_model_granularity + # module.ds_model_num_layers = num_layers + # module.ds_model_num_params = num_params + if self.min_granularity_value > ds_model_granularity: + self.min_granularity_value = ds_model_granularity + self.min_granularity_layer = module.__class__.__name__ + self.granularity_info.add(f"{module.__class__.__name__.ljust(30)}|{str(ds_model_granularity).rjust(20)}") + + return num_layers, num_params + + def _set_leaf_by_threshold_preorder(self, module, granularity_treshhold): + '''Set modules as leaf modules based on the threshold, prioritizing parent nodes.''' + + num_params = sum(p.ds_numel for p in module.parameters()) + if num_params == 0: + # skip Modules without parameters, such as GELU, etc. + return + if module.ds_model_granularity <= granularity_treshhold: + set_z3_leaf_module(module, True) + self.z3_leaf_layers.append(module) + return + + for sub_module in module.children(): + self._set_leaf_by_threshold_preorder(sub_module, granularity_treshhold) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 2c0c9d498d13..04d52319ae8c 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -157,6 +157,7 @@ def __init__( zero_hpz_partition_size=1, zero_quantized_weights=False, zero_quantized_nontrainable_weights=False, + zero_module_granularity_threshold=0, ): see_memory_usage("Stage 3 initialize beginning", force=True) @@ -227,7 +228,8 @@ def __init__( mpu=mpu, zero_param_parallel_group=zero_param_parallel_group, zero_quantized_weights=zero_quantized_weights, - zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights) + zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights, + zero_module_granularity_threshold=zero_module_granularity_threshold) self.persistent_parameters = self.parameter_offload.persistent_parameters self._configure_offloading(offload_optimizer_config, offload_param_config) @@ -458,6 +460,7 @@ def initialize_ds_offload( zero_param_parallel_group, zero_quantized_weights, zero_quantized_nontrainable_weights, + zero_module_granularity_threshold, ): return DeepSpeedZeRoOffload(module=module, timers=timers, @@ -473,7 +476,8 @@ def initialize_ds_offload( mpu=mpu, zero_param_parallel_group=zero_param_parallel_group, zero_quantized_weights=zero_quantized_weights, - zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights) + zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights, + zero_module_granularity_threshold=zero_module_granularity_threshold) def _get_trainable_parameter_groups(self): param_groups = [] diff --git a/deepspeed/utils/__init__.py b/deepspeed/utils/__init__.py index c6a202d485eb..983e64642c69 100644 --- a/deepspeed/utils/__init__.py +++ b/deepspeed/utils/__init__.py @@ -16,7 +16,7 @@ from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state, safe_set_full_grad from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_grad, safe_set_local_optimizer_state -from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter +from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter, set_z3_leaf_module from .mixed_precision_linkage import link_hp_params, lazy_init_hp_params_optimizer_state from deepspeed.runtime.dataloader import RepeatingLoader from .numa import get_numactl_cmd diff --git a/deepspeed/utils/z3_leaf_module.py b/deepspeed/utils/z3_leaf_module.py index 47d9ff698f1f..14e8ae2d2823 100644 --- a/deepspeed/utils/z3_leaf_module.py +++ b/deepspeed/utils/z3_leaf_module.py @@ -4,7 +4,7 @@ # DeepSpeed Team import torch -from typing import List, Type +from typing import List, Type, Union def z3_leaf_module(model: torch.nn.Module) -> bool: @@ -40,18 +40,24 @@ def get_z3_leaf_modules(model: torch.nn.Module) -> List[torch.nn.Module]: return [module for module in model.modules() if z3_leaf_module(module)] -def _do_set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type], +def set_z3_leaf_module(model: torch.nn.Module, flag: bool): + model._z3_leaf = flag + + +def _do_set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: Union[List[Type], List[str]], flag: bool) -> List[torch.nn.Module]: - assert all(isinstance(module_class, type) for module_class in leaf_module_classes), \ - f'leaf_module_classes must be a list of types, got {leaf_module_classes}' + assert all(isinstance(module_class, (type, str) ) for module_class in leaf_module_classes), \ + f'leaf_module_classes must be a list of types or names, got {leaf_module_classes}' leaf_modules = [] def _set_z3_leaf_flag(model: torch.nn.Module): nonlocal leaf_modules - if model.__class__ in leaf_module_classes: - model._z3_leaf = flag - leaf_modules.append(model) + for module in leaf_module_classes: + if (isinstance(module, type) and model.__class__ == module) or \ + (isinstance(module, str) and model.__class__.__name__ == module): + model._z3_leaf = flag + leaf_modules.append(model) model.apply(_set_z3_leaf_flag) @@ -61,13 +67,14 @@ def _set_z3_leaf_flag(model: torch.nn.Module): return leaf_modules -def set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> List[torch.nn.Module]: +def set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: Union[List[Type], + List[str]]) -> List[torch.nn.Module]: """Sets a flag within a module in `model` to instruct ZeRO3 to stop setting hooks recursively when it encounters a module class listed in `leaf_module_classes`. This is particularly useful in the context of Mixture of Experts (MoE) models. In MoE models, the computation order of experts varies across forward passes. This variability can disrupt ZeRO3's functionality, as ZeRO3 relies on tracking the computation order of modules to prefetch parameters efficiently. By designating a module as a 'leaf' node, ZeRO3 will prefetch parameters for all child modules upon entering the module. Another scenario where this functionality is beneficial is in models with excessively fine-grained nested modules, where it helps to avoid the overhead associated with hooks. Args: model (torch.nn.Module): The model to which the leaf module flag will be applied. - leaf_module_classes (List[Type]): A list of module classes that should be flagged as 'leaf' modules. + leaf_module_classes (Union[List[Type], List[str]]): A list of module classes that should be flagged as 'leaf' modules. Returns: List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`. """ @@ -79,7 +86,7 @@ def unset_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type See `set_z3_leaf_modules` for more details. Args: model (torch.nn.Module): The model to which the leaf module flag will be applied. - leaf_module_classes (List[Type]): A list of module classes that should be flagged as 'leaf' modules. + leaf_module_classes (Union[List[Type], List[str]]): A list of module classes that should be flagged as 'leaf' modules. Returns: List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`. """ diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index adb2f1679ea0..51e3bbd6eaaa 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -489,6 +489,11 @@ Enabling and configuring ZeRO memory optimizations |--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| ------- | | Consolidate the weights before saving the model by `save_16bit_model()`. Since the weights are partitioned across GPUs, they aren't part of `state_dict`, so this function automatically gathers the weights when this option is enabled and then saves the fp16 model weights. | `False` | +***stage3_module_granularity_threshold***: [integer] +| Description | Default | +|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| ------- | +| The granularity of a module is determined by the ratio of `parameter_count` / `(1 + descendant_count)`. ZeRO3 classifies modules with a granularity below the threshold as fine-grained, treating them as integral units during parameter fetching. This reduces host and communication overhead from separate hooks. | `0` | + ***zero_hpz_partition_size***: [integer] | Description | Default | diff --git a/tests/unit/runtime/zero/test_zero_leaf_module.py b/tests/unit/runtime/zero/test_zero_leaf_module.py index 1d3b88a04a4e..74c709883645 100644 --- a/tests/unit/runtime/zero/test_zero_leaf_module.py +++ b/tests/unit/runtime/zero/test_zero_leaf_module.py @@ -3,6 +3,7 @@ # DeepSpeed Team +import pytest import deepspeed.comm as dist import torch @@ -12,6 +13,8 @@ import deepspeed from deepspeed.utils import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module from deepspeed.accelerator import get_accelerator +from torch import nn +import time class ChooseModuleByCounter(torch.nn.Module): @@ -53,6 +56,49 @@ def forward(self, x, y): return x, loss +class MLPBlock(nn.Module): + + def __init__(self, hidden_dim): + super(MLPBlock, self).__init__() + self.gate_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) + self.up_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) + self.act_fn = nn.GELU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class FineGrainedBlock(nn.Module): + + def __init__(self, hidden_dim, num_block): + super(FineGrainedBlock, self).__init__() + self.num_block = num_block + self.mlp_layers = torch.nn.ModuleList([MLPBlock(hidden_dim=hidden_dim) for _ in range(self.num_block)]) + + def forward(self, x): + for i in range(self.num_block): + x = self.mlp_layers[i](x) + return x + + +class modelWithFineGrainedBlock(nn.Module): + + def __init__(self, hidden_dim, num_block): + super(modelWithFineGrainedBlock, self).__init__() + self.coarse_grained_layer1 = nn.Linear(hidden_dim, 8 * hidden_dim) + self.coarse_grained_layer2 = nn.Linear(8 * hidden_dim, hidden_dim) + self.fine_grained_layer = FineGrainedBlock(hidden_dim, num_block) + self.cel = torch.nn.CrossEntropyLoss() + + def forward(self, x, y): + x = self.coarse_grained_layer1(x) + x = self.coarse_grained_layer2(x) + x = self.fine_grained_layer(x) + loss = self.cel(x, y) + return x, loss + + def run_model(model, config_dict, hidden_dim, dtype, requires_grad): model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) data_loader = random_dataloader(model=model, @@ -97,9 +143,9 @@ def _test_set_z3_leaf_modules(self, cls, requires_grad): "stage3_max_reuse_distance": 0, } } - if get_accelerator().is_fp16_supported(): + if preferred_dtype() is torch.float16: config_dict["fp16"] = {"enabled": True} - elif get_accelerator().is_bf16_supported(): + elif preferred_dtype() is torch.bfloat16: config_dict["bf16"] = {"enabled": True} model = cls(hidden_dim) @@ -143,3 +189,74 @@ def test_set_no_match_class(self): raise AssertionError("Expected error that no module is set as a leaf module") except ValueError as e: pass + + +@pytest.mark.parametrize("module_granularity_threshold", [0, 100, 12100, 10000000]) +class TestZ3LeafOptimization(DistributedTest): + world_size = 2 + reuse_dist_env = True + + def test_finegrained_optimization(self, module_granularity_threshold: int): + hidden_dim = 128 + num_block = 16 + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "zero_optimization": { + "stage": 3, + "stage3_prefetch_bucket_size": hidden_dim**2, + "stage3_param_persistence_threshold": 0, + "stage3_max_reuse_distance": 0, + } + } + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + def bench_loss_and_time(config): + warm_up_step = 10 + model = modelWithFineGrainedBlock(hidden_dim, num_block) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config) + data_loader = random_dataloader(model=model, + total_samples=20, + hidden_dim=hidden_dim, + device=model.device, + dtype=preferred_dtype()) + dist.barrier() + loss_list = [] + + for i, batch in enumerate(data_loader): + if i == warm_up_step: + dist.barrier() + get_accelerator().synchronize() + start_time = time.time() + batch[0].requires_grad = True + loss = model(batch[0], batch[1]) + loss = loss[1] + loss_list.append(loss) + model.backward(loss) + model.step() + get_accelerator().synchronize() + end_time = time.time() + duration = end_time - start_time + model.destroy() + return loss_list, duration + + baseline_loss_list, baseline_exec_time = bench_loss_and_time(config_dict) + + config_dict["zero_optimization"]["stage3_module_granularity_threshold"] = module_granularity_threshold + loss, duration = bench_loss_and_time(config_dict) + + if dist.get_rank() == 0: + print(f"baseline exec time:", baseline_exec_time) + print( + f"finegrained optimziation exec time: {duration},granularity threshold:{module_granularity_threshold} " + ) + assert baseline_loss_list == loss, f"incorrect loss value with threshold:{module_granularity_threshold}"