From a2610d8a8829cbd9d22f20ace5e8a8d4be569b03 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 21 Oct 2024 15:31:56 +0800 Subject: [PATCH 01/21] z3 coalesced fetch --- deepspeed/runtime/engine.py | 7 ++++++- deepspeed/runtime/zero/config.py | 9 +++++++++ deepspeed/runtime/zero/parameter_offload.py | 6 +++++- deepspeed/runtime/zero/stage3.py | 8 ++++++-- deepspeed/utils/z3_leaf_module.py | 22 +++++++++++---------- 5 files changed, 38 insertions(+), 14 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 05bb23e8ddd9..f33da1fc9370 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -811,6 +811,9 @@ def zero_max_reuse_distance(self): def zero_prefetch_bucket_size(self): return self._config.zero_config.prefetch_bucket_size + def zero_force_coalesced_fetch_layers(self): + return self._config.zero_config.force_coalesced_fetch_layers + def zero_param_persistence_threshold(self): return self._config.zero_config.param_persistence_threshold @@ -1611,6 +1614,7 @@ def _configure_zero_optimizer(self, optimizer): zero_param_parallel_group=zero_param_parallel_group, zero_quantized_weights=self.zero_quantized_weights(), zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(), + force_coalesced_fetch_layers=self.zero_force_coalesced_fetch_layers(), ) else: log_dist( @@ -1657,7 +1661,8 @@ def _configure_zero_optimizer(self, optimizer): zero_hpz_partition_size=self.zero_hpz_partition_size(), zero_quantized_weights=self.zero_quantized_weights(), zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(), - ) + zero_force_coalesced_fetch_layers=self.zero_force_coalesced_fetch_layers() + ) else: raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage)) diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 1cfcd784e2ce..88c63d006883 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -21,6 +21,7 @@ "stage3_max_live_parameters" : 1000000000, "stage3_max_reuse_distance" : 1000000000, "stage3_use_all_reduce_for_fetch_params": [true|false], + "stage3_force_coalesced_fetch_layers": list[str], "allgather_partitions": [true|false], "use_multi_rank_bucket_allreduce": [true|false], "allgather_bucket_size": 500000000, @@ -245,6 +246,14 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): this option is enabled and then saves the fp16 model weights. """ + + force_coalesced_fetch_layers: list[str] = Field(None,alias="stage3_force_coalesced_fetch_layers") + """ + Treat the layer as an integral unit (to avoid recursion) when fetching at stage3. + This will reduce the host overhead and separated allgather overhead in fetching + parameters introduced by hooks for fine-grained layers. + """ + use_all_reduce_for_fetch_params: bool = Field(False, alias="stage3_use_all_reduce_for_fetch_params") """ Use all_reduce op when fetching module parameters at stage3. This improves performance by reducing diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 1ce2414a1e17..d1f1163363f6 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -6,7 +6,7 @@ import sys import torch from collections import OrderedDict -from deepspeed.utils import z3_leaf_module +from deepspeed.utils import z3_leaf_module, set_z3_leaf_modules from deepspeed.runtime.utils import see_memory_usage from deepspeed.runtime.zero.utils import apply_to_tensors_only, is_zero_param from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum @@ -96,6 +96,7 @@ def __init__( zero_param_parallel_group=None, zero_quantized_weights=False, zero_quantized_nontrainable_weights=False, + zero_force_coalesced_fetch_layers=None, ): see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=True) @@ -144,6 +145,9 @@ def __init__( self.forward_hooks = [] self.backward_hooks = [] + if zero_force_coalesced_fetch_layers is not None and len(zero_force_coalesced_fetch_layers)>0: + set_z3_leaf_modules(module, zero_force_coalesced_fetch_layers) + self.setup_zero_stage3_hooks() print_rank_0( f'Created module hooks: forward = {len(self.forward_hooks)}, backward = {len(self.backward_hooks)}', diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index e2c273fd913f..6b068ee4fdf6 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -157,6 +157,7 @@ def __init__( zero_hpz_partition_size=1, zero_quantized_weights=False, zero_quantized_nontrainable_weights=False, + zero_force_coalesced_fetch_layers=False, ): see_memory_usage("Stage 3 initialize beginning", force=True) @@ -227,7 +228,8 @@ def __init__( mpu=mpu, zero_param_parallel_group=zero_param_parallel_group, zero_quantized_weights=zero_quantized_weights, - zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights) + zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights, + zero_force_coalesced_fetch_layers=zero_force_coalesced_fetch_layers) self.persistent_parameters = self.parameter_offload.persistent_parameters self._configure_offloading(offload_optimizer_config, offload_param_config) @@ -458,6 +460,7 @@ def initialize_ds_offload( zero_param_parallel_group, zero_quantized_weights, zero_quantized_nontrainable_weights, + zero_force_coalesced_fetch_layers, ): return DeepSpeedZeRoOffload(module=module, timers=timers, @@ -473,7 +476,8 @@ def initialize_ds_offload( mpu=mpu, zero_param_parallel_group=zero_param_parallel_group, zero_quantized_weights=zero_quantized_weights, - zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights) + zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights, + zero_force_coalesced_fetch_layers=zero_force_coalesced_fetch_layers,) def _get_trainable_parameter_groups(self): param_groups = [] diff --git a/deepspeed/utils/z3_leaf_module.py b/deepspeed/utils/z3_leaf_module.py index 47d9ff698f1f..91fdd8474853 100644 --- a/deepspeed/utils/z3_leaf_module.py +++ b/deepspeed/utils/z3_leaf_module.py @@ -4,7 +4,7 @@ # DeepSpeed Team import torch -from typing import List, Type +from typing import List, Type, Union def z3_leaf_module(model: torch.nn.Module) -> bool: @@ -40,18 +40,20 @@ def get_z3_leaf_modules(model: torch.nn.Module) -> List[torch.nn.Module]: return [module for module in model.modules() if z3_leaf_module(module)] -def _do_set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type], +def _do_set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: Union[List[Type], List[str]], flag: bool) -> List[torch.nn.Module]: - assert all(isinstance(module_class, type) for module_class in leaf_module_classes), \ - f'leaf_module_classes must be a list of types, got {leaf_module_classes}' + assert all(isinstance(module_class, (type, str) ) for module_class in leaf_module_classes), \ + f'leaf_module_classes must be a list of types or names, got {leaf_module_classes}' leaf_modules = [] def _set_z3_leaf_flag(model: torch.nn.Module): nonlocal leaf_modules - if model.__class__ in leaf_module_classes: - model._z3_leaf = flag - leaf_modules.append(model) + for module in leaf_module_classes: + if (isinstance(module, type) and model.__class__ == module) or \ + (isinstance(module, str) and model.__class__.__name__ == module): + model._z3_leaf = flag + leaf_modules.append(model) model.apply(_set_z3_leaf_flag) @@ -61,13 +63,13 @@ def _set_z3_leaf_flag(model: torch.nn.Module): return leaf_modules -def set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> List[torch.nn.Module]: +def set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: Union[List[Type], List[str]]) -> List[torch.nn.Module]: """Sets a flag within a module in `model` to instruct ZeRO3 to stop setting hooks recursively when it encounters a module class listed in `leaf_module_classes`. This is particularly useful in the context of Mixture of Experts (MoE) models. In MoE models, the computation order of experts varies across forward passes. This variability can disrupt ZeRO3's functionality, as ZeRO3 relies on tracking the computation order of modules to prefetch parameters efficiently. By designating a module as a 'leaf' node, ZeRO3 will prefetch parameters for all child modules upon entering the module. Another scenario where this functionality is beneficial is in models with excessively fine-grained nested modules, where it helps to avoid the overhead associated with hooks. Args: model (torch.nn.Module): The model to which the leaf module flag will be applied. - leaf_module_classes (List[Type]): A list of module classes that should be flagged as 'leaf' modules. + leaf_module_classes (Union[List[Type], List[str]]): A list of module classes that should be flagged as 'leaf' modules. Returns: List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`. """ @@ -79,7 +81,7 @@ def unset_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type See `set_z3_leaf_modules` for more details. Args: model (torch.nn.Module): The model to which the leaf module flag will be applied. - leaf_module_classes (List[Type]): A list of module classes that should be flagged as 'leaf' modules. + leaf_module_classes (Union[List[Type], List[str]]): A list of module classes that should be flagged as 'leaf' modules. Returns: List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`. """ From 4e8be08565f91651c0fe258cced538337fa37c3f Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 21 Oct 2024 09:02:49 +0000 Subject: [PATCH 02/21] fix format --- deepspeed/runtime/engine.py | 6 ++-- deepspeed/runtime/zero/config.py | 11 +++---- deepspeed/runtime/zero/parameter_offload.py | 6 ++-- deepspeed/runtime/zero/stage3.py | 34 +++++++++++---------- deepspeed/utils/z3_leaf_module.py | 3 +- 5 files changed, 31 insertions(+), 29 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index f33da1fc9370..3cdfb005efb7 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -813,7 +813,7 @@ def zero_prefetch_bucket_size(self): def zero_force_coalesced_fetch_layers(self): return self._config.zero_config.force_coalesced_fetch_layers - + def zero_param_persistence_threshold(self): return self._config.zero_config.param_persistence_threshold @@ -1661,8 +1661,8 @@ def _configure_zero_optimizer(self, optimizer): zero_hpz_partition_size=self.zero_hpz_partition_size(), zero_quantized_weights=self.zero_quantized_weights(), zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(), - zero_force_coalesced_fetch_layers=self.zero_force_coalesced_fetch_layers() - ) + zero_force_coalesced_fetch_layers=self.zero_force_coalesced_fetch_layers(), + ) else: raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage)) diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 88c63d006883..7d6bd80f8358 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -21,7 +21,7 @@ "stage3_max_live_parameters" : 1000000000, "stage3_max_reuse_distance" : 1000000000, "stage3_use_all_reduce_for_fetch_params": [true|false], - "stage3_force_coalesced_fetch_layers": list[str], + "stage3_force_coalesced_fetch_layers": list[str], "allgather_partitions": [true|false], "use_multi_rank_bucket_allreduce": [true|false], "allgather_bucket_size": 500000000, @@ -246,14 +246,13 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): this option is enabled and then saves the fp16 model weights. """ - - force_coalesced_fetch_layers: list[str] = Field(None,alias="stage3_force_coalesced_fetch_layers") - """ + force_coalesced_fetch_layers: list[str] = Field(None, alias="stage3_force_coalesced_fetch_layers") + """ Treat the layer as an integral unit (to avoid recursion) when fetching at stage3. - This will reduce the host overhead and separated allgather overhead in fetching + This will reduce the host overhead and separated allgather overhead in fetching parameters introduced by hooks for fine-grained layers. """ - + use_all_reduce_for_fetch_params: bool = Field(False, alias="stage3_use_all_reduce_for_fetch_params") """ Use all_reduce op when fetching module parameters at stage3. This improves performance by reducing diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index d1f1163363f6..2e4e21b0fb92 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -145,9 +145,9 @@ def __init__( self.forward_hooks = [] self.backward_hooks = [] - if zero_force_coalesced_fetch_layers is not None and len(zero_force_coalesced_fetch_layers)>0: - set_z3_leaf_modules(module, zero_force_coalesced_fetch_layers) - + if zero_force_coalesced_fetch_layers is not None and len(zero_force_coalesced_fetch_layers) > 0: + set_z3_leaf_modules(module, zero_force_coalesced_fetch_layers) + self.setup_zero_stage3_hooks() print_rank_0( f'Created module hooks: forward = {len(self.forward_hooks)}, backward = {len(self.backward_hooks)}', diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 6b068ee4fdf6..75b220bf5d3f 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -462,22 +462,24 @@ def initialize_ds_offload( zero_quantized_nontrainable_weights, zero_force_coalesced_fetch_layers, ): - return DeepSpeedZeRoOffload(module=module, - timers=timers, - ds_config=ds_config, - overlap_comm=overlap_comm, - prefetch_bucket_size=prefetch_bucket_size, - max_reuse_distance=max_reuse_distance, - max_live_parameters=max_live_parameters, - param_persistence_threshold=param_persistence_threshold, - model_persistence_threshold=model_persistence_threshold, - dp_process_group=dp_process_group, - offload_param_config=offload_param_config, - mpu=mpu, - zero_param_parallel_group=zero_param_parallel_group, - zero_quantized_weights=zero_quantized_weights, - zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights, - zero_force_coalesced_fetch_layers=zero_force_coalesced_fetch_layers,) + return DeepSpeedZeRoOffload( + module=module, + timers=timers, + ds_config=ds_config, + overlap_comm=overlap_comm, + prefetch_bucket_size=prefetch_bucket_size, + max_reuse_distance=max_reuse_distance, + max_live_parameters=max_live_parameters, + param_persistence_threshold=param_persistence_threshold, + model_persistence_threshold=model_persistence_threshold, + dp_process_group=dp_process_group, + offload_param_config=offload_param_config, + mpu=mpu, + zero_param_parallel_group=zero_param_parallel_group, + zero_quantized_weights=zero_quantized_weights, + zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights, + zero_force_coalesced_fetch_layers=zero_force_coalesced_fetch_layers, + ) def _get_trainable_parameter_groups(self): param_groups = [] diff --git a/deepspeed/utils/z3_leaf_module.py b/deepspeed/utils/z3_leaf_module.py index 91fdd8474853..6ecdfa07429e 100644 --- a/deepspeed/utils/z3_leaf_module.py +++ b/deepspeed/utils/z3_leaf_module.py @@ -63,7 +63,8 @@ def _set_z3_leaf_flag(model: torch.nn.Module): return leaf_modules -def set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: Union[List[Type], List[str]]) -> List[torch.nn.Module]: +def set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: Union[List[Type], + List[str]]) -> List[torch.nn.Module]: """Sets a flag within a module in `model` to instruct ZeRO3 to stop setting hooks recursively when it encounters a module class listed in `leaf_module_classes`. This is particularly useful in the context of Mixture of Experts (MoE) models. In MoE models, the computation order of experts varies across forward passes. This variability can disrupt ZeRO3's functionality, as ZeRO3 relies on tracking the computation order of modules to prefetch parameters efficiently. By designating a module as a 'leaf' node, ZeRO3 will prefetch parameters for all child modules upon entering the module. Another scenario where this functionality is beneficial is in models with excessively fine-grained nested modules, where it helps to avoid the overhead associated with hooks. From 7641994562f33b8aded57f71b4153195ccce6d59 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 21 Oct 2024 10:40:24 +0000 Subject: [PATCH 03/21] fix default value --- deepspeed/runtime/zero/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 7d6bd80f8358..9861a596fd6f 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -246,7 +246,7 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): this option is enabled and then saves the fp16 model weights. """ - force_coalesced_fetch_layers: list[str] = Field(None, alias="stage3_force_coalesced_fetch_layers") + force_coalesced_fetch_layers: list[str] = Field([], alias="stage3_force_coalesced_fetch_layers") """ Treat the layer as an integral unit (to avoid recursion) when fetching at stage3. This will reduce the host overhead and separated allgather overhead in fetching From 805a82070730b5f202cc5ca773f722bd6fc6498f Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 21 Oct 2024 10:55:16 +0000 Subject: [PATCH 04/21] fix default --- deepspeed/runtime/zero/stage3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 75b220bf5d3f..7b1d07e9ba60 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -157,7 +157,7 @@ def __init__( zero_hpz_partition_size=1, zero_quantized_weights=False, zero_quantized_nontrainable_weights=False, - zero_force_coalesced_fetch_layers=False, + zero_force_coalesced_fetch_layers=None, ): see_memory_usage("Stage 3 initialize beginning", force=True) From 810353bb0aa385654ee666d2666cb0b4ace6a5c4 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Wed, 23 Oct 2024 03:00:43 +0000 Subject: [PATCH 05/21] fix ut --- deepspeed/runtime/engine.py | 2 +- deepspeed/runtime/zero/config.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 3cdfb005efb7..ac730461b20d 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1614,7 +1614,7 @@ def _configure_zero_optimizer(self, optimizer): zero_param_parallel_group=zero_param_parallel_group, zero_quantized_weights=self.zero_quantized_weights(), zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(), - force_coalesced_fetch_layers=self.zero_force_coalesced_fetch_layers(), + zero_force_coalesced_fetch_layers=self.zero_force_coalesced_fetch_layers(), ) else: log_dist( diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 9861a596fd6f..0b03b501790d 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -4,7 +4,7 @@ # DeepSpeed Team import sys -from typing import Optional +from typing import Optional, List from enum import Enum from pydantic import Field, model_validator from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel @@ -246,7 +246,7 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): this option is enabled and then saves the fp16 model weights. """ - force_coalesced_fetch_layers: list[str] = Field([], alias="stage3_force_coalesced_fetch_layers") + force_coalesced_fetch_layers: List[str] = Field([], alias="stage3_force_coalesced_fetch_layers") """ Treat the layer as an integral unit (to avoid recursion) when fetching at stage3. This will reduce the host overhead and separated allgather overhead in fetching From 7b94377ba6e0026aa50a3bbf85c3d5a1fecfab40 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 4 Nov 2024 09:50:48 +0000 Subject: [PATCH 06/21] add ut(usage) --- .../runtime/zero/test_zero_leaf_module.py | 102 +++++++++++++++++- 1 file changed, 101 insertions(+), 1 deletion(-) diff --git a/tests/unit/runtime/zero/test_zero_leaf_module.py b/tests/unit/runtime/zero/test_zero_leaf_module.py index 1d3b88a04a4e..84e10aba7673 100644 --- a/tests/unit/runtime/zero/test_zero_leaf_module.py +++ b/tests/unit/runtime/zero/test_zero_leaf_module.py @@ -12,6 +12,8 @@ import deepspeed from deepspeed.utils import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module from deepspeed.accelerator import get_accelerator +from torch import nn +import time class ChooseModuleByCounter(torch.nn.Module): @@ -44,7 +46,7 @@ def __init__(self, hidden_dim): torch.nn.Linear(hidden_dim, hidden_dim, bias=False)]) self.act = torch.nn.ReLU() self.cel = torch.nn.CrossEntropyLoss() - + def forward(self, x, y): # Each rank runs only one of the linear layers x = self.linears[dist.get_rank() % len(self.linears)](x) @@ -52,6 +54,40 @@ def forward(self, x, y): loss = self.cel(x, y) return x, loss +class MLPBlock(nn.Module): + def __init__(self, hidden_dim): + super(MLPBlock, self).__init__() + self.gate_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) + self.up_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) + self.act_fn=nn.GELU() + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) +class FineGrainedBlock(nn.Module): + def __init__(self, hidden_dim, num_block): + super(FineGrainedBlock, self).__init__() + self.num_block=num_block + self.mlp_layers = torch.nn.ModuleList( + [MLPBlock(hidden_dim=hidden_dim) for _ in range(self.num_block)] + ) + def forward(self, x): + for i in range(self.num_block): + x=self.mlp_layers[i](x) + return x +class modelWithFineGrainedBlock(nn.Module): + def __init__(self, hidden_dim, num_block): + super(modelWithFineGrainedBlock, self).__init__() + self.coarse_grained_layer1=nn.Linear(hidden_dim,8*hidden_dim) + self.coarse_grained_layer2=nn.Linear(8*hidden_dim,hidden_dim) + self.finegrad_layer = FineGrainedBlock(hidden_dim,num_block) + self.cel = torch.nn.CrossEntropyLoss() + def forward(self, x, y): + x=self.coarse_grained_layer1(x) + x=self.coarse_grained_layer2(x) + x=self.finegrad_layer(x) + loss = self.cel(x, y) + return x, loss + def run_model(model, config_dict, hidden_dim, dtype, requires_grad): model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) @@ -143,3 +179,67 @@ def test_set_no_match_class(self): raise AssertionError("Expected error that no module is set as a leaf module") except ValueError as e: pass + + +class TestZ3LeafOptimization(DistributedTest): + # Need multiple gpus to test possible hanging + world_size = 2 + reuse_dist_env = True + def test_FineGrained_optimization(self): + hidden_dim=32 + num_block=16 + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "zero_optimization": { + "stage": 3, + "stage3_prefetch_bucket_size": hidden_dim**2, + "stage3_param_persistence_threshold": 0, + "stage3_max_reuse_distance": 0, + "stage3_force_coalesced_fetch_layers":["FineGrainedBlock"] + } + } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + + def bench_loss_and_time(config): + warm_up_step=10 + model = modelWithFineGrainedBlock(hidden_dim,num_block) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config) + data_loader = random_dataloader(model=model, + total_samples=20, + hidden_dim=hidden_dim, + device=model.device, + dtype=preferred_dtype()) + dist.barrier() + loss_list=[] + + for i, batch in enumerate(data_loader): + if i ==warm_up_step: + get_accelerator().synchronize() + st=time.time() + batch[0].requires_grad = True + loss = model(batch[0], batch[1]) + loss = loss[1] + loss_list.append(loss) + model.backward(loss) + model.step() + get_accelerator().synchronize() + en=time.time() + duration = en-st + model.destroy() + return loss_list, duration + + opt_loss_list,opt_duration=bench_loss_and_time(config_dict) + del config_dict["zero_optimization"]["stage3_force_coalesced_fetch_layers"] + basic_loss_list, basic_duration=bench_loss_and_time(config_dict) + print(f"coalesced fetch time: {opt_duration}, basic duration time:{basic_duration}") + assert basic_loss_list==opt_loss_list \ No newline at end of file From cd31a0d147980a68d54fb4fe013aeca9e8fb046f Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 4 Nov 2024 13:06:00 +0000 Subject: [PATCH 07/21] use int type config --- deepspeed/runtime/engine.py | 8 +- deepspeed/runtime/zero/config.py | 6 +- deepspeed/runtime/zero/parameter_offload.py | 77 ++++++++++++++++++- deepspeed/runtime/zero/stage3.py | 8 +- deepspeed/utils/__init__.py | 2 +- deepspeed/utils/z3_leaf_module.py | 2 + .../runtime/zero/test_zero_leaf_module.py | 16 ++-- 7 files changed, 98 insertions(+), 21 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index ac730461b20d..bd6bc3493ada 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -811,8 +811,8 @@ def zero_max_reuse_distance(self): def zero_prefetch_bucket_size(self): return self._config.zero_config.prefetch_bucket_size - def zero_force_coalesced_fetch_layers(self): - return self._config.zero_config.force_coalesced_fetch_layers + def zero_coalesced_fetch_threshold(self): + return self._config.zero_config.coalesced_fetch_threshold def zero_param_persistence_threshold(self): return self._config.zero_config.param_persistence_threshold @@ -1614,7 +1614,7 @@ def _configure_zero_optimizer(self, optimizer): zero_param_parallel_group=zero_param_parallel_group, zero_quantized_weights=self.zero_quantized_weights(), zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(), - zero_force_coalesced_fetch_layers=self.zero_force_coalesced_fetch_layers(), + zero_coalesced_fetch_threshold=self.zero_coalesced_fetch_threshold(), ) else: log_dist( @@ -1661,7 +1661,7 @@ def _configure_zero_optimizer(self, optimizer): zero_hpz_partition_size=self.zero_hpz_partition_size(), zero_quantized_weights=self.zero_quantized_weights(), zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(), - zero_force_coalesced_fetch_layers=self.zero_force_coalesced_fetch_layers(), + zero_coalesced_fetch_threshold=self.zero_coalesced_fetch_threshold(), ) else: diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 0b03b501790d..adc76f3ea215 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -4,7 +4,7 @@ # DeepSpeed Team import sys -from typing import Optional, List +from typing import Optional from enum import Enum from pydantic import Field, model_validator from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel @@ -21,7 +21,7 @@ "stage3_max_live_parameters" : 1000000000, "stage3_max_reuse_distance" : 1000000000, "stage3_use_all_reduce_for_fetch_params": [true|false], - "stage3_force_coalesced_fetch_layers": list[str], + "stage3_coalesced_fetch_threshold": 0, "allgather_partitions": [true|false], "use_multi_rank_bucket_allreduce": [true|false], "allgather_bucket_size": 500000000, @@ -246,7 +246,7 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): this option is enabled and then saves the fp16 model weights. """ - force_coalesced_fetch_layers: List[str] = Field([], alias="stage3_force_coalesced_fetch_layers") + coalesced_fetch_threshold: int = Field(pp_int(0), alias="stage3_coalesced_fetch_threshold") """ Treat the layer as an integral unit (to avoid recursion) when fetching at stage3. This will reduce the host overhead and separated allgather overhead in fetching diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 2e4e21b0fb92..d49049ee0956 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -6,7 +6,7 @@ import sys import torch from collections import OrderedDict -from deepspeed.utils import z3_leaf_module, set_z3_leaf_modules +from deepspeed.utils import z3_leaf_module, set_z3_leaf_modules, set_z3_leaf_module from deepspeed.runtime.utils import see_memory_usage from deepspeed.runtime.zero.utils import apply_to_tensors_only, is_zero_param from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum @@ -14,6 +14,7 @@ from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.partitioned_param_coordinator import PartitionedParameterCoordinator, InflightParamRegistry, iter_params from deepspeed.accelerator import get_accelerator +from deepspeed import utils FWD_MODULE_STACK = list() @@ -96,7 +97,7 @@ def __init__( zero_param_parallel_group=None, zero_quantized_weights=False, zero_quantized_nontrainable_weights=False, - zero_force_coalesced_fetch_layers=None, + zero_coalesced_fetch_threshold=0, ): see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=True) @@ -117,6 +118,7 @@ def __init__( self.offload_device = offload_param_config.device self.offload_param_pin_memory = offload_param_config.pin_memory + self._convert_to_zero_parameters(ds_config, module, mpu) for m in module.modules(): @@ -143,10 +145,23 @@ def __init__( module.ds_inflight_param_registry[False] = InflightParamRegistry() self.__inflight_param_registry = module.ds_inflight_param_registry + if zero_coalesced_fetch_threshold >=0: + self.min_granularity_value=sys.maxsize + self.min_granularity_layer=None + self.z3_leaf_layers=[] + self.count_layers_and_parameters(module) + + if self.min_granularity_value<=zero_coalesced_fetch_threshold: + self.set_z3_leaf_by_threshold(module, zero_coalesced_fetch_threshold) + print_rank_0(f"z3_leaf_module was setted by stage3_coalesced_fetch_threshold", force=True) + for layer in self.z3_leaf_layers: + print_rank_0(f"{layer.__class__.__name__}:{layer.ds_model_granularity}", force=True) + else: + utils.logger.warning(f"You have used min_granularity_value, but the smallest module granularity is [{self.min_granularity_layer}:{self.min_granularity_value}],\ + To make this variable effective, you need to set stage3_coalesced_fetch_threshold >= {self.min_granularity_value}") self.forward_hooks = [] self.backward_hooks = [] - if zero_force_coalesced_fetch_layers is not None and len(zero_force_coalesced_fetch_layers) > 0: - set_z3_leaf_modules(module, zero_force_coalesced_fetch_layers) + self.setup_zero_stage3_hooks() print_rank_0( @@ -155,6 +170,8 @@ def __init__( see_memory_usage("DeepSpeedZeRoOffload initialize [end]", force=True) + + @instrument_w_nvtx def partition_all_parameters(self): """Partitioning Parameters that were not partitioned usually if parameters @@ -486,3 +503,55 @@ def post_sub_module_backward_function(self, sub_module): see_memory_usage( f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", force=False) + + def count_layers_and_parameters(self, module): + + if not list(module.parameters()): + module.ds_model_granularity=sys.maxsize + return 0,0 + + num_layers = 0 + num_params = 0 + num_params += sum(p.ds_numel for p in module.parameters(recurse=False)) + + for child in module.children(): + layers_in_child, params_in_child = self.count_layers_and_parameters(child) + num_layers += layers_in_child + num_params += params_in_child + + num_layers += 1 + + # 将结果保存到模块的自定义属性中 TODO + module.ds_sub_layers = num_layers + module.ds_sub_params = num_params + + ds_model_granularity=num_params//num_layers + module.ds_model_granularity=ds_model_granularity + if self.min_granularity_value>ds_model_granularity: + self.min_granularity_value=ds_model_granularity + self.min_granularity_layer=module.__class__.__name__ + + return num_layers, num_params + + def set_z3_leaf_by_threshold(self, module,granularity_treshhold ): + num_params = sum(p.ds_numel for p in module.parameters()) + if num_params==0: + return + # Avoid setting as leaf for particularly large models, even if the granularity is very small + Z3_MAX_LEAF_SIZE=5e8 + + if module.ds_model_granularity List[torch.nn.Module]: """ return [module for module in model.modules() if z3_leaf_module(module)] +def set_z3_leaf_module(model: torch.nn.Module, flag: bool): + model._z3_leaf = flag def _do_set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: Union[List[Type], List[str]], flag: bool) -> List[torch.nn.Module]: diff --git a/tests/unit/runtime/zero/test_zero_leaf_module.py b/tests/unit/runtime/zero/test_zero_leaf_module.py index 84e10aba7673..7dc9704786cf 100644 --- a/tests/unit/runtime/zero/test_zero_leaf_module.py +++ b/tests/unit/runtime/zero/test_zero_leaf_module.py @@ -14,7 +14,7 @@ from deepspeed.accelerator import get_accelerator from torch import nn import time - +from math import ceil class ChooseModuleByCounter(torch.nn.Module): @@ -186,8 +186,9 @@ class TestZ3LeafOptimization(DistributedTest): world_size = 2 reuse_dist_env = True def test_FineGrained_optimization(self): - hidden_dim=32 + hidden_dim=128 num_block=16 + stage3_coalesced_fetch_threshold=12000 config_dict = { "train_micro_batch_size_per_gpu": 1, "steps_per_print": 1, @@ -202,7 +203,7 @@ def test_FineGrained_optimization(self): "stage3_prefetch_bucket_size": hidden_dim**2, "stage3_param_persistence_threshold": 0, "stage3_max_reuse_distance": 0, - "stage3_force_coalesced_fetch_layers":["FineGrainedBlock"] + "stage3_coalesced_fetch_threshold":stage3_coalesced_fetch_threshold } } if get_accelerator().is_fp16_supported(): @@ -221,7 +222,12 @@ def bench_loss_and_time(config): dtype=preferred_dtype()) dist.barrier() loss_list=[] - + + # for name, submodule in model.named_modules(): + # if hasattr(submodule,'ds_sub_layers'): + # tresh=submodule.ds_sub_params/submodule.ds_sub_layers + # if dist.get_rank()==0: + # print(f"module '{name}' layers: {submodule.ds_sub_layers}, params: {submodule.ds_sub_params}, tresh:{tresh}") for i, batch in enumerate(data_loader): if i ==warm_up_step: get_accelerator().synchronize() @@ -239,7 +245,7 @@ def bench_loss_and_time(config): return loss_list, duration opt_loss_list,opt_duration=bench_loss_and_time(config_dict) - del config_dict["zero_optimization"]["stage3_force_coalesced_fetch_layers"] + del config_dict["zero_optimization"]["stage3_coalesced_fetch_threshold"] basic_loss_list, basic_duration=bench_loss_and_time(config_dict) print(f"coalesced fetch time: {opt_duration}, basic duration time:{basic_duration}") assert basic_loss_list==opt_loss_list \ No newline at end of file From ea5096457f3c5cd4674df9a65cd2bf17c29420b4 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 4 Nov 2024 21:07:51 +0800 Subject: [PATCH 08/21] fix format --- deepspeed/runtime/zero/parameter_offload.py | 72 ++++++++--------- deepspeed/utils/z3_leaf_module.py | 2 + .../runtime/zero/test_zero_leaf_module.py | 79 +++++++++++-------- 3 files changed, 78 insertions(+), 75 deletions(-) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index d49049ee0956..9fea4f39cb17 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -118,7 +118,6 @@ def __init__( self.offload_device = offload_param_config.device self.offload_param_pin_memory = offload_param_config.pin_memory - self._convert_to_zero_parameters(ds_config, module, mpu) for m in module.modules(): @@ -145,24 +144,25 @@ def __init__( module.ds_inflight_param_registry[False] = InflightParamRegistry() self.__inflight_param_registry = module.ds_inflight_param_registry - if zero_coalesced_fetch_threshold >=0: - self.min_granularity_value=sys.maxsize - self.min_granularity_layer=None - self.z3_leaf_layers=[] + if zero_coalesced_fetch_threshold >= 0: + self.min_granularity_value = sys.maxsize + self.min_granularity_layer = None + self.z3_leaf_layers = [] self.count_layers_and_parameters(module) - if self.min_granularity_value<=zero_coalesced_fetch_threshold: + if self.min_granularity_value <= zero_coalesced_fetch_threshold: self.set_z3_leaf_by_threshold(module, zero_coalesced_fetch_threshold) print_rank_0(f"z3_leaf_module was setted by stage3_coalesced_fetch_threshold", force=True) for layer in self.z3_leaf_layers: print_rank_0(f"{layer.__class__.__name__}:{layer.ds_model_granularity}", force=True) else: - utils.logger.warning(f"You have used min_granularity_value, but the smallest module granularity is [{self.min_granularity_layer}:{self.min_granularity_value}],\ - To make this variable effective, you need to set stage3_coalesced_fetch_threshold >= {self.min_granularity_value}") + utils.logger.warning( + f"You have used min_granularity_value, but the smallest module granularity is [{self.min_granularity_layer}:{self.min_granularity_value}],\ + To make this variable effective, you need to set stage3_coalesced_fetch_threshold >= {self.min_granularity_value}" + ) self.forward_hooks = [] self.backward_hooks = [] - self.setup_zero_stage3_hooks() print_rank_0( f'Created module hooks: forward = {len(self.forward_hooks)}, backward = {len(self.backward_hooks)}', @@ -170,8 +170,6 @@ def __init__( see_memory_usage("DeepSpeedZeRoOffload initialize [end]", force=True) - - @instrument_w_nvtx def partition_all_parameters(self): """Partitioning Parameters that were not partitioned usually if parameters @@ -503,55 +501,47 @@ def post_sub_module_backward_function(self, sub_module): see_memory_usage( f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", force=False) - + def count_layers_and_parameters(self, module): - - if not list(module.parameters()): - module.ds_model_granularity=sys.maxsize - return 0,0 - + + if not list(module.parameters()): + module.ds_model_granularity = sys.maxsize + return 0, 0 + num_layers = 0 num_params = 0 num_params += sum(p.ds_numel for p in module.parameters(recurse=False)) - + for child in module.children(): layers_in_child, params_in_child = self.count_layers_and_parameters(child) num_layers += layers_in_child num_params += params_in_child - + num_layers += 1 - + # 将结果保存到模块的自定义属性中 TODO module.ds_sub_layers = num_layers module.ds_sub_params = num_params - - ds_model_granularity=num_params//num_layers - module.ds_model_granularity=ds_model_granularity - if self.min_granularity_value>ds_model_granularity: - self.min_granularity_value=ds_model_granularity - self.min_granularity_layer=module.__class__.__name__ - + + ds_model_granularity = num_params // num_layers + module.ds_model_granularity = ds_model_granularity + if self.min_granularity_value > ds_model_granularity: + self.min_granularity_value = ds_model_granularity + self.min_granularity_layer = module.__class__.__name__ + return num_layers, num_params - def set_z3_leaf_by_threshold(self, module,granularity_treshhold ): + def set_z3_leaf_by_threshold(self, module, granularity_treshhold): num_params = sum(p.ds_numel for p in module.parameters()) - if num_params==0: - return + if num_params == 0: + return # Avoid setting as leaf for particularly large models, even if the granularity is very small - Z3_MAX_LEAF_SIZE=5e8 + Z3_MAX_LEAF_SIZE = 5e8 - if module.ds_model_granularity List[torch.nn.Module]: """ return [module for module in model.modules() if z3_leaf_module(module)] + def set_z3_leaf_module(model: torch.nn.Module, flag: bool): model._z3_leaf = flag + def _do_set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: Union[List[Type], List[str]], flag: bool) -> List[torch.nn.Module]: assert all(isinstance(module_class, (type, str) ) for module_class in leaf_module_classes), \ diff --git a/tests/unit/runtime/zero/test_zero_leaf_module.py b/tests/unit/runtime/zero/test_zero_leaf_module.py index 7dc9704786cf..2c8bc58b0cc5 100644 --- a/tests/unit/runtime/zero/test_zero_leaf_module.py +++ b/tests/unit/runtime/zero/test_zero_leaf_module.py @@ -16,6 +16,7 @@ import time from math import ceil + class ChooseModuleByCounter(torch.nn.Module): def __init__(self, hidden_dim): @@ -46,7 +47,7 @@ def __init__(self, hidden_dim): torch.nn.Linear(hidden_dim, hidden_dim, bias=False)]) self.act = torch.nn.ReLU() self.cel = torch.nn.CrossEntropyLoss() - + def forward(self, x, y): # Each rank runs only one of the linear layers x = self.linears[dist.get_rank() % len(self.linears)](x) @@ -54,40 +55,49 @@ def forward(self, x, y): loss = self.cel(x, y) return x, loss + class MLPBlock(nn.Module): + def __init__(self, hidden_dim): super(MLPBlock, self).__init__() self.gate_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) self.up_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) self.down_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) - self.act_fn=nn.GELU() + self.act_fn = nn.GELU() + def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + class FineGrainedBlock(nn.Module): + def __init__(self, hidden_dim, num_block): super(FineGrainedBlock, self).__init__() - self.num_block=num_block - self.mlp_layers = torch.nn.ModuleList( - [MLPBlock(hidden_dim=hidden_dim) for _ in range(self.num_block)] - ) + self.num_block = num_block + self.mlp_layers = torch.nn.ModuleList([MLPBlock(hidden_dim=hidden_dim) for _ in range(self.num_block)]) + def forward(self, x): for i in range(self.num_block): - x=self.mlp_layers[i](x) + x = self.mlp_layers[i](x) return x + + class modelWithFineGrainedBlock(nn.Module): + def __init__(self, hidden_dim, num_block): super(modelWithFineGrainedBlock, self).__init__() - self.coarse_grained_layer1=nn.Linear(hidden_dim,8*hidden_dim) - self.coarse_grained_layer2=nn.Linear(8*hidden_dim,hidden_dim) - self.finegrad_layer = FineGrainedBlock(hidden_dim,num_block) + self.coarse_grained_layer1 = nn.Linear(hidden_dim, 8 * hidden_dim) + self.coarse_grained_layer2 = nn.Linear(8 * hidden_dim, hidden_dim) + self.finegrad_layer = FineGrainedBlock(hidden_dim, num_block) self.cel = torch.nn.CrossEntropyLoss() + def forward(self, x, y): - x=self.coarse_grained_layer1(x) - x=self.coarse_grained_layer2(x) - x=self.finegrad_layer(x) + x = self.coarse_grained_layer1(x) + x = self.coarse_grained_layer2(x) + x = self.finegrad_layer(x) loss = self.cel(x, y) return x, loss - + def run_model(model, config_dict, hidden_dim, dtype, requires_grad): model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) @@ -185,10 +195,11 @@ class TestZ3LeafOptimization(DistributedTest): # Need multiple gpus to test possible hanging world_size = 2 reuse_dist_env = True + def test_FineGrained_optimization(self): - hidden_dim=128 - num_block=16 - stage3_coalesced_fetch_threshold=12000 + hidden_dim = 128 + num_block = 16 + stage3_coalesced_fetch_threshold = 12000 config_dict = { "train_micro_batch_size_per_gpu": 1, "steps_per_print": 1, @@ -203,7 +214,7 @@ def test_FineGrained_optimization(self): "stage3_prefetch_bucket_size": hidden_dim**2, "stage3_param_persistence_threshold": 0, "stage3_max_reuse_distance": 0, - "stage3_coalesced_fetch_threshold":stage3_coalesced_fetch_threshold + "stage3_coalesced_fetch_threshold": stage3_coalesced_fetch_threshold } } if get_accelerator().is_fp16_supported(): @@ -212,26 +223,26 @@ def test_FineGrained_optimization(self): config_dict["bf16"] = {"enabled": True} def bench_loss_and_time(config): - warm_up_step=10 - model = modelWithFineGrainedBlock(hidden_dim,num_block) + warm_up_step = 10 + model = modelWithFineGrainedBlock(hidden_dim, num_block) model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config) data_loader = random_dataloader(model=model, - total_samples=20, - hidden_dim=hidden_dim, - device=model.device, - dtype=preferred_dtype()) + total_samples=20, + hidden_dim=hidden_dim, + device=model.device, + dtype=preferred_dtype()) dist.barrier() - loss_list=[] + loss_list = [] # for name, submodule in model.named_modules(): # if hasattr(submodule,'ds_sub_layers'): - # tresh=submodule.ds_sub_params/submodule.ds_sub_layers + # tresh=submodule.ds_sub_params/submodule.ds_sub_layers # if dist.get_rank()==0: # print(f"module '{name}' layers: {submodule.ds_sub_layers}, params: {submodule.ds_sub_params}, tresh:{tresh}") for i, batch in enumerate(data_loader): - if i ==warm_up_step: + if i == warm_up_step: get_accelerator().synchronize() - st=time.time() + st = time.time() batch[0].requires_grad = True loss = model(batch[0], batch[1]) loss = loss[1] @@ -239,13 +250,13 @@ def bench_loss_and_time(config): model.backward(loss) model.step() get_accelerator().synchronize() - en=time.time() - duration = en-st + en = time.time() + duration = en - st model.destroy() return loss_list, duration - opt_loss_list,opt_duration=bench_loss_and_time(config_dict) + opt_loss_list, opt_duration = bench_loss_and_time(config_dict) del config_dict["zero_optimization"]["stage3_coalesced_fetch_threshold"] - basic_loss_list, basic_duration=bench_loss_and_time(config_dict) + basic_loss_list, basic_duration = bench_loss_and_time(config_dict) print(f"coalesced fetch time: {opt_duration}, basic duration time:{basic_duration}") - assert basic_loss_list==opt_loss_list \ No newline at end of file + assert basic_loss_list == opt_loss_list From 600d9c749ca996053c1a84fde229030da0475840 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 4 Nov 2024 13:12:21 +0000 Subject: [PATCH 09/21] fix note --- deepspeed/runtime/zero/parameter_offload.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index c3115e3c5bfc..5cdc2ddc925b 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -524,7 +524,6 @@ def count_layers_and_parameters(self, module): num_layers += 1 - # 将结果保存到模块的自定义属性中 TODO module.ds_sub_layers = num_layers module.ds_sub_params = num_params From c2c434bd13e036b7ca36914f5c908660497b426f Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 5 Nov 2024 08:56:45 +0000 Subject: [PATCH 10/21] refine code --- deepspeed/runtime/zero/config.py | 4 +- deepspeed/runtime/zero/parameter_offload.py | 51 +++++++++++-------- deepspeed/runtime/zero/stage3.py | 34 ++++++------- .../runtime/zero/test_zero_leaf_module.py | 12 ++--- 4 files changed, 54 insertions(+), 47 deletions(-) diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index adc76f3ea215..c4cf6f1a7419 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -248,7 +248,9 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): coalesced_fetch_threshold: int = Field(pp_int(0), alias="stage3_coalesced_fetch_threshold") """ - Treat the layer as an integral unit (to avoid recursion) when fetching at stage3. + The ratio of a module's number of parameters/required forward passes (layers) + measures model granularity. Modules with values below this threshold will be + treated as an integral unit (to avoid recursion) when fetching at stage3. This will reduce the host overhead and separated allgather overhead in fetching parameters introduced by hooks for fine-grained layers. """ diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 5cdc2ddc925b..bdf7de827303 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -6,7 +6,7 @@ import sys import torch from collections import OrderedDict -from deepspeed.utils import z3_leaf_module, set_z3_leaf_modules, set_z3_leaf_module +from deepspeed.utils import z3_leaf_module, set_z3_leaf_module from deepspeed.runtime.utils import see_memory_usage from deepspeed.runtime.zero.utils import apply_to_tensors_only, is_zero_param from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum @@ -153,18 +153,8 @@ def __init__( self.min_granularity_value = sys.maxsize self.min_granularity_layer = None self.z3_leaf_layers = [] - self.count_layers_and_parameters(module) + self._set_z3_leaf_modules_by_threshold(module, zero_coalesced_fetch_threshold) - if self.min_granularity_value <= zero_coalesced_fetch_threshold: - self.set_z3_leaf_by_threshold(module, zero_coalesced_fetch_threshold) - print_rank_0(f"z3_leaf_module was setted by stage3_coalesced_fetch_threshold", force=True) - for layer in self.z3_leaf_layers: - print_rank_0(f"{layer.__class__.__name__}:{layer.ds_model_granularity}", force=True) - else: - utils.logger.warning( - f"You have used min_granularity_value, but the smallest module granularity is [{self.min_granularity_layer}:{self.min_granularity_value}],\ - To make this variable effective, you need to set stage3_coalesced_fetch_threshold >= {self.min_granularity_value}" - ) self.forward_hooks = [] self.backward_hooks = [] @@ -507,9 +497,24 @@ def post_sub_module_backward_function(self, sub_module): f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", force=False) - def count_layers_and_parameters(self, module): + def _set_z3_leaf_modules_by_threshold(self, module, zero_coalesced_fetch_threshold): + self._get_granularity_recursively(module) + if self.min_granularity_value <= zero_coalesced_fetch_threshold: + self._set_leaf_by_threshold_preorder(module, zero_coalesced_fetch_threshold) + print_rank_0(f"z3_leaf_module was set by stage3_coalesced_fetch_threshold", force=True) + for layer in self.z3_leaf_layers: + print_rank_0(f"{layer.__class__.__name__}:{layer.ds_model_granularity}", force=True) + else: + utils.logger.warning( + f"You have used zero_coalesced_fetch_threshold, but the smallest module granularity is [{self.min_granularity_layer}:{self.min_granularity_value}],\ + To make this variable effective, you need to set stage3_coalesced_fetch_threshold >= {self.min_granularity_value}" + ) + + def _get_granularity_recursively(self, module): + """This function is used to recursively obtain the granularity of each module.""" if not list(module.parameters()): + # skip Modules without parameters, such as GELU, etc. module.ds_model_granularity = sys.maxsize return 0, 0 @@ -518,14 +523,14 @@ def count_layers_and_parameters(self, module): num_params += sum(p.ds_numel for p in module.parameters(recurse=False)) for child in module.children(): - layers_in_child, params_in_child = self.count_layers_and_parameters(child) + layers_in_child, params_in_child = self._get_granularity_recursively(child) num_layers += layers_in_child num_params += params_in_child num_layers += 1 - - module.ds_sub_layers = num_layers - module.ds_sub_params = num_params + #for debug + # module.ds_sub_layers = num_layers + # module.ds_sub_params = num_params ds_model_granularity = num_params // num_layers module.ds_model_granularity = ds_model_granularity @@ -535,12 +540,16 @@ def count_layers_and_parameters(self, module): return num_layers, num_params - def set_z3_leaf_by_threshold(self, module, granularity_treshhold): + def _set_leaf_by_threshold_preorder(self, module, granularity_treshhold): + '''Set modules as leaf modules based on the threshold, prioritizing parent nodes.''' + + # Avoid setting as leaf for particularly large models, even if the granularity is very small + Z3_MAX_LEAF_SIZE = 5e8 + num_params = sum(p.ds_numel for p in module.parameters()) if num_params == 0: + # skip Modules without parameters, such as GELU, etc. return - # Avoid setting as leaf for particularly large models, even if the granularity is very small - Z3_MAX_LEAF_SIZE = 5e8 if module.ds_model_granularity < granularity_treshhold and num_params < Z3_MAX_LEAF_SIZE: set_z3_leaf_module(module, True) @@ -548,4 +557,4 @@ def set_z3_leaf_by_threshold(self, module, granularity_treshhold): return for sub_module in module.children(): - self.set_z3_leaf_by_threshold(sub_module, granularity_treshhold) + self._set_leaf_by_threshold_preorder(sub_module, granularity_treshhold) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index a7b84e8e9dc4..9a0dbb8e629d 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -462,24 +462,22 @@ def initialize_ds_offload( zero_quantized_nontrainable_weights, zero_coalesced_fetch_threshold, ): - return DeepSpeedZeRoOffload( - module=module, - timers=timers, - ds_config=ds_config, - overlap_comm=overlap_comm, - prefetch_bucket_size=prefetch_bucket_size, - max_reuse_distance=max_reuse_distance, - max_live_parameters=max_live_parameters, - param_persistence_threshold=param_persistence_threshold, - model_persistence_threshold=model_persistence_threshold, - dp_process_group=dp_process_group, - offload_param_config=offload_param_config, - mpu=mpu, - zero_param_parallel_group=zero_param_parallel_group, - zero_quantized_weights=zero_quantized_weights, - zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights, - zero_coalesced_fetch_threshold=zero_coalesced_fetch_threshold, - ) + return DeepSpeedZeRoOffload(module=module, + timers=timers, + ds_config=ds_config, + overlap_comm=overlap_comm, + prefetch_bucket_size=prefetch_bucket_size, + max_reuse_distance=max_reuse_distance, + max_live_parameters=max_live_parameters, + param_persistence_threshold=param_persistence_threshold, + model_persistence_threshold=model_persistence_threshold, + dp_process_group=dp_process_group, + offload_param_config=offload_param_config, + mpu=mpu, + zero_param_parallel_group=zero_param_parallel_group, + zero_quantized_weights=zero_quantized_weights, + zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights, + zero_coalesced_fetch_threshold=zero_coalesced_fetch_threshold) def _get_trainable_parameter_groups(self): param_groups = [] diff --git a/tests/unit/runtime/zero/test_zero_leaf_module.py b/tests/unit/runtime/zero/test_zero_leaf_module.py index 2c8bc58b0cc5..e06143d3be66 100644 --- a/tests/unit/runtime/zero/test_zero_leaf_module.py +++ b/tests/unit/runtime/zero/test_zero_leaf_module.py @@ -14,7 +14,6 @@ from deepspeed.accelerator import get_accelerator from torch import nn import time -from math import ceil class ChooseModuleByCounter(torch.nn.Module): @@ -192,7 +191,6 @@ def test_set_no_match_class(self): class TestZ3LeafOptimization(DistributedTest): - # Need multiple gpus to test possible hanging world_size = 2 reuse_dist_env = True @@ -233,7 +231,6 @@ def bench_loss_and_time(config): dtype=preferred_dtype()) dist.barrier() loss_list = [] - # for name, submodule in model.named_modules(): # if hasattr(submodule,'ds_sub_layers'): # tresh=submodule.ds_sub_params/submodule.ds_sub_layers @@ -241,8 +238,9 @@ def bench_loss_and_time(config): # print(f"module '{name}' layers: {submodule.ds_sub_layers}, params: {submodule.ds_sub_params}, tresh:{tresh}") for i, batch in enumerate(data_loader): if i == warm_up_step: + dist.barrier() get_accelerator().synchronize() - st = time.time() + start_time = time.time() batch[0].requires_grad = True loss = model(batch[0], batch[1]) loss = loss[1] @@ -250,13 +248,13 @@ def bench_loss_and_time(config): model.backward(loss) model.step() get_accelerator().synchronize() - en = time.time() - duration = en - st + end_time = time.time() + duration = end_time - start_time model.destroy() return loss_list, duration opt_loss_list, opt_duration = bench_loss_and_time(config_dict) del config_dict["zero_optimization"]["stage3_coalesced_fetch_threshold"] basic_loss_list, basic_duration = bench_loss_and_time(config_dict) - print(f"coalesced fetch time: {opt_duration}, basic duration time:{basic_duration}") + print(f"coalesced fetch time: {opt_duration}, basic time:{basic_duration}") assert basic_loss_list == opt_loss_list From e5f9430c2372163da3f5902a2cc070cdcbb30f8a Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 5 Nov 2024 08:59:19 +0000 Subject: [PATCH 11/21] remove debug code --- deepspeed/runtime/zero/parameter_offload.py | 5 +---- tests/unit/runtime/zero/test_zero_leaf_module.py | 6 +----- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index bdf7de827303..5ed89a2dd76f 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -506,7 +506,7 @@ def _set_z3_leaf_modules_by_threshold(self, module, zero_coalesced_fetch_thresho print_rank_0(f"{layer.__class__.__name__}:{layer.ds_model_granularity}", force=True) else: utils.logger.warning( - f"You have used zero_coalesced_fetch_threshold, but the smallest module granularity is [{self.min_granularity_layer}:{self.min_granularity_value}],\ + f"You have used stage3_coalesced_fetch_threshold, but the smallest module granularity is [{self.min_granularity_layer}:{self.min_granularity_value}],\ To make this variable effective, you need to set stage3_coalesced_fetch_threshold >= {self.min_granularity_value}" ) @@ -528,9 +528,6 @@ def _get_granularity_recursively(self, module): num_params += params_in_child num_layers += 1 - #for debug - # module.ds_sub_layers = num_layers - # module.ds_sub_params = num_params ds_model_granularity = num_params // num_layers module.ds_model_granularity = ds_model_granularity diff --git a/tests/unit/runtime/zero/test_zero_leaf_module.py b/tests/unit/runtime/zero/test_zero_leaf_module.py index e06143d3be66..edd42d94c3cc 100644 --- a/tests/unit/runtime/zero/test_zero_leaf_module.py +++ b/tests/unit/runtime/zero/test_zero_leaf_module.py @@ -231,11 +231,7 @@ def bench_loss_and_time(config): dtype=preferred_dtype()) dist.barrier() loss_list = [] - # for name, submodule in model.named_modules(): - # if hasattr(submodule,'ds_sub_layers'): - # tresh=submodule.ds_sub_params/submodule.ds_sub_layers - # if dist.get_rank()==0: - # print(f"module '{name}' layers: {submodule.ds_sub_layers}, params: {submodule.ds_sub_params}, tresh:{tresh}") + for i, batch in enumerate(data_loader): if i == warm_up_step: dist.barrier() From c2b020a4a7fc5949950bd8da6db3e9ca4d5268ef Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 5 Nov 2024 10:55:05 +0000 Subject: [PATCH 12/21] update --- deepspeed/runtime/zero/parameter_offload.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 5ed89a2dd76f..62901bb6ae67 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -506,13 +506,16 @@ def _set_z3_leaf_modules_by_threshold(self, module, zero_coalesced_fetch_thresho print_rank_0(f"{layer.__class__.__name__}:{layer.ds_model_granularity}", force=True) else: utils.logger.warning( - f"You have used stage3_coalesced_fetch_threshold, but the smallest module granularity is [{self.min_granularity_layer}:{self.min_granularity_value}],\ - To make this variable effective, you need to set stage3_coalesced_fetch_threshold >= {self.min_granularity_value}" + f"The smallest module granularity is [{self.min_granularity_layer}:{self.min_granularity_value}]. "\ + f"To make stage3_coalesced_fetch_threshold effective, you need to set stage3_coalesced_fetch_threshold >= {self.min_granularity_value}" ) def _get_granularity_recursively(self, module): """This function is used to recursively obtain the granularity of each module.""" + # Avoid setting as leaf for particularly large models, even if the granularity is very small + Z3_MAX_LEAF_SIZE = 1e9 + if not list(module.parameters()): # skip Modules without parameters, such as GELU, etc. module.ds_model_granularity = sys.maxsize @@ -521,6 +524,10 @@ def _get_granularity_recursively(self, module): num_layers = 0 num_params = 0 num_params += sum(p.ds_numel for p in module.parameters(recurse=False)) + if not any(module.children()): + # torch leaf module + module.ds_model_granularity = sys.maxsize + return 1, num_params for child in module.children(): layers_in_child, params_in_child = self._get_granularity_recursively(child) @@ -529,8 +536,10 @@ def _get_granularity_recursively(self, module): num_layers += 1 - ds_model_granularity = num_params // num_layers + ds_model_granularity = (num_params // num_layers) if num_params <= Z3_MAX_LEAF_SIZE else sys.maxsize module.ds_model_granularity = ds_model_granularity + # module.ds_model_num_layers = num_layers + # module.ds_model_num_params = num_params if self.min_granularity_value > ds_model_granularity: self.min_granularity_value = ds_model_granularity self.min_granularity_layer = module.__class__.__name__ @@ -540,15 +549,12 @@ def _get_granularity_recursively(self, module): def _set_leaf_by_threshold_preorder(self, module, granularity_treshhold): '''Set modules as leaf modules based on the threshold, prioritizing parent nodes.''' - # Avoid setting as leaf for particularly large models, even if the granularity is very small - Z3_MAX_LEAF_SIZE = 5e8 - num_params = sum(p.ds_numel for p in module.parameters()) if num_params == 0: # skip Modules without parameters, such as GELU, etc. return - if module.ds_model_granularity < granularity_treshhold and num_params < Z3_MAX_LEAF_SIZE: + if module.ds_model_granularity <= granularity_treshhold: set_z3_leaf_module(module, True) self.z3_leaf_layers.append(module) return From 36801098e84390d512bf53c78f333d693699556e Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 5 Nov 2024 13:02:45 +0000 Subject: [PATCH 13/21] don't set leaf for container module --- deepspeed/runtime/zero/parameter_offload.py | 11 +++++++---- tests/unit/runtime/zero/test_zero_leaf_module.py | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 62901bb6ae67..6f9c88ecc051 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -553,11 +553,14 @@ def _set_leaf_by_threshold_preorder(self, module, granularity_treshhold): if num_params == 0: # skip Modules without parameters, such as GELU, etc. return - if module.ds_model_granularity <= granularity_treshhold: - set_z3_leaf_module(module, True) - self.z3_leaf_layers.append(module) - return + if module.__class__.__name__ not in torch.nn.modules.container.__all__: + # Do not set container modules like ModuleList as leaf modules + # as this will prevent hooks from being set on their children + # and they may do not invoke the forward method + set_z3_leaf_module(module, True) + self.z3_leaf_layers.append(module) + return for sub_module in module.children(): self._set_leaf_by_threshold_preorder(sub_module, granularity_treshhold) diff --git a/tests/unit/runtime/zero/test_zero_leaf_module.py b/tests/unit/runtime/zero/test_zero_leaf_module.py index edd42d94c3cc..3b0106b70050 100644 --- a/tests/unit/runtime/zero/test_zero_leaf_module.py +++ b/tests/unit/runtime/zero/test_zero_leaf_module.py @@ -194,7 +194,7 @@ class TestZ3LeafOptimization(DistributedTest): world_size = 2 reuse_dist_env = True - def test_FineGrained_optimization(self): + def test_finegrained_optimization(self): hidden_dim = 128 num_block = 16 stage3_coalesced_fetch_threshold = 12000 From 22c0f81c61703543a6cb5c4938e94f0582b64e50 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Wed, 6 Nov 2024 08:28:28 +0000 Subject: [PATCH 14/21] update ut --- deepspeed/runtime/zero/parameter_offload.py | 18 ++++++--- .../runtime/zero/test_zero_leaf_module.py | 38 ++++++++++++------- 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 6f9c88ecc051..ac80205ecae4 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -152,6 +152,7 @@ def __init__( if zero_coalesced_fetch_threshold >= 0: self.min_granularity_value = sys.maxsize self.min_granularity_layer = None + self.granularity_info=set() self.z3_leaf_layers = [] self._set_z3_leaf_modules_by_threshold(module, zero_coalesced_fetch_threshold) @@ -501,7 +502,7 @@ def _set_z3_leaf_modules_by_threshold(self, module, zero_coalesced_fetch_thresho self._get_granularity_recursively(module) if self.min_granularity_value <= zero_coalesced_fetch_threshold: self._set_leaf_by_threshold_preorder(module, zero_coalesced_fetch_threshold) - print_rank_0(f"z3_leaf_module was set by stage3_coalesced_fetch_threshold", force=True) + print_rank_0(f"z3_leaf_module was set by stage3_coalesced_fetch_threshold:{zero_coalesced_fetch_threshold}", force=True) for layer in self.z3_leaf_layers: print_rank_0(f"{layer.__class__.__name__}:{layer.ds_model_granularity}", force=True) else: @@ -509,6 +510,8 @@ def _set_z3_leaf_modules_by_threshold(self, module, zero_coalesced_fetch_thresho f"The smallest module granularity is [{self.min_granularity_layer}:{self.min_granularity_value}]. "\ f"To make stage3_coalesced_fetch_threshold effective, you need to set stage3_coalesced_fetch_threshold >= {self.min_granularity_value}" ) + for granularity in self.granularity_info: + print_rank_0(granularity) def _get_granularity_recursively(self, module): """This function is used to recursively obtain the granularity of each module.""" @@ -534,8 +537,14 @@ def _get_granularity_recursively(self, module): num_layers += layers_in_child num_params += params_in_child + if module.__class__.__name__ in torch.nn.modules.container.__all__: + # Do not set container modules like ModuleList as leaf modules + # as this will prevent hooks from being set on their children + # and they may do not invoke the forward method + module.ds_model_granularity=sys.maxsize + return num_layers, num_params + num_layers += 1 - ds_model_granularity = (num_params // num_layers) if num_params <= Z3_MAX_LEAF_SIZE else sys.maxsize module.ds_model_granularity = ds_model_granularity # module.ds_model_num_layers = num_layers @@ -543,6 +552,7 @@ def _get_granularity_recursively(self, module): if self.min_granularity_value > ds_model_granularity: self.min_granularity_value = ds_model_granularity self.min_granularity_layer = module.__class__.__name__ + self.granularity_info.add(f"{module.__class__.__name__}:{ds_model_granularity}" ) return num_layers, num_params @@ -554,10 +564,6 @@ def _set_leaf_by_threshold_preorder(self, module, granularity_treshhold): # skip Modules without parameters, such as GELU, etc. return if module.ds_model_granularity <= granularity_treshhold: - if module.__class__.__name__ not in torch.nn.modules.container.__all__: - # Do not set container modules like ModuleList as leaf modules - # as this will prevent hooks from being set on their children - # and they may do not invoke the forward method set_z3_leaf_module(module, True) self.z3_leaf_layers.append(module) return diff --git a/tests/unit/runtime/zero/test_zero_leaf_module.py b/tests/unit/runtime/zero/test_zero_leaf_module.py index 3b0106b70050..b7ccf28f1975 100644 --- a/tests/unit/runtime/zero/test_zero_leaf_module.py +++ b/tests/unit/runtime/zero/test_zero_leaf_module.py @@ -87,13 +87,13 @@ def __init__(self, hidden_dim, num_block): super(modelWithFineGrainedBlock, self).__init__() self.coarse_grained_layer1 = nn.Linear(hidden_dim, 8 * hidden_dim) self.coarse_grained_layer2 = nn.Linear(8 * hidden_dim, hidden_dim) - self.finegrad_layer = FineGrainedBlock(hidden_dim, num_block) + self.fine_grained_layer = FineGrainedBlock(hidden_dim, num_block) self.cel = torch.nn.CrossEntropyLoss() def forward(self, x, y): x = self.coarse_grained_layer1(x) x = self.coarse_grained_layer2(x) - x = self.finegrad_layer(x) + x = self.fine_grained_layer(x) loss = self.cel(x, y) return x, loss @@ -142,9 +142,9 @@ def _test_set_z3_leaf_modules(self, cls, requires_grad): "stage3_max_reuse_distance": 0, } } - if get_accelerator().is_fp16_supported(): + if preferred_dtype() is torch.float16: config_dict["fp16"] = {"enabled": True} - elif get_accelerator().is_bf16_supported(): + elif preferred_dtype() is torch.bfloat16: config_dict["bf16"] = {"enabled": True} model = cls(hidden_dim) @@ -197,7 +197,7 @@ class TestZ3LeafOptimization(DistributedTest): def test_finegrained_optimization(self): hidden_dim = 128 num_block = 16 - stage3_coalesced_fetch_threshold = 12000 + stage3_coalesced_fetch_threshold_list = [0,100,12000,10000000] config_dict = { "train_micro_batch_size_per_gpu": 1, "steps_per_print": 1, @@ -212,13 +212,13 @@ def test_finegrained_optimization(self): "stage3_prefetch_bucket_size": hidden_dim**2, "stage3_param_persistence_threshold": 0, "stage3_max_reuse_distance": 0, - "stage3_coalesced_fetch_threshold": stage3_coalesced_fetch_threshold } } - if get_accelerator().is_fp16_supported(): + if preferred_dtype() is torch.float16: config_dict["fp16"] = {"enabled": True} - elif get_accelerator().is_bf16_supported(): + elif preferred_dtype() is torch.bfloat16: config_dict["bf16"] = {"enabled": True} + def bench_loss_and_time(config): warm_up_step = 10 @@ -248,9 +248,21 @@ def bench_loss_and_time(config): duration = end_time - start_time model.destroy() return loss_list, duration + result_loss_list=[] + result_duration=[] + + baseline_loss_list, baseline_exec_time = bench_loss_and_time(config_dict) + + for threshold in stage3_coalesced_fetch_threshold_list: + config_dict["zero_optimization"]["stage3_coalesced_fetch_threshold"]=threshold + loss_list, duration = bench_loss_and_time(config_dict) + result_duration.append(duration) + result_loss_list.append(loss_list) + if dist.get_rank()==0: + print(f"baseline exec time:",baseline_exec_time) + for idx,threshold in enumerate(stage3_coalesced_fetch_threshold_list): + if dist.get_rank()==0: + print(f"finegrained optimziation exec time: {result_duration[idx]}, threshold:{threshold} " ) + assert baseline_loss_list == result_loss_list[idx], f"incorrect loss value with threshold:{threshold}" + - opt_loss_list, opt_duration = bench_loss_and_time(config_dict) - del config_dict["zero_optimization"]["stage3_coalesced_fetch_threshold"] - basic_loss_list, basic_duration = bench_loss_and_time(config_dict) - print(f"coalesced fetch time: {opt_duration}, basic time:{basic_duration}") - assert basic_loss_list == opt_loss_list From f77325802be9f006c529571766aa2fdde3770250 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Wed, 6 Nov 2024 09:01:29 +0000 Subject: [PATCH 15/21] udpate --- deepspeed/runtime/zero/parameter_offload.py | 5 +++-- tests/unit/runtime/zero/test_zero_leaf_module.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index ac80205ecae4..31bcf7c86043 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -510,8 +510,9 @@ def _set_z3_leaf_modules_by_threshold(self, module, zero_coalesced_fetch_thresho f"The smallest module granularity is [{self.min_granularity_layer}:{self.min_granularity_value}]. "\ f"To make stage3_coalesced_fetch_threshold effective, you need to set stage3_coalesced_fetch_threshold >= {self.min_granularity_value}" ) + print_rank_0(f"{'MODULE NAME'.ljust(30)}|{'GRANULARITY VALUE'.rjust(20)}", force=True) for granularity in self.granularity_info: - print_rank_0(granularity) + print_rank_0(granularity, force=True) def _get_granularity_recursively(self, module): """This function is used to recursively obtain the granularity of each module.""" @@ -552,7 +553,7 @@ def _get_granularity_recursively(self, module): if self.min_granularity_value > ds_model_granularity: self.min_granularity_value = ds_model_granularity self.min_granularity_layer = module.__class__.__name__ - self.granularity_info.add(f"{module.__class__.__name__}:{ds_model_granularity}" ) + self.granularity_info.add(f"{module.__class__.__name__.ljust(30)}|{str(ds_model_granularity).rjust(10)}" ) return num_layers, num_params diff --git a/tests/unit/runtime/zero/test_zero_leaf_module.py b/tests/unit/runtime/zero/test_zero_leaf_module.py index b7ccf28f1975..b863a94ba3fe 100644 --- a/tests/unit/runtime/zero/test_zero_leaf_module.py +++ b/tests/unit/runtime/zero/test_zero_leaf_module.py @@ -197,7 +197,7 @@ class TestZ3LeafOptimization(DistributedTest): def test_finegrained_optimization(self): hidden_dim = 128 num_block = 16 - stage3_coalesced_fetch_threshold_list = [0,100,12000,10000000] + stage3_coalesced_fetch_threshold_list = [100] config_dict = { "train_micro_batch_size_per_gpu": 1, "steps_per_print": 1, @@ -251,18 +251,18 @@ def bench_loss_and_time(config): result_loss_list=[] result_duration=[] - baseline_loss_list, baseline_exec_time = bench_loss_and_time(config_dict) + # baseline_loss_list, baseline_exec_time = bench_loss_and_time(config_dict) for threshold in stage3_coalesced_fetch_threshold_list: config_dict["zero_optimization"]["stage3_coalesced_fetch_threshold"]=threshold loss_list, duration = bench_loss_and_time(config_dict) result_duration.append(duration) result_loss_list.append(loss_list) - if dist.get_rank()==0: - print(f"baseline exec time:",baseline_exec_time) - for idx,threshold in enumerate(stage3_coalesced_fetch_threshold_list): - if dist.get_rank()==0: - print(f"finegrained optimziation exec time: {result_duration[idx]}, threshold:{threshold} " ) - assert baseline_loss_list == result_loss_list[idx], f"incorrect loss value with threshold:{threshold}" + # if dist.get_rank()==0: + # print(f"baseline exec time:",baseline_exec_time) + # for idx,threshold in enumerate(stage3_coalesced_fetch_threshold_list): + # if dist.get_rank()==0: + # print(f"finegrained optimziation exec time: {result_duration[idx]}, threshold:{threshold} " ) + # assert baseline_loss_list == result_loss_list[idx], f"incorrect loss value with threshold:{threshold}" From c31ad0233fd4c4f11c3320df09ef5add7cb742c2 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Wed, 6 Nov 2024 10:13:38 +0000 Subject: [PATCH 16/21] change config name, refine doc --- deepspeed/runtime/engine.py | 8 ++--- deepspeed/runtime/zero/config.py | 13 ++++--- deepspeed/runtime/zero/parameter_offload.py | 35 ++++++++++--------- deepspeed/runtime/zero/stage3.py | 8 ++--- docs/_pages/config-json.md | 5 +++ .../runtime/zero/test_zero_leaf_module.py | 30 ++++++++-------- 6 files changed, 52 insertions(+), 47 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index bd6bc3493ada..8fe503c9f422 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -811,8 +811,8 @@ def zero_max_reuse_distance(self): def zero_prefetch_bucket_size(self): return self._config.zero_config.prefetch_bucket_size - def zero_coalesced_fetch_threshold(self): - return self._config.zero_config.coalesced_fetch_threshold + def zero_module_granularity_threshold(self): + return self._config.zero_config.module_granularity_threshold def zero_param_persistence_threshold(self): return self._config.zero_config.param_persistence_threshold @@ -1614,7 +1614,7 @@ def _configure_zero_optimizer(self, optimizer): zero_param_parallel_group=zero_param_parallel_group, zero_quantized_weights=self.zero_quantized_weights(), zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(), - zero_coalesced_fetch_threshold=self.zero_coalesced_fetch_threshold(), + zero_module_granularity_threshold=self.zero_module_granularity_threshold(), ) else: log_dist( @@ -1661,7 +1661,7 @@ def _configure_zero_optimizer(self, optimizer): zero_hpz_partition_size=self.zero_hpz_partition_size(), zero_quantized_weights=self.zero_quantized_weights(), zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(), - zero_coalesced_fetch_threshold=self.zero_coalesced_fetch_threshold(), + zero_module_granularity_threshold=self.zero_module_granularity_threshold(), ) else: diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index c4cf6f1a7419..7cac7e3c1ce7 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -21,7 +21,7 @@ "stage3_max_live_parameters" : 1000000000, "stage3_max_reuse_distance" : 1000000000, "stage3_use_all_reduce_for_fetch_params": [true|false], - "stage3_coalesced_fetch_threshold": 0, + "stage3_module_granularity_threshold": 0, "allgather_partitions": [true|false], "use_multi_rank_bucket_allreduce": [true|false], "allgather_bucket_size": 500000000, @@ -246,13 +246,12 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): this option is enabled and then saves the fp16 model weights. """ - coalesced_fetch_threshold: int = Field(pp_int(0), alias="stage3_coalesced_fetch_threshold") + module_granularity_threshold: int = Field(pp_int(0), alias="stage3_module_granularity_threshold") """ - The ratio of a module's number of parameters/required forward passes (layers) - measures model granularity. Modules with values below this threshold will be - treated as an integral unit (to avoid recursion) when fetching at stage3. - This will reduce the host overhead and separated allgather overhead in fetching - parameters introduced by hooks for fine-grained layers. + The granularity of a module is determined by the ratio of "parameter_count / (1 + descendant count)". + ZeRO3 classifies modules with a granularity below the threshold as fine-grained, + which are treated as integral units during parameter fetching. This reduces host overhead + and the separate allgather overhead introduced by hooks for fine-grained layers when fetching parameters. """ use_all_reduce_for_fetch_params: bool = Field(False, alias="stage3_use_all_reduce_for_fetch_params") diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 31bcf7c86043..4dd8a71f9c23 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -102,7 +102,7 @@ def __init__( zero_param_parallel_group=None, zero_quantized_weights=False, zero_quantized_nontrainable_weights=False, - zero_coalesced_fetch_threshold=0, + zero_module_granularity_threshold=0, ): see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=True) @@ -149,12 +149,12 @@ def __init__( module.ds_inflight_param_registry[False] = InflightParamRegistry() self.__inflight_param_registry = module.ds_inflight_param_registry - if zero_coalesced_fetch_threshold >= 0: + if zero_module_granularity_threshold >= 0: self.min_granularity_value = sys.maxsize self.min_granularity_layer = None - self.granularity_info=set() + self.granularity_info = set() self.z3_leaf_layers = [] - self._set_z3_leaf_modules_by_threshold(module, zero_coalesced_fetch_threshold) + self._set_z3_leaf_modules_by_threshold(module, zero_module_granularity_threshold) self.forward_hooks = [] self.backward_hooks = [] @@ -498,17 +498,19 @@ def post_sub_module_backward_function(self, sub_module): f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", force=False) - def _set_z3_leaf_modules_by_threshold(self, module, zero_coalesced_fetch_threshold): + def _set_z3_leaf_modules_by_threshold(self, module, zero_module_granularity_threshold): self._get_granularity_recursively(module) - if self.min_granularity_value <= zero_coalesced_fetch_threshold: - self._set_leaf_by_threshold_preorder(module, zero_coalesced_fetch_threshold) - print_rank_0(f"z3_leaf_module was set by stage3_coalesced_fetch_threshold:{zero_coalesced_fetch_threshold}", force=True) + if self.min_granularity_value <= zero_module_granularity_threshold: + self._set_leaf_by_threshold_preorder(module, zero_module_granularity_threshold) + print_rank_0( + f"z3_leaf_module was set by stage3_module_granularity_threshold:{zero_module_granularity_threshold}", + force=True) for layer in self.z3_leaf_layers: print_rank_0(f"{layer.__class__.__name__}:{layer.ds_model_granularity}", force=True) else: utils.logger.warning( f"The smallest module granularity is [{self.min_granularity_layer}:{self.min_granularity_value}]. "\ - f"To make stage3_coalesced_fetch_threshold effective, you need to set stage3_coalesced_fetch_threshold >= {self.min_granularity_value}" + f"To make stage3_module_granularity_threshold effective, you need to set stage3_module_granularity_threshold >= {self.min_granularity_value}" ) print_rank_0(f"{'MODULE NAME'.ljust(30)}|{'GRANULARITY VALUE'.rjust(20)}", force=True) for granularity in self.granularity_info: @@ -517,7 +519,8 @@ def _set_z3_leaf_modules_by_threshold(self, module, zero_coalesced_fetch_thresho def _get_granularity_recursively(self, module): """This function is used to recursively obtain the granularity of each module.""" - # Avoid setting as leaf for particularly large models, even if the granularity is very small + # avoid setting as leaf for particularly large models, even if the granularity is very small + # an oversized leaf module increases the number of live parameters, introducing memory overhead Z3_MAX_LEAF_SIZE = 1e9 if not list(module.parameters()): @@ -542,9 +545,9 @@ def _get_granularity_recursively(self, module): # Do not set container modules like ModuleList as leaf modules # as this will prevent hooks from being set on their children # and they may do not invoke the forward method - module.ds_model_granularity=sys.maxsize + module.ds_model_granularity = sys.maxsize return num_layers, num_params - + num_layers += 1 ds_model_granularity = (num_params // num_layers) if num_params <= Z3_MAX_LEAF_SIZE else sys.maxsize module.ds_model_granularity = ds_model_granularity @@ -553,7 +556,7 @@ def _get_granularity_recursively(self, module): if self.min_granularity_value > ds_model_granularity: self.min_granularity_value = ds_model_granularity self.min_granularity_layer = module.__class__.__name__ - self.granularity_info.add(f"{module.__class__.__name__.ljust(30)}|{str(ds_model_granularity).rjust(10)}" ) + self.granularity_info.add(f"{module.__class__.__name__.ljust(30)}|{str(ds_model_granularity).rjust(10)}") return num_layers, num_params @@ -565,9 +568,9 @@ def _set_leaf_by_threshold_preorder(self, module, granularity_treshhold): # skip Modules without parameters, such as GELU, etc. return if module.ds_model_granularity <= granularity_treshhold: - set_z3_leaf_module(module, True) - self.z3_leaf_layers.append(module) - return + set_z3_leaf_module(module, True) + self.z3_leaf_layers.append(module) + return for sub_module in module.children(): self._set_leaf_by_threshold_preorder(sub_module, granularity_treshhold) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 84fe3f983c38..76b9a2b4ee8a 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -157,7 +157,7 @@ def __init__( zero_hpz_partition_size=1, zero_quantized_weights=False, zero_quantized_nontrainable_weights=False, - zero_coalesced_fetch_threshold=0, + zero_module_granularity_threshold=0, ): see_memory_usage("Stage 3 initialize beginning", force=True) @@ -229,7 +229,7 @@ def __init__( zero_param_parallel_group=zero_param_parallel_group, zero_quantized_weights=zero_quantized_weights, zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights, - zero_coalesced_fetch_threshold=zero_coalesced_fetch_threshold) + zero_module_granularity_threshold=zero_module_granularity_threshold) self.persistent_parameters = self.parameter_offload.persistent_parameters self._configure_offloading(offload_optimizer_config, offload_param_config) @@ -460,7 +460,7 @@ def initialize_ds_offload( zero_param_parallel_group, zero_quantized_weights, zero_quantized_nontrainable_weights, - zero_coalesced_fetch_threshold, + zero_module_granularity_threshold, ): return DeepSpeedZeRoOffload(module=module, timers=timers, @@ -477,7 +477,7 @@ def initialize_ds_offload( zero_param_parallel_group=zero_param_parallel_group, zero_quantized_weights=zero_quantized_weights, zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights, - zero_coalesced_fetch_threshold=zero_coalesced_fetch_threshold) + zero_module_granularity_threshold=zero_module_granularity_threshold) def _get_trainable_parameter_groups(self): param_groups = [] diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index adb2f1679ea0..51e3bbd6eaaa 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -489,6 +489,11 @@ Enabling and configuring ZeRO memory optimizations |--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| ------- | | Consolidate the weights before saving the model by `save_16bit_model()`. Since the weights are partitioned across GPUs, they aren't part of `state_dict`, so this function automatically gathers the weights when this option is enabled and then saves the fp16 model weights. | `False` | +***stage3_module_granularity_threshold***: [integer] +| Description | Default | +|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| ------- | +| The granularity of a module is determined by the ratio of `parameter_count` / `(1 + descendant_count)`. ZeRO3 classifies modules with a granularity below the threshold as fine-grained, treating them as integral units during parameter fetching. This reduces host and communication overhead from separate hooks. | `0` | + ***zero_hpz_partition_size***: [integer] | Description | Default | diff --git a/tests/unit/runtime/zero/test_zero_leaf_module.py b/tests/unit/runtime/zero/test_zero_leaf_module.py index b863a94ba3fe..a1eacabc44ba 100644 --- a/tests/unit/runtime/zero/test_zero_leaf_module.py +++ b/tests/unit/runtime/zero/test_zero_leaf_module.py @@ -197,7 +197,7 @@ class TestZ3LeafOptimization(DistributedTest): def test_finegrained_optimization(self): hidden_dim = 128 num_block = 16 - stage3_coalesced_fetch_threshold_list = [100] + stage3_module_granularity_threshold_list = [0, 100, 12100, 10000000] config_dict = { "train_micro_batch_size_per_gpu": 1, "steps_per_print": 1, @@ -218,7 +218,6 @@ def test_finegrained_optimization(self): config_dict["fp16"] = {"enabled": True} elif preferred_dtype() is torch.bfloat16: config_dict["bf16"] = {"enabled": True} - def bench_loss_and_time(config): warm_up_step = 10 @@ -248,21 +247,20 @@ def bench_loss_and_time(config): duration = end_time - start_time model.destroy() return loss_list, duration - result_loss_list=[] - result_duration=[] - # baseline_loss_list, baseline_exec_time = bench_loss_and_time(config_dict) - - for threshold in stage3_coalesced_fetch_threshold_list: - config_dict["zero_optimization"]["stage3_coalesced_fetch_threshold"]=threshold + result_loss_list = [] + result_duration = [] + + baseline_loss_list, baseline_exec_time = bench_loss_and_time(config_dict) + + for threshold in stage3_module_granularity_threshold_list: + config_dict["zero_optimization"]["stage3_module_granularity_threshold"] = threshold loss_list, duration = bench_loss_and_time(config_dict) result_duration.append(duration) result_loss_list.append(loss_list) - # if dist.get_rank()==0: - # print(f"baseline exec time:",baseline_exec_time) - # for idx,threshold in enumerate(stage3_coalesced_fetch_threshold_list): - # if dist.get_rank()==0: - # print(f"finegrained optimziation exec time: {result_duration[idx]}, threshold:{threshold} " ) - # assert baseline_loss_list == result_loss_list[idx], f"incorrect loss value with threshold:{threshold}" - - + if dist.get_rank() == 0: + print(f"baseline exec time:", baseline_exec_time) + for idx, threshold in enumerate(stage3_module_granularity_threshold_list): + if dist.get_rank() == 0: + print(f"finegrained optimziation exec time: {result_duration[idx]},granularity threshold:{threshold} ") + assert baseline_loss_list == result_loss_list[idx], f"incorrect loss value with threshold:{threshold}" From 40ceeac45c6b2f9565d3f53202b9aa878da2c425 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Wed, 6 Nov 2024 10:27:41 +0000 Subject: [PATCH 17/21] fix rjust size --- deepspeed/runtime/zero/parameter_offload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 4dd8a71f9c23..c9f9d2af6a86 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -556,7 +556,7 @@ def _get_granularity_recursively(self, module): if self.min_granularity_value > ds_model_granularity: self.min_granularity_value = ds_model_granularity self.min_granularity_layer = module.__class__.__name__ - self.granularity_info.add(f"{module.__class__.__name__.ljust(30)}|{str(ds_model_granularity).rjust(10)}") + self.granularity_info.add(f"{module.__class__.__name__.ljust(30)}|{str(ds_model_granularity).rjust(20)}") return num_layers, num_params From c31c8d218a8d425508aaabb6f6de074bc651f7d8 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Wed, 6 Nov 2024 18:36:59 +0800 Subject: [PATCH 18/21] format --- deepspeed/runtime/zero/parameter_offload.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index eed414309667..a6be66c3c3b6 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -144,7 +144,7 @@ def __init__( if not hasattr(module, "ds_inflight_param_registry"): module.ds_inflight_param_registry = InflightParamRegistry() self.__inflight_param_registry = module.ds_inflight_param_registry - + self.param_coordinator = PartitionedParameterCoordinator( prefetch_bucket_sz=self._prefetch_bucket_sz, max_reuse_distance_in_numel=self._max_reuse_distance_in_numel, @@ -156,7 +156,7 @@ def __init__( zero_quantized_weights=self.zero_quantized_weights, zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights, ) - + if zero_module_granularity_threshold >= 0: self.min_granularity_value = sys.maxsize self.min_granularity_layer = None From 619cbe6879bbf3a532146dc91a924c7a631d13b2 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 7 Nov 2024 02:57:34 +0000 Subject: [PATCH 19/21] always print info if the config is enabled --- deepspeed/runtime/zero/parameter_offload.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index a6be66c3c3b6..42982cf38c22 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -494,22 +494,24 @@ def post_sub_module_backward_function(self, sub_module): force=False) def _set_z3_leaf_modules_by_threshold(self, module, zero_module_granularity_threshold): + self._get_granularity_recursively(module) + print_rank_0(f"{'MODULE NAME'.ljust(30)}|{'GRANULARITY VALUE'.rjust(20)}", force=True) + for granularity in self.granularity_info: + print_rank_0(granularity, force=True) + if self.min_granularity_value <= zero_module_granularity_threshold: self._set_leaf_by_threshold_preorder(module, zero_module_granularity_threshold) - print_rank_0( - f"z3_leaf_module was set by stage3_module_granularity_threshold:{zero_module_granularity_threshold}", - force=True) + utils.logger.info( + f"z3_leaf_module was set by stage3_module_granularity_threshold:{zero_module_granularity_threshold}") for layer in self.z3_leaf_layers: print_rank_0(f"{layer.__class__.__name__}:{layer.ds_model_granularity}", force=True) else: utils.logger.warning( f"The smallest module granularity is [{self.min_granularity_layer}:{self.min_granularity_value}]. "\ - f"To make stage3_module_granularity_threshold effective, you need to set stage3_module_granularity_threshold >= {self.min_granularity_value}" + f"To make stage3_module_granularity_threshold effective, you need to set stage3_module_granularity_threshold >= {self.min_granularity_value}. "\ + f"Current Value:{zero_module_granularity_threshold}" ) - print_rank_0(f"{'MODULE NAME'.ljust(30)}|{'GRANULARITY VALUE'.rjust(20)}", force=True) - for granularity in self.granularity_info: - print_rank_0(granularity, force=True) def _get_granularity_recursively(self, module): """This function is used to recursively obtain the granularity of each module.""" From a6e5a399d6ebe6370c7a5b8f17577ea46cfd551c Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 7 Nov 2024 04:10:23 +0000 Subject: [PATCH 20/21] update --- deepspeed/runtime/zero/parameter_offload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 42982cf38c22..082d7e874e4d 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -157,7 +157,7 @@ def __init__( zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights, ) - if zero_module_granularity_threshold >= 0: + if zero_module_granularity_threshold > 0: self.min_granularity_value = sys.maxsize self.min_granularity_layer = None self.granularity_info = set() From 25df9623cdc8ad138fb5221f6a4d779a2ab2ffae Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 11 Nov 2024 02:57:03 +0000 Subject: [PATCH 21/21] use mark parametrize for test --- .../runtime/zero/test_zero_leaf_module.py | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/tests/unit/runtime/zero/test_zero_leaf_module.py b/tests/unit/runtime/zero/test_zero_leaf_module.py index a1eacabc44ba..74c709883645 100644 --- a/tests/unit/runtime/zero/test_zero_leaf_module.py +++ b/tests/unit/runtime/zero/test_zero_leaf_module.py @@ -3,6 +3,7 @@ # DeepSpeed Team +import pytest import deepspeed.comm as dist import torch @@ -190,14 +191,14 @@ def test_set_no_match_class(self): pass +@pytest.mark.parametrize("module_granularity_threshold", [0, 100, 12100, 10000000]) class TestZ3LeafOptimization(DistributedTest): world_size = 2 reuse_dist_env = True - def test_finegrained_optimization(self): + def test_finegrained_optimization(self, module_granularity_threshold: int): hidden_dim = 128 num_block = 16 - stage3_module_granularity_threshold_list = [0, 100, 12100, 10000000] config_dict = { "train_micro_batch_size_per_gpu": 1, "steps_per_print": 1, @@ -248,19 +249,14 @@ def bench_loss_and_time(config): model.destroy() return loss_list, duration - result_loss_list = [] - result_duration = [] - baseline_loss_list, baseline_exec_time = bench_loss_and_time(config_dict) - for threshold in stage3_module_granularity_threshold_list: - config_dict["zero_optimization"]["stage3_module_granularity_threshold"] = threshold - loss_list, duration = bench_loss_and_time(config_dict) - result_duration.append(duration) - result_loss_list.append(loss_list) + config_dict["zero_optimization"]["stage3_module_granularity_threshold"] = module_granularity_threshold + loss, duration = bench_loss_and_time(config_dict) + if dist.get_rank() == 0: print(f"baseline exec time:", baseline_exec_time) - for idx, threshold in enumerate(stage3_module_granularity_threshold_list): - if dist.get_rank() == 0: - print(f"finegrained optimziation exec time: {result_duration[idx]},granularity threshold:{threshold} ") - assert baseline_loss_list == result_loss_list[idx], f"incorrect loss value with threshold:{threshold}" + print( + f"finegrained optimziation exec time: {duration},granularity threshold:{module_granularity_threshold} " + ) + assert baseline_loss_list == loss, f"incorrect loss value with threshold:{module_granularity_threshold}"