Skip to content

Commit

Permalink
Revert "stage3: efficient compute of scaled_global_grad_norm (#5256)" (
Browse files Browse the repository at this point in the history
…#5461)

This reverts commit 54c0687 due to
#5256 causing bugs when the ZeRO3 + ZeRO Offload features are enabled.

This bug was discovered due to failures in the DS Chat CI workflow.
Failing tests across CI failures:
| Failing Test Name |
| --- |
| test_ds_chat[zero3--offload-] |
| test_ds_chat[zero3--offload-lora] |
| test_ds_chat[zero3-he-offload-] |
| test_ds_chat[zero3-he-offload-lora] |

Error message:
```
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cpu!
```

It seems that `torch.stack()` or `torch.norm()` is having issues when
the offload feature is enabled and tensors are split between CPU/GPU,
however this is just an initial guess and would require more
investigation.

@nelyahu Since you are the original author of the PR, if you have some
bandwidth, any help here is greatly appreciated!

After reverting this commit, all tests pass in the DS Chat CI workflow:

https://github.com/microsoft/DeepSpeed/actions/runs/8824064414/job/24225802763

@tjruwase for context.
  • Loading branch information
lekurile authored Apr 25, 2024
1 parent fcc731f commit bc48371
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from deepspeed.utils import logger
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item
from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter, get_only_unique_item
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
Expand Down Expand Up @@ -2027,7 +2027,7 @@ def step(self, closure=None):
return

norm_groups = self._get_norm_groups()
scaled_global_grad_norm = torch.norm(torch.stack(norm_groups))
scaled_global_grad_norm = get_global_norm(norm_list=norm_groups)

# Stash unscaled gradient norm
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale
Expand Down

0 comments on commit bc48371

Please sign in to comment.