Skip to content

Commit

Permalink
[feature] add API Reference & Sample to optimizer Readme; add state c…
Browse files Browse the repository at this point in the history
…heck for bert exam;
  • Loading branch information
duanjunwen committed Apr 10, 2024
1 parent efac2a1 commit 40a5528
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 9 deletions.
26 changes: 25 additions & 1 deletion colossalai/nn/optimizer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,31 @@ A series of optimizers have been optimized and integrated.

Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details.

### Distributed Adafactor API
### API Reference

{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }}

### Sample: Init with booster

```python
# ==============================
# Model Init
# ==============================
tp_model = TPModel()

# ==============================
# Optimizer Init
# ==============================
dist_optim = DistributedAdaFactor([p for n, p in tp_model.named_parameters()])

# ==============================
# Booster Init
# ==============================
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
criterion = lambda x: x.mean()
tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion)
```

### Performance
| Version | iter | Float Percision | Device Nums | weight shape | Avg runtime(ms) | Avg Speed Up Rate | Best Speed Up Rate |
Expand Down
9 changes: 3 additions & 6 deletions colossalai/nn/optimizer/distributed_adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
relative_step=True,
warmup_init=False,
):
lr=None
lr = None
if lr is not None and relative_step:
raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
if warmup_init and not relative_step:
Expand Down Expand Up @@ -162,17 +162,14 @@ def step(self, closure=None):
grad = p.grad
if grad.is_sparse:
raise RuntimeError("Adafactor does not support sparse gradients.")

state = self.state[p]
self.grad_shape = grad.shape # 1 dim shape

# print(f"self.shard_to_param {self.shard_to_param}")

param_is_dtensor = is_distributed_tensor(self.shard_to_param.get(id(p)))

if param_is_dtensor:
self.grad_shape = self.shard_to_param.get(id(p)).shape # tp shape (2 dim)

self.factored, self.use_first_moment = self._get_options(group, self.grad_shape)

if len(state) == 0:
state["step"] = 0
if self.use_first_moment:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/features/distributed_adafactor.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,4 @@ else:
dist_optim.step()
dist_optim.zero_grad()
```

<!-- doc-test-command: echo -->
8 changes: 7 additions & 1 deletion tests/test_optimizer/test_dist_adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@
from colossalai.utils import set_seed
from colossalai.zero import LowLevelZeroOptimizer
from tests.kit.model_zoo import model_zoo
from tests.test_optimizer._utils import run_bert_test
from tests.test_optimizer._utils import run_bert_test, check_optim_states
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_weight,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)


HEIGHT = 4
WIDTH = 4
_TP_SPEC = DimSpec([0])
Expand Down Expand Up @@ -679,6 +680,11 @@ def exam_bert_test(test_config):
atol, rtol = 5e-4, 5e-4
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)

# check optim states
check_optim_states(org_optimizer, sharded_optimizer.optim)


clear_layout_converter()
torch.cuda.empty_cache()

Expand Down

0 comments on commit 40a5528

Please sign in to comment.