Skip to content

Commit

Permalink
Merge branch 'master' into onebit
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored May 7, 2024
2 parents 3073f9a + 0fc19b6 commit 4a1dd4c
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 15 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/nv-torch-latest-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:

- name: Install pytorch
run: |
pip install -U --cache-dir $TORCH_CACHE torch==2.2.2 torchvision --index-url https://download.pytorch.org/whl/cu118
pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/cu118
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand All @@ -55,5 +55,5 @@ jobs:
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
cd tests
pytest $PYTEST_OPTS --forked -n 4 unit/ --torch_ver="2.2" --cuda_ver="11.8"
pytest $PYTEST_OPTS --forked -m 'sequential' unit/ --torch_ver="2.2" --cuda_ver="11.8"
pytest $PYTEST_OPTS --forked -n 4 unit/ --torch_ver="2.3" --cuda_ver="11.8"
pytest $PYTEST_OPTS --forked -m 'sequential' unit/ --torch_ver="2.3" --cuda_ver="11.8"
2 changes: 1 addition & 1 deletion deepspeed/linear/optimized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(self,
self.bias = bias
self.lora_config = lora_config
self.quantization_config = quantization_config
device = get_accelerator().current_device() if device is None else device
device = get_accelerator().current_device_name() if device is None else device
assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config"

self.zero_shards = self.lora_config.base_weight_sharding
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/fp16/fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def _get_norm_mask_idx(self, group):
group_mask_idx_list.append([grad_flat_st_idx, grad_flat_en_idx])
grad_flat_st_idx = grad_flat_en_idx

return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device())
return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device_name())

def step(self, closure=None):
"""
Expand Down
8 changes: 5 additions & 3 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def get_norm_with_moe_layers_fast(all_groups_norm, group):
# This implementation standardizes the grad_norm across ranks. A more precise implementation can be found in 'get_norm_with_moe_layers'.
# Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce
scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=group))
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_accelerator().current_device(), dtype=torch.float)
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_accelerator().current_device_name(), dtype=torch.float)
dist.all_reduce(scaled_norm_tensor, group=group)
all_groups_norm = scaled_norm_tensor.item()
#print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}")
Expand Down Expand Up @@ -424,9 +424,11 @@ def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=No
# # mask_tensor_ = torch.zeros_like(p, device=p.device, dtype=bool)
# # for mask_idx in grad_norm_mask[idx]:
# # mask_tensor_[mask_idx[0]:mask_idx[1]] = True
cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device(),
cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device_name(),
dtype=p.dtype).repeat(grad_norm_mask[idx].shape[0], 1)
mask_tensor = torch.zeros(p.shape[0] + 1, device=get_accelerator().current_device(), dtype=p.dtype)
mask_tensor = torch.zeros(p.shape[0] + 1,
device=get_accelerator().current_device_name(),
dtype=p.dtype)
mask_tensor = mask_tensor.scatter_(0, grad_norm_mask[idx].view(-1),
cum_sum_pairs.view(-1)).cumsum(0).bool()[:-1]

Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,7 +1409,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
norm_is_nan = total_norm.isnan()
inf_or_nan = norm_is_nan.logical_or(norm_is_inf)

err = torch.tensor(-1.0, device=self.device, dtype=torch.float)
err = torch.tensor(-1.0, device=inf_or_nan.device, dtype=torch.float)
total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm

return total_norm
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,9 @@ def create_dir_symlink(src, dest):
if sys.platform == "win32":
# This creates a symbolic links on Windows.
# It needs Administrator privilege to create symlinks on Windows.
create_dir_symlink('..\\..\\csrc', '.\\deepspeed\\ops\\csrc')
create_dir_symlink('..\\..\\op_builder', '.\\deepspeed\\ops\\op_builder')
create_dir_symlink('..\\accelerator', '.\\deepspeed\\accelerator')
create_dir_symlink('.\\deepspeed\\ops\\csrc', '..\\..\\csrc')
create_dir_symlink('.\\deepspeed\\ops\\op_builder', '..\\..\\op_builder')
create_dir_symlink('.\\deepspeed\\accelerator', '..\\accelerator')
egg_info.manifest_maker.template = 'MANIFEST_win.in'

# Parse the DeepSpeed version string from version.txt.
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class TestTopk(DistributedTest):
world_size = 2

def test(self):
device = get_accelerator().current_device()
device = get_accelerator().current_device_name()
if dist.get_rank() == 0:
logits = torch.rand(2, 2, device=device)
elif dist.get_rank() == 1:
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/runtime/compile/test_compile_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from unit.runtime.compile.util import compare_loss
from unit.common import DistributedTest
from unit.util import bf16_required_version_check
from unit.util import bf16_required_version_check, skip_on_arch

pytestmark = pytest.mark.skipif(not required_torch_version(min_version=2.1),
reason="Compile tests requires Pytorch version 2.1 or above")
Expand All @@ -26,9 +26,11 @@ class TestZeRO(DistributedTest):
@pytest.mark.parametrize('zero_stage', [1, 2, 3])
@pytest.mark.parametrize('offload_device', [OffloadDeviceEnum.none, OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme])
def test_compile_zero(self, tmpdir, zero_stage, dtype, offload_device):
if dtype == torch.bfloat16:
skip_on_arch(min_arch=8)
if dtype == torch.bfloat16 and not bf16_required_version_check():
pytest.skip(
" DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly"
"DeepSpeed BFloat16 tests need NCCL >= 2.10.3, CUDA >=11.0, and HW support for BFloat16 to run correctly"
)
if get_accelerator().device_name() == "cpu":
pytest.skip("CPU does not support this test yet")
Expand Down

0 comments on commit 4a1dd4c

Please sign in to comment.