Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 9, 2024
1 parent 020ed54 commit dc3f8dd
Show file tree
Hide file tree
Showing 15 changed files with 82 additions and 93 deletions.
2 changes: 1 addition & 1 deletion colossalai/nn/optimizer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ A series of optimizers have been optimized and integrated.

### Distributed Adafactor

Distributed Adafactor supports tensor parallelism and ZerO optimization.
Distributed Adafactor supports tensor parallelism and ZerO optimization.

### Performance
| Version | iter | Float Percision | Device Nums | weight shape | Avg runtime(ms) | Avg Speed Up Rate | Best Speed Up Rate |
Expand Down
2 changes: 1 addition & 1 deletion colossalai/nn/optimizer/adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
relative_step=True,
warmup_init=False,
):
lr=None
lr = None
if lr is not None and relative_step:
raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
if warmup_init and not relative_step:
Expand Down
39 changes: 18 additions & 21 deletions colossalai/nn/optimizer/distributed_adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import torch
import torch.distributed as dist

# from torch.optim import Optimizer
from colossalai.interface.optimizer import DistributedOptim

from colossalai.shardformer.layer._operation import _gather, _split
from colossalai.tensor.d_tensor import get_sharding_spec, is_distributed_tensor

Expand All @@ -27,7 +27,7 @@ def __init__(
relative_step=True,
warmup_init=False,
):
lr=None
lr = None
if lr is not None and relative_step:
raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
if warmup_init and not relative_step:
Expand Down Expand Up @@ -55,7 +55,6 @@ def __init__(
self.use_first_moment = None # bool
self.use_zero = True
super().__init__(params, defaults)


def setup_distributed(
self,
Expand All @@ -82,9 +81,8 @@ def setup_distributed(
if self.data_parallel_group is not None:
self.data_parallel_size = dist.get_world_size(self.data_parallel_group)
self.use_zero = use_zero

self.shard_to_param = shard_to_param if shard_to_param is not None else {}


@staticmethod
def _get_lr(param_group, param_state):
Expand Down Expand Up @@ -164,9 +162,9 @@ def step(self, closure=None):
raise RuntimeError("Adafactor does not support sparse gradients.")
state = self.state[p]
self.grad_shape = grad.shape # 1 dim shape

# print(f"self.shard_to_param {self.shard_to_param}")

param_is_dtensor = is_distributed_tensor(self.shard_to_param.get(id(p)))

if param_is_dtensor:
Expand All @@ -192,13 +190,13 @@ def step(self, closure=None):
# Row Residual situation
if self.grad_shape[0] % self.data_parallel_size != 0:
state["exp_avg_sq_row"] = torch.zeros(
self.grad_shape[0], device=p.device, dtype=p.dtype
self.grad_shape[0], device=p.device, dtype=p.dtype
) # [H/dp/Tp]
else:
state["exp_avg_sq_row"] = torch.zeros(
self.grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype
self.grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype
) # [H/dp/Tp]

state["exp_avg_sq_col"] = torch.zeros(
self.grad_shape[1], device=p.device, dtype=p.dtype
) # [W]
Expand Down Expand Up @@ -248,16 +246,12 @@ def step(self, closure=None):
# Row Residual situation
if self.grad_shape[0] % self.data_parallel_size != 0:
# gather update[flatten] along dp group then reshape to [H/tp, W]
update = _gather(
input_=update, dim=-1, process_group=self.data_parallel_group
)
update = _gather(input_=update, dim=-1, process_group=self.data_parallel_group)
# view update to origin[tp] shape
update_reshape = update.view(-1, self.grad_shape[1])

# gather grad[flatten] along dp group then reshape to [H/tp, W]
grad = _gather(
input_=grad, dim=-1, process_group=self.data_parallel_group
)
grad = _gather(input_=grad, dim=-1, process_group=self.data_parallel_group)
grad_reshape = grad.view(-1, self.grad_shape[1])
exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp]
exp_avg_sq_col = state["exp_avg_sq_col"] # [W]
Expand All @@ -269,7 +263,9 @@ def step(self, closure=None):
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update_reshape.mul_(grad_reshape)
if self.use_zero:
update = _split(input_=update_reshape.view(-1), dim=-1, process_group=self.data_parallel_group)
update = _split(
input_=update_reshape.view(-1), dim=-1, process_group=self.data_parallel_group
)
else:
update = update_reshape
else:
Expand All @@ -287,7 +283,9 @@ def step(self, closure=None):
input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group
)
sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True)
update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam)
update_reshape = self._approx_sq_grad_row_parallel(
exp_avg_sq_row, exp_avg_sq_col, sq_row_meam
)
update_reshape.mul_(grad_reshape)
if self.use_zero:
update = update_reshape.view(-1)
Expand All @@ -308,8 +306,7 @@ def step(self, closure=None):

if group["weight_decay"] != 0:
p.add_(p, alpha=(-group["weight_decay"] * lr))

p.add_(-update)

p.add_(-update)

return loss
1 change: 0 additions & 1 deletion colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,6 @@ def forward(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
)


if not shard_config.parallel_output:
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)

Expand Down
8 changes: 6 additions & 2 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,17 @@ def module_policy(self):
GPT2LMHeadModel: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": not self.shard_config.parallel_output}
suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": not self.shard_config.parallel_output},
)
],
)
}
if self.shard_config.parallel_output:
addon_module[GPT2LMHeadModel].method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
addon_module[GPT2LMHeadModel].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
module_policy.update(addon_module)

if self.pipeline_stage_manager is not None:
Expand Down
10 changes: 8 additions & 2 deletions colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,18 @@ def module_policy(self):
new_item = {
LlamaForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col, kwargs={"gather_output": not self.shard_config.parallel_output})
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs={"gather_output": not self.shard_config.parallel_output},
)
],
)
}
if self.shard_config.parallel_output:
new_item[LlamaForCausalLM].method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
new_item[LlamaForCausalLM].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
policy.update(new_item)

if self.pipeline_stage_manager:
Expand Down
4 changes: 1 addition & 3 deletions examples/images/vit/vit_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,7 @@ def criterion(outputs, inputs):
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
# run pipeline forward backward
batch = iter([batch])
outputs = booster.execute_pipeline(
batch, model, criterion, optimizer, return_loss=True
)
outputs = booster.execute_pipeline(batch, model, criterion, optimizer, return_loss=True)
else:
outputs = model(**batch)
loss = criterion(outputs, None)
Expand Down
4 changes: 1 addition & 3 deletions examples/language/llama2/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,7 @@ def main():
) as pbar:
for step in pbar:
if use_pipeline:
outputs = booster.execute_pipeline(
dataloader_iter, model, _criterion, optimizer, return_loss=True
)
outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True)
loss = outputs["loss"]
else:
batch = next(dataloader_iter)
Expand Down
4 changes: 1 addition & 3 deletions examples/language/llama2/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,7 @@ def main():
) as pbar:
for step in pbar:
if use_pipeline:
outputs = booster.execute_pipeline(
dataloader_iter, model, _criterion, optimizer, return_loss=True
)
outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True)
loss = outputs["loss"]
else:
batch = next(dataloader_iter)
Expand Down
4 changes: 1 addition & 3 deletions examples/language/opt/opt_train_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, b
# Forward pass
for _ in pbar:
if use_pipeline:
outputs = booster.execute_pipeline(
dataloader, model, _criterion, optimizer, return_loss=True
)
outputs = booster.execute_pipeline(dataloader, model, _criterion, optimizer, return_loss=True)
# Backward and optimize
if is_pp_last_stage:
loss = outputs["loss"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ def _preprocess_data(data):
data = data_gen_fn()
model.train()
if booster.plugin.stage_manager is not None:
booster.execute_pipeline(
_preprocess_data(data), model, _criterion, optimizer, return_loss=True
)
booster.execute_pipeline(_preprocess_data(data), model, _criterion, optimizer, return_loss=True)
else:
output = model(**_preprocess_data(data))
loss = criterion(output)
Expand Down Expand Up @@ -108,9 +106,7 @@ def _preprocess_data(data):
data_for_shard = data_gen_fn()
data_for_origin = data_gen_fn()
if booster.plugin.stage_manager is not None:
booster.execute_pipeline(
_preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True
)
booster.execute_pipeline(_preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True)
booster.execute_pipeline(
_preprocess_data(data_for_origin),
new_model,
Expand Down
Loading

0 comments on commit dc3f8dd

Please sign in to comment.