Skip to content

Commit

Permalink
Enhance testing: Skip fused_optimizer tests if not supported. (#5159)
Browse files Browse the repository at this point in the history
Enhance testing: Skip fused_optimizer tests if not supported.

Added condition check to skip fused_optimizer tests if FusedAdam and
FusedLamb are not supported by the accelerator. This enhancement ensures
that the tests are appropriately skipped when the hardware configuration
does not support these optimizers, preventing potential issues.

Details:
- Introduced a condition check to determine support for FusedAdam and
FusedLamb.
- If not supported, fused_optimizer tests are skipped to improve test
reliability.
- Improved compatibility and stability across different hardware
configurations.

---------

Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
vshekhawat-hlab and loadams authored May 16, 2024
1 parent 23173fa commit 7f55b20
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 3 deletions.
3 changes: 3 additions & 0 deletions tests/unit/elasticity/test_elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def test_proper_mbsz(ds_config):
class TestNonElasticBatchParams(DistributedTest):
world_size = 2

@pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], reason="lamb is not compatible")
def test(self):
config_dict = {
"train_batch_size": 2,
Expand Down Expand Up @@ -182,6 +183,7 @@ def test(self):
class TestNonElasticBatchParamsWithOverride(DistributedTest):
world_size = 2

@pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], reason="lamb is not compatible")
def test(self):
if not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME]:
pytest.skip("This op had not been implemented on this system.", allow_module_level=True)
Expand Down Expand Up @@ -215,6 +217,7 @@ def test(self):
class TestElasticConfigChanged(DistributedTest):
world_size = 2

@pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], reason="lamb is not compatible")
def test(self):
config_dict = {
"train_batch_size": 2,
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/ops/adam/test_cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import deepspeed
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.adam import FusedAdam
from deepspeed.ops.op_builder import CPUAdamBuilder
from deepspeed.ops.op_builder import CPUAdamBuilder, FusedAdamBuilder
from unit.common import DistributedTest

if not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
Expand Down Expand Up @@ -62,6 +62,8 @@ class TestCPUAdam(DistributedTest):
set_dist_env = False

@pytest.mark.skipif(not get_accelerator().is_available(), reason="only supported in CUDA environments.")
@pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME],
reason="FusedAdam is not compatible")
def test_fused_adam_equal(self, dtype, model_size):
if ("amd" in pytest.cpu_vendor) and (dtype == torch.half):
pytest.skip("cpu-adam with half precision not supported on AMD CPUs")
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/ops/adam/test_hybrid_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import deepspeed
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.adam import FusedAdam, DeepSpeedCPUAdam
from deepspeed.ops.op_builder import CPUAdamBuilder
from deepspeed.ops.op_builder import CPUAdamBuilder, FusedAdamBuilder
from unit.common import DistributedTest

if not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
Expand Down Expand Up @@ -43,6 +43,8 @@ class TestHybridAdam(DistributedTest):
set_dist_env = False

@pytest.mark.skipif(not get_accelerator().is_available(), reason="only supported in CUDA environments.")
@pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME],
reason="FusedAdam is not compatible")
def test_hybrid_adam_equal(self, dtype, model_size):
if ("amd" in pytest.cpu_vendor) and (dtype == torch.half):
pytest.skip("cpu-adam with half precision not supported on AMD CPUs")
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/runtime/half_precision/test_dynamic_loss_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
from unit.common import DistributedTest
from unit.simple_model import SimpleModel
from deepspeed.ops.op_builder import FusedLambBuilder


def run_model_step(model, gradient_list):
Expand Down Expand Up @@ -152,6 +153,7 @@ def test_some_overflow(self):
assert optim.cur_iter == expected_iteration


@pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], reason="lamb is not compatible")
class TestUnfused(DistributedTest):
world_size = 1

Expand Down
9 changes: 8 additions & 1 deletion tests/unit/runtime/half_precision/test_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from unit.simple_model import SimpleModel, SimpleOptimizer, random_dataloader, SimpleMoEModel, sequence_dataloader
from deepspeed.utils.torch import required_torch_version
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import CPUAdamBuilder
from deepspeed.ops.op_builder import CPUAdamBuilder, FusedLambBuilder
from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer

try:
Expand All @@ -22,7 +22,11 @@
_amp_available = False
amp_available = pytest.mark.skipif(not _amp_available, reason="apex/amp is not installed")

if torch.half not in get_accelerator().supported_dtypes():
pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True)


@pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], reason="lamb is not compatible")
class TestLambFP32GradClip(DistributedTest):
world_size = 2

Expand Down Expand Up @@ -55,6 +59,7 @@ def test(self):
model.step()


@pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], reason="lamb is not compatible")
class TestLambFP16(DistributedTest):
world_size = 2

Expand Down Expand Up @@ -231,6 +236,7 @@ def mock_unscale_and_clip_grads(grads_groups_flat, total_norm, apply_scale=True)
engine.backward(loss)
engine.step()

@pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], reason="lamb is not compatible")
@pytest.mark.parametrize("fused_lamb_legacy", [(False), (True)])
def test_lamb_gradnorm(self, monkeypatch, fused_lamb_legacy: bool):
if not get_accelerator().is_fp16_supported():
Expand Down Expand Up @@ -495,6 +501,7 @@ def test_adam_basic(self):
model.backward(loss)
model.step()

@pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], reason="lamb is not compatible")
def test_lamb_basic(self):
if not get_accelerator().is_fp16_supported():
pytest.skip("fp16 is not supported")
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/runtime/test_ds_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from deepspeed.runtime.utils import see_memory_usage
from deepspeed.utils.torch import required_torch_version
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import FusedAdamBuilder


@pytest.mark.parametrize('zero_stage', [0, 3])
Expand Down Expand Up @@ -67,6 +68,9 @@ def test(self, optimizer_type):
def _optimizer_callable(params) -> Optimizer:
return AdamW(params=params)

if (optimizer_type is None) and (not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME]):
pytest.skip("FusedAdam is not compatible")

hidden_dim = 10
model = SimpleModel(hidden_dim)

Expand Down Expand Up @@ -95,6 +99,8 @@ def _optimizer_callable(params) -> Optimizer:
class TestConfigOptimizer(DistributedTest):
world_size = 1

@pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME],
reason="FusedAdam is not compatible")
def test(self, client_parameters):
ds_config = {"train_batch_size": 1, "optimizer": {"type": "Adam", "params": {"lr": 0.001}}}

Expand Down

0 comments on commit 7f55b20

Please sign in to comment.