From f77ffc2c6b7485f01193148e8758b77386e89a6b Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 8 Apr 2024 17:35:01 +0800 Subject: [PATCH 01/35] [feature] solve conflict; update optimizer readme; --- colossalai/nn/optimizer/README.md | 44 ++ colossalai/nn/optimizer/adafactor.py | 208 ++++++ .../nn/optimizer/distributed_adafactor.py | 275 ++++++++ tests/test_optimizer/test_dist_adafactor.py | 594 ++++++++++++++++++ 4 files changed, 1121 insertions(+) create mode 100644 colossalai/nn/optimizer/adafactor.py create mode 100644 colossalai/nn/optimizer/distributed_adafactor.py create mode 100644 tests/test_optimizer/test_dist_adafactor.py diff --git a/colossalai/nn/optimizer/README.md b/colossalai/nn/optimizer/README.md index d3f8badc7313..b846f1cf55cb 100644 --- a/colossalai/nn/optimizer/README.md +++ b/colossalai/nn/optimizer/README.md @@ -81,3 +81,47 @@ If you wish to add an optimizer for a specific application, please follow the st If your PR is accepted, we may invite you to put up a tutorial or blog in [ColossalAI Documentation](https://colossalai.org/). + + +## Optimizer + +A series of optimizers have been optimized and integrated. + +### Distributed Adafactor + +Distributed Adafactor supports tensor parallelism and ZerO optimization. Here is a brief flowchart of how adafactor implements the tensor parallel: + +[[Tensor Parallel Strategy in Distributed Adafactor]](adafactor_strategy.png) + +### Performance +| Version | iter | Float Percision | Device Nums | weight shape | Avg runtime(ms) | Avg Speed Up Rate | Best Speed Up Rate | +| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | +| AdaFactor | 50 | float32 | 2 | [4096 , 4096] | 0.58 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.41 | 1.39 | 56.91 | +| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.61 | 0.96 | 18.69 | +| AdaFactor | 50 | float16 | 2 | [4096 , 4096] | 0.72 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.54 | 1.33 | 26.03 | +| DistAdaFactor(Row Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.67 | 1.08 | 20.55 | +| AdaFactor | 50 | bfloat16 | 2 | [4096 , 4096] | 0.72 | - | - | +| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.55 | 1.31 | 26.11 | +| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.67 | 1.07 | 21.86 | +| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | +| AdaFactor | 50 | float32 | 4 | [4096 , 4096] | 0.57 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.38 | 1.48 | 53.99 | +| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.60 | 0.95 | 16.53 | +| AdaFactor | 50 | float16 | 4 | [4096 , 4096] | 0.70 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.50 | 1.44 | 21.98 | +| DistAdaFactor(Row Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.64 | 1.12 | 15.35 | +| AdaFactor | 50 | bfloat16 | 4 | [4096 , 4096] | 0.72 | - | - | +| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.56 | 1.29 | 25.63 | +| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.71 | 1.09 | 21.52 | +| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | +| AdaFactor | 50 | float32 | 8 | [4096 , 4096] | 0.56 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.38 | 1.50 | 54.51 | +| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.91 | 0.67 | 15.68 | +| AdaFactor | 50 | float16 | 8 | [4096 , 4096] | 0.74 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float16 | 8 | [4096 , 4096] | 0.84 | 0.87 | 9.21 | +| DistAdaFactor(Row Parallel) | 50 | float16 | 8 | [4096 , 4096] | 1.03 | 0.75 | 16.12 | +| AdaFactor | 50 | bfloat16 | 8 | [4096 , 4096] | 0.71 | - | - | +| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.54 | 1.31 | 27.28 | +| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.73 | 1.03 | 25.01 | diff --git a/colossalai/nn/optimizer/adafactor.py b/colossalai/nn/optimizer/adafactor.py new file mode 100644 index 000000000000..c71676bf5459 --- /dev/null +++ b/colossalai/nn/optimizer/adafactor.py @@ -0,0 +1,208 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +from torch.optim import Optimizer + +__all__ = ["Adafactor"] + + +# Adafactor +class Adafactor(Optimizer): + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + ): + if lr is not None and relative_step: + raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") + if warmup_init and not relative_step: + raise ValueError("`warmup_init=True` requires `relative_step=True`") + + defaults = { + "lr": lr, + "eps": eps, + "clip_threshold": clip_threshold, + "decay_rate": decay_rate, + "beta1": beta1, + "weight_decay": weight_decay, + "scale_parameter": scale_parameter, + "relative_step": relative_step, + "warmup_init": warmup_init, + } + super().__init__(params, defaults) + + @staticmethod + def _get_lr(param_group, param_state): + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps"][1], param_state["RMS"]) + return param_scale * rel_step_sz + + @staticmethod + def _get_options(param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + """ + param_groups: Dict + { + "params":[weight, bias] + "lr" + "eps" + "clip_threshold" + "decay_rate" + "beta1" + "weight_decay" + "scale_parameter" + "relative_step" + "warmup_init" + } + """ + + for group in self.param_groups: + # update weight & bias + for p in group["params"]: + if p.grad is None: + continue + """ + # grad shape is same as weigh / bias + """ + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + """ + p is weight + state + {'step', + 'exp_avg_sq_row', + 'exp_avg_sq_col', + 'RMS' + } + + p is bias + state + {'step', + 'exp_avg_sq', + 'RMS' + } + """ + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = self._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1], device=grad.device) + state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:], device=grad.device) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state["step"] += 1 + # state["RMS"] = self._rms(p_data_fp32) + lr = self._get_lr(group, state) + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad**2) + group["eps"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + # Exponential average of row indexes + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + # Exponential average of columns indexes + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + # RMS + # update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) + p_data_fp32.add_(-update) + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_data_fp32) + + return loss diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py new file mode 100644 index 000000000000..50792ba9af0d --- /dev/null +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -0,0 +1,275 @@ +import math +from typing import Dict + +import torch +import torch.distributed as dist +from torch.optim import Optimizer + +from colossalai.shardformer.layer._operation import _gather +from colossalai.tensor.d_tensor import get_sharding_spec, is_distributed_tensor + +# DistributedAdaFactor (with Tensor parallel and Zero stage 2) +__all__ = ["DistributedAdaFactor"] + + +class DistributedAdaFactor(Optimizer): + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + ): + if lr is not None and relative_step: + raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") + if warmup_init and not relative_step: + raise ValueError("`warmup_init=True` requires `relative_step=True`") + + defaults = { + "lr": lr, + "eps": eps, + "clip_threshold": clip_threshold, + "decay_rate": decay_rate, + "beta1": beta1, + "weight_decay": weight_decay, + "scale_parameter": scale_parameter, + "relative_step": relative_step, + "warmup_init": warmup_init, + } + super().__init__(params, defaults) + self.tensor_parallel_size = 1 + self.tensor_parallel_group = None + self.data_parallel_size = 1 + self.data_parallel_group = None + self.shard_to_param = None # Dict{id:shape}, sample {id(param): torch.tensor} + self.shard_spec = None + self.grad_shape = None + self.factored = None # bool + self.use_first_moment = None # bool + self.use_zero = True + + def setup_distributed( + self, + tensor_parallel_group: dist.ProcessGroup = None, + data_parallel_group: dist.ProcessGroup = None, + shard_to_param: Dict = None, + use_zero: bool = True, + ) -> None: + """ + Inject features to the Optimizer + + Args: + tensor_parallel_group: The devices group for tensor parallel; + data_parallel_group: The devices group for data parallel; + sharding_spec_dict: ShardingSpecs of Each params; + param_shape: Paramater Shape of Each params; + use_zero: Whether or not to use zero; + + """ + self.tensor_parallel_group = tensor_parallel_group # "Expected row process group" + self.data_parallel_group = data_parallel_group + if self.tensor_parallel_group is not None: + self.tensor_parallel_size = dist.get_world_size(self.tensor_parallel_group) + if self.data_parallel_group is not None: + self.data_parallel_size = dist.get_world_size(self.data_parallel_group) + self.shard_to_param = shard_to_param + self.use_zero = use_zero + + @staticmethod + def _get_lr(param_group, param_state): + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps"][1], param_state["RMS"]) + return param_scale * rel_step_sz + + @staticmethod + def _get_options(param_group, param_shape): + """ + Determines whether the current param is factored + Args: + param_group : param group + param_shape : Original Shape of param + + """ + factored = len(param_shape) >= 2 + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + # approx_sq_grad for row parallel weight + @staticmethod + def _approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam): + # row_meam = sq_row_meam + r_factor = (exp_avg_sq_row / sq_row_meam).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization steps + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + """ + param_groups: Dict + { + "params":[weight, bias] + "lr" + "eps" + "clip_threshold" + "decay_rate" + "beta1" + "weight_decay" + "scale_parameter" + "relative_step" + "warmup_init" + } + """ + for group in self.param_groups: + # update weight & bias + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + state = self.state[p] + self.grad_shape = grad.shape # 1 dim shape + + param_is_dtensor = is_distributed_tensor(self.shard_to_param.get(id(p))) + + if param_is_dtensor: + self.grad_shape = self.shard_to_param.get(id(p)).shape # tp shape (2 dim) + + self.factored, self.use_first_moment = self._get_options(group, self.grad_shape) + if len(state) == 0: + state["step"] = 0 + if self.use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p) + if self.factored: + self.shard_spec = get_sharding_spec(self.shard_to_param.get(id(p))) + if self.shard_spec.sharding_sequence[0] == "R": # Col Parallel + state["exp_avg_sq_row"] = torch.zeros( + self.grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype + ) # [H/dp] + state["exp_avg_sq_col"] = torch.zeros( + self.grad_shape[1], device=p.device, dtype=p.dtype + ) # [W/TP] + + if self.shard_spec.sharding_sequence[-1] == "R": # Row Parallel + state["exp_avg_sq_row"] = torch.zeros( + self.grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype + ) # [H/dp/Tp] + state["exp_avg_sq_col"] = torch.zeros( + self.grad_shape[1], device=p.device, dtype=p.dtype + ) # [W] + else: + state["exp_avg_sq"] = torch.zeros_like(p) + state["RMS"] = 0 + else: + if self.use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if self.factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p.float() + # if p.dtype in {torch.float16, torch.bfloat16}: + # p_data_fp32 = p_data_fp32.float() + + state["step"] += 1 + lr = self._get_lr(group, state) + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad**2) + group["eps"][0] + + if self.factored: + # ============================== + # First Dim is R, Last Dim is S{} means split dim -1 ---> + # Coloum Parallel ---> sq_row need Do (col) Reduce + # ============================== + self.shard_spec = get_sharding_spec(self.shard_to_param.get(id(p))) + if self.shard_spec.sharding_sequence[0] == "R": + update_reshape = update.view(-1, self.grad_shape[1]) + grad_reshape = grad.view(-1, self.grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + dist.all_reduce(exp_avg_sq_row, group=self.tensor_parallel_group) + exp_avg_sq_row.div_(self.tensor_parallel_size) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + # update = update_reshape.view(update_reshape.shape[0]*update_reshape.shape[1]) + update = update_reshape.view(-1) + # ============================== + # Last Dim is R, First Dim is S{} means split dim 0 ---> + # Row Parallel ---> sq_col need Do (row) Reduce + # ============================== + elif self.shard_spec.sharding_sequence[-1] == "R": + update_reshape = update.view(-1, self.grad_shape[1]) + grad_reshape = grad.view(-1, self.grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp/tp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group) + exp_avg_sq_col.div_(self.tensor_parallel_size) + # gather row + exp_avg_sq_row_gather = _gather( + input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group + ) + sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True) + update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam) + update_reshape.mul_(grad_reshape) + # update = update_reshape.view(update_reshape.shape[0]*update_reshape.shape[1]) + update = update_reshape.view(-1) + else: + exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + # (Line No.8) RMS + # update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + if self.use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) + + p_data_fp32.add_(-update).flatten() + + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_data_fp32) + + return loss diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py new file mode 100644 index 000000000000..f675f42301a8 --- /dev/null +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -0,0 +1,594 @@ +import copy +import os + +import pytest +import torch +import torch.distributed as dist +from torch import nn + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import TorchDDPPlugin +from colossalai.cluster import ProcessGroupMesh +from colossalai.device.device_mesh import DeviceMesh +from colossalai.nn.optimizer.adafactor import Adafactor +from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer._operation import _gather +from colossalai.tensor.d_tensor import ( + distribute_tensor, + get_layout, + get_sharding_spec, + is_distributed_tensor, + shard_colwise, + shard_rowwise, +) +from colossalai.tensor.d_tensor.sharding_spec import DimSpec +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed +from colossalai.zero import LowLevelZeroOptimizer + +HEIGHT = 4096 +WIDTH = 4096 +_TP_SPEC = DimSpec([0]) + + +def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torch.dtype = torch.float32): + rtol = None + atol = None + if dtype is torch.float32: + rtol = 1e-05 + atol = 1e-05 + elif dtype is torch.float16: + rtol = 5e-2 + atol = 5e-4 + elif dtype is torch.bfloat16: + rtol = 4e-3 + atol = 4e-3 + + return torch.all(tensor1.isclose(tensor2, rtol=rtol, atol=atol)) + # assert_close(tensor1, tensor2, rtol=rtol, atol=atol) + + +# setup param groups; (For zero test optim) +def setup_param_groups_zero(model: nn.Module) -> list: + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": 0.1, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + return optimizer_grouped_parameters + + +# setup param groups; (For base optim) +def setup_param_groups(model: nn.Module) -> list: + optimizer_grouped_parameters = [p for n, p in model.named_parameters()] + return optimizer_grouped_parameters + + +# setup flatten param groups, sharding spec and shape; (For dist optim) +def setup_flatten_param_groups_sharding_spec_shape(model: nn.Module) -> dict: + flatten_optimizer_grouped_parameters = [] + sharding_spec = {} # {id(flatten param): get_layout(p).global_shape} + param_shape = {} # {id(flatten param): get_sharding_spec(p)} + for n, p in model.named_parameters(): + # flatten_p = copy.deepcopy(p).flatten() + flatten_p = nn.Parameter(p.clone().flatten().requires_grad_(True)) + flatten_optimizer_grouped_parameters.append(flatten_p) + if is_distributed_tensor(p): + sharding_spec[id(flatten_p)] = get_sharding_spec(p) + param_shape[id(flatten_p)] = get_layout(p).global_shape + else: + sharding_spec[id(flatten_p)] = None + param_shape[id(flatten_p)] = p.shape + # print(f"sharding_spec {sharding_spec}") + # print(f"param_shape {param_shape}") + return flatten_optimizer_grouped_parameters, sharding_spec, param_shape + + +def set_dist_grad( + dist_module: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype, group: dist.ProcessGroup +) -> None: + """ + Set split grads for Tensor Parallel or ZeRO DP. + We do not need a separate treatment for ZeRO, + as the wrapper takes care of reduce-scattering grads. + """ + rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + + for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()): + if torch_p.grad is None: + torch_p.grad = torch.zeros_like(torch_p) + + is_distributed = hasattr(p, "dist_layout") + if is_distributed: + sharding = p.dist_layout.sharding_spec.sharding_sequence + split_dim = sharding.index(_TP_SPEC) + shape = torch_p.split(world_size, dim=split_dim)[rank].shape + + indices = torch.arange(shape[split_dim] * rank, shape[split_dim] * (rank + 1)) + # Generate grads only for the correctly split chunk + torch_p.grad.index_add_(split_dim, indices, torch.randn(shape, device=torch_p.device, dtype=g_dtype)) + + else: + shape = torch_p.shape + torch_p.grad += torch.randn(shape, device=torch_p.device, dtype=g_dtype) + + # avoid inconsistent grad and param dtype error + orig_p = p.data + p.data = torch_p.grad.clone().to(g_dtype) + p.grad = p.data + p.data = orig_p + + +class MlpModel(nn.Module): + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(HEIGHT, WIDTH) + self.linear2 = nn.Linear(WIDTH, HEIGHT) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +class TPModel(nn.Module): + def __init__(self, linear1, linear2, tp_group=None): + super().__init__() + self.linear1 = Linear1D_Col.from_native_module( + linear1, process_group=tp_group, gather_output=False, overlap=True + ) + self.linear2 = Linear1D_Row.from_native_module(linear2, process_group=tp_group, parallel_input=True) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +@parameterize("dtype", [torch.float32]) # , torch.float16, torch.bfloat16 +def exam_dist_adafactor_base(dtype: torch.dtype): + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + tensor_parallel_size = world_size + torch.set_default_dtype(dtype) + set_seed(42) + + # ============================== + # Base Case + # ============================== + H, W = 4096, 4096 + model_col = nn.Linear(H, W).to(local_rank) # Col parallel weight + weight, bias = model_col.weight, model_col.bias + device_mesh = DeviceMesh( + torch.Tensor([i for i in range(world_size)]), (1, tensor_parallel_size), init_process_group=True + ) + tp_group = device_mesh.get_process_group(axis=1) + # ============================== + # Col Parallel + # ============================== + weight_col_shard = shard_colwise(weight.clone(), device_mesh.get_process_group(axis=1)) + weight_col_shard_layout = get_layout(weight_col_shard) # Layout info weight_col_shard_layout.global_shape + weight_col_shard_shard_spec = get_sharding_spec(weight_col_shard) # Shard spec + weight_col_shard_flatten = nn.Parameter(weight_col_shard.clone().flatten().requires_grad_(True)) + bias_col_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True)) + col_params_shape = { + id(weight_col_shard_flatten): weight_col_shard_layout.global_shape, + id(bias_col_flatten): bias.shape, + } + col_sharding_spec_dict = {id(weight_col_shard_flatten): weight_col_shard_shard_spec, id(bias_col_flatten): None} + + # ============================== + # Row Parallel + # ============================== + weight_row_shard = shard_rowwise(weight.clone(), device_mesh.get_process_group(axis=1)) + weight_row_shard_layout = get_layout(weight_row_shard) # Layout info weight_row_shard_layout.global_shape + weight_row_shard_shard_spec = get_sharding_spec(weight_row_shard) # Shard spec + weight_row_shard_flatten = nn.Parameter( + weight_row_shard.clone().flatten().requires_grad_(True) + ) # flatten input(not dtensor) to optimizer + bias_row_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True)) + row_params_shape = { + id(weight_row_shard_flatten): weight_row_shard_layout.global_shape, + id(bias_row_flatten): bias.shape, + } + row_sharding_spec_dict = {id(weight_row_shard_flatten): weight_row_shard_shard_spec, id(bias_row_flatten): None} + + # ============================== + # Init Optimizer + # ============================== + + # base + optimizer_base = Adafactor([weight, bias]) + + # col parallel + optimizer_cp = DistributedAdaFactor([weight_col_shard_flatten, bias_col_flatten]) + optimizer_cp.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=None, + sharding_spec_dict=col_sharding_spec_dict, + param_shape=col_params_shape, + ) + # row parallel + optimizer_rp = DistributedAdaFactor([weight_row_shard_flatten, bias_row_flatten]) + optimizer_rp.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=None, + sharding_spec_dict=row_sharding_spec_dict, + param_shape=row_params_shape, + ) + + N_STEPS = 1 + for _ in range(N_STEPS): + # base step + optimizer_base.zero_grad() + weight.grad = torch.rand_like(weight) + bias.grad = torch.rand_like(bias) + optimizer_base.step() + + # col parallel step + optimizer_cp.zero_grad() + weight_col_shard_flatten.grad = ( + distribute_tensor(weight.grad, device_mesh, weight_col_shard_shard_spec).clone().flatten() + ) + bias_col_flatten.grad = bias.grad.clone().flatten() + optimizer_cp.step() + + # row parallel step + optimizer_rp.zero_grad() + weight_row_shard_flatten.grad = ( + distribute_tensor(weight.grad, device_mesh, weight_row_shard_shard_spec).clone().flatten() + ) + bias_row_flatten.grad = bias.grad.clone().flatten() + optimizer_rp.step() + + # gather result + weight_col_gather = _gather( + input_=weight_col_shard_flatten.data.view(-1, H // tensor_parallel_size), + dim=-1, + process_group=device_mesh.get_process_group(axis=1), + ) # gather + weight_row_gather = _gather( + input_=weight_row_shard_flatten.data, dim=-1, process_group=device_mesh.get_process_group(axis=1) + ).view( + -1, W + ) # gather + + # verify + col_correct = correctness_verify(weight.data, weight_col_gather.data, dtype) + row_correct = correctness_verify(weight.data, weight_row_gather.data, dtype) + + print(f"col corrness {col_correct} row correct {row_correct}") + + +@parameterize("dtype", [torch.float32]) # , torch.float16, torch.bfloat16 +def exam_dist_adafactor_fwd_bwd(dtype: torch.dtype): + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + tensor_parallel_size = world_size + torch.set_default_dtype(dtype) + set_seed(42) + + # ============================== + # Model Init + # ============================== + device_mesh = DeviceMesh( + torch.Tensor([i for i in range(world_size)]), (1, tensor_parallel_size), init_process_group=True + ) + base_model = MlpModel().to(local_rank) + tp_model = TPModel( + copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), device_mesh.get_process_group(axis=1) + ).to(local_rank) + tp_group = device_mesh.get_process_group(axis=1) + + base_param_group = setup_param_groups(base_model) + tp_param_group, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model) + + # ============================== + # Optimizer Init + # ============================== + base_optim = Adafactor(base_param_group) + dist_optim = DistributedAdaFactor(tp_param_group) + dist_optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=None, + sharding_spec_dict=tp_shard_spec, + param_shape=tp_param_shape, + ) + + # ============================== + # Correctness Verify + # ============================== + x = torch.randn(HEIGHT, WIDTH, device=local_rank) + + loss_tp = tp_model(x).sum() + loss_tp.backward() + + loss = base_model(x).sum() + loss.backward() + + base_optim.zero_grad() + dist_optim.zero_grad() + + base_optim.step() + dist_optim.step() + + for p, tp_p in zip(base_param_group, tp_param_group): + if tp_shard_spec[id(tp_p)] is not None: + if len(tp_shard_spec[id(tp_p)].sharding_sequence) >= 2: + # print(f"device {local_rank} \n tp_p shard spec {tp_shard_spec[id(tp_p)]}\n len {len(tp_shard_spec[id(tp_p)].sharding_sequence)}") + # if tp_p tp_shard_spec is col tp --> view to (-1, H // tensor_parallel_size) then gather + if tp_shard_spec[id(tp_p)].sharding_sequence[0] == "R": + tp_p = _gather( + input_=tp_p.data.view(-1, HEIGHT // tensor_parallel_size), + dim=-1, + process_group=device_mesh.get_process_group(axis=1), + ) # gather + # if tp_p tp_shard_spec is row tp --> gather then view to (-1, H // tensor_parallel_size) + else: + tp_p = _gather(input_=tp_p.data, dim=-1, process_group=device_mesh.get_process_group(axis=1)).view( + -1, WIDTH + ) # gather + else: + # bias parallel + tp_p = _gather(input_=tp_p.data, dim=-1, process_group=device_mesh.get_process_group(axis=1)) + # print(f"device {local_rank} \n p {p}\n tp_p {tp_p}\n") + else: + # compare p and tp no need + pass + # print(f"device {local_rank} \n p {p}\n tp_p {tp_p}\n") + correctness_verify(p.data, tp_p.data, dtype) + # print(f"correct {correctness}") + + +@parameterize("dtype", [torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 +@parameterize("tp_zero_size", [(4, 2)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) +def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): + tp_size, zero_size = tp_zero_size + use_zero = True if zero_size > 1 else False + local_rank = dist.get_rank() + + proc_mesh = ProcessGroupMesh(tp_size, zero_size) + tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) + + torch.set_default_dtype(dtype) + set_seed(42) + + # ============================== + # Model Init + # ============================== + base_model = MlpModel().to(local_rank) + tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) + + base_param_group = setup_param_groups(base_model) + tp_param_group = setup_param_groups(tp_model) + tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model) + + # ============================== + # Optimizer Init + # ============================== + base_optim = Adafactor(base_param_group) + dist_optim = DistributedAdaFactor(tp_param_group) + + # Setup distributed optimizer + if zero_size > 1: + base_optim = LowLevelZeroOptimizer( + base_optim, + overlap_communication=True, + initial_scale=128, + partition_grad=True, + dp_process_group=dp_group, + verbose=True, + ) + + dist_optim = LowLevelZeroOptimizer( + dist_optim, + overlap_communication=True, + initial_scale=128, + partition_grad=True, + dp_process_group=dp_group, + verbose=True, + ) + shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened + dist_optim.optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=dp_group, + shard_to_param=shard_to_param, + use_zero=use_zero, + ) + else: + dist_optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=dp_group, + shard_to_param=shard_to_param, + use_zero=use_zero, + ) + + # ============================== + # Correctness Verify + # ============================== + x = torch.randn(HEIGHT, WIDTH, device=local_rank) + + out = base_model(x) + out_tp = tp_model(x) + + if zero_size > 1: + dist_optim.backward(out_tp.sum()) + base_optim.backward(out.sum()) + else: + out_tp.sum().backward() + out.sum().backward() + + base_optim.step() + dist_optim.step() + + base_optim.zero_grad() + dist_optim.zero_grad() + + for p, tp_p in zip(base_param_group, tp_param_group): + param_is_distributed = is_distributed_tensor(tp_p) + if param_is_distributed: + shard_spec = get_sharding_spec(tp_p) + # print(f"device {local_rank} shard spec{shard_spec} len {len(shard_spec.sharding_sequence)}\n") + if len(shard_spec.sharding_sequence) >= 2: + # Col Parallel + if shard_spec.sharding_sequence[0] == "R": + tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + # ROW Parallel + if shard_spec.sharding_sequence[-1] == "R": + tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather + else: + # TP bias + tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + + else: + # No TP bias + pass + correctness = correctness_verify(p.data, tp_p.data, dtype) + print(f"Curr Param correct {correctness}") + # print(f"device {local_rank} base_optim state dict {base_optim.optim.state_dict()['state'].items()} \n dist_optim state dict {dist_optim.optim.state_dict()['state'].items()} \n") + + + + + +@parameterize("dtype", [torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 +@parameterize("tp_zero_size", [(4, 2)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) +def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int]): + tp_size, zero_size = tp_zero_size + local_rank = dist.get_rank() + use_zero = True if zero_size > 1 else False + + proc_mesh = ProcessGroupMesh(tp_size, zero_size) + tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) + + torch.set_default_dtype(dtype) + set_seed(42) + + # ============================== + # Model Init + # ============================== + base_model = MlpModel().to(local_rank) + tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) + + base_param_group = setup_param_groups(base_model) + tp_param_group = setup_param_groups(tp_model) + + # ============================== + # Optimizer Init + # ============================== + base_optim = Adafactor(base_param_group) + dist_optim = DistributedAdaFactor(tp_param_group) + + # Setup distributed optimizer + if zero_size > 1: + base_optim = LowLevelZeroOptimizer( + base_optim, + overlap_communication=True, + initial_scale=128, + partition_grad=True, + dp_process_group=dp_group, + verbose=True, + ) + + dist_optim = LowLevelZeroOptimizer( + dist_optim, + overlap_communication=True, + initial_scale=128, + partition_grad=True, + dp_process_group=dp_group, + verbose=True, + ) + shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened + dist_optim.optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=dp_group, + shard_to_param=shard_to_param, + use_zero=use_zero, + ) + else: + dist_optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=dp_group, + shard_to_param=shard_to_param, + use_zero=use_zero, + ) + + # ============================== + # Booster Init + # ============================== + plugin = TorchDDPPlugin() + booster = Booster(plugin=plugin) + criterion = lambda x: x.mean() + + tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion) + + # ============================== + # Correctness Verify + # ============================== + x = torch.randn(HEIGHT, WIDTH, device=local_rank) + + out = base_model(x) + out_tp = tp_model(x) + + if zero_size > 1: + dist_optim.backward(out_tp.sum()) + base_optim.backward(out.sum()) + else: + out_tp.sum().backward() + out.sum().backward() + + base_optim.step() + dist_optim.step() + + base_optim.zero_grad() + dist_optim.zero_grad() + + for p, tp_p in zip(base_param_group, tp_param_group): + param_is_distributed = is_distributed_tensor(tp_p) + if param_is_distributed: + shard_spec = get_sharding_spec(tp_p) + # print(f"device {local_rank} shard spec{shard_spec} len {len(shard_spec.sharding_sequence)}\n") + if len(shard_spec.sharding_sequence) >= 2: + # Col Parallel + if shard_spec.sharding_sequence[0] == "R": + tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + # ROW Parallel + if shard_spec.sharding_sequence[-1] == "R": + tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather + else: + # TP bias + tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + else: + # No TP bias + pass + correctness = correctness_verify(p.data, tp_p.data, dtype) + print(f"Curr Param correct {correctness}") + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + # exam_dist_adafactor_base() + # exam_dist_adafactor_fwd_bwd() + exam_dist_adafactor_zero() + # exam_dist_adafactor_booster() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_dist_adafactor(): + spawn(run_dist, nprocs=8) + + +if __name__ == "__main__": + test_dist_adafactor() From b75ac58272e649495761dcbae3bd6e8cc7a3083c Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 8 Apr 2024 17:37:26 +0800 Subject: [PATCH 02/35] [feature] update optimize readme; --- colossalai/nn/optimizer/README.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/colossalai/nn/optimizer/README.md b/colossalai/nn/optimizer/README.md index b846f1cf55cb..07c95143c74c 100644 --- a/colossalai/nn/optimizer/README.md +++ b/colossalai/nn/optimizer/README.md @@ -89,9 +89,7 @@ A series of optimizers have been optimized and integrated. ### Distributed Adafactor -Distributed Adafactor supports tensor parallelism and ZerO optimization. Here is a brief flowchart of how adafactor implements the tensor parallel: - -[[Tensor Parallel Strategy in Distributed Adafactor]](adafactor_strategy.png) +Distributed Adafactor supports tensor parallelism and ZerO optimization. ### Performance | Version | iter | Float Percision | Device Nums | weight shape | Avg runtime(ms) | Avg Speed Up Rate | Best Speed Up Rate | From d5f72fef4f2a0d4a2a4fbc34fbd09145b574bf7e Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 8 Apr 2024 19:28:11 +0800 Subject: [PATCH 03/35] [fix] fix testcase; --- .../nn/optimizer/distributed_adafactor.py | 17 +++++++--- tests/test_optimizer/test_dist_adafactor.py | 33 ++++++++++++------- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index 50792ba9af0d..89ed71d4d799 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -53,6 +53,7 @@ def __init__( self.factored = None # bool self.use_first_moment = None # bool self.use_zero = True + self.is_dist = {} def setup_distributed( self, @@ -78,8 +79,9 @@ def setup_distributed( self.tensor_parallel_size = dist.get_world_size(self.tensor_parallel_group) if self.data_parallel_group is not None: self.data_parallel_size = dist.get_world_size(self.data_parallel_group) - self.shard_to_param = shard_to_param self.use_zero = use_zero + + self.shard_to_param = shard_to_param if shard_to_param is not None else {} @staticmethod def _get_lr(param_group, param_state): @@ -227,7 +229,10 @@ def step(self, closure=None): update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) update_reshape.mul_(grad_reshape) # update = update_reshape.view(update_reshape.shape[0]*update_reshape.shape[1]) - update = update_reshape.view(-1) + if self.use_zero: + update = update_reshape.view(-1) + else: + update = update_reshape # ============================== # Last Dim is R, First Dim is S{} means split dim 0 ---> # Row Parallel ---> sq_col need Do (row) Reduce @@ -250,7 +255,10 @@ def step(self, closure=None): update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam) update_reshape.mul_(grad_reshape) # update = update_reshape.view(update_reshape.shape[0]*update_reshape.shape[1]) - update = update_reshape.view(-1) + if self.use_zero: + update = update_reshape.view(-1) + else: + update = update_reshape else: exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) @@ -267,7 +275,8 @@ def step(self, closure=None): if group["weight_decay"] != 0: p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) - p_data_fp32.add_(-update).flatten() + p_data_fp32.add_(-update) + if p.dtype in {torch.float16, torch.bfloat16}: p.copy_(p_data_fp32) diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index f675f42301a8..253b4568f48c 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -5,6 +5,7 @@ import torch import torch.distributed as dist from torch import nn +from torch.testing import assert_close import colossalai from colossalai.booster import Booster @@ -27,9 +28,10 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed from colossalai.zero import LowLevelZeroOptimizer +from tests.test_optimizer._utils import run_bert_test -HEIGHT = 4096 -WIDTH = 4096 +HEIGHT = 4 +WIDTH = 4 _TP_SPEC = DimSpec([0]) @@ -128,6 +130,11 @@ def set_dist_grad( p.data = orig_p +def set_master_param_to_shard_param(master_param_list) -> dict: + master_param_to_shard_param ={id(p):p for p in master_param_list} + return master_param_to_shard_param + + class MlpModel(nn.Module): def __init__(self): super(MlpModel, self).__init__() @@ -350,8 +357,8 @@ def exam_dist_adafactor_fwd_bwd(dtype: torch.dtype): # print(f"correct {correctness}") -@parameterize("dtype", [torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 -@parameterize("tp_zero_size", [(4, 2)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) +@parameterize("dtype", [torch.float32]) # torch.float32, torch.float16, torch.bfloat16 +@parameterize("tp_zero_size", [(4, 1)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): tp_size, zero_size = tp_zero_size use_zero = True if zero_size > 1 else False @@ -406,6 +413,7 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): use_zero=use_zero, ) else: + shard_to_param = set_master_param_to_shard_param(tp_param_group) dist_optim.setup_distributed( tensor_parallel_group=tp_group, data_parallel_group=dp_group, @@ -454,15 +462,13 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): # No TP bias pass correctness = correctness_verify(p.data, tp_p.data, dtype) - print(f"Curr Param correct {correctness}") + # print(f"{correctness}\n p.data {p.data}\n tp_p.data{tp_p.data}\n") + # print(f"Curr Param correct {correctness}") # print(f"device {local_rank} base_optim state dict {base_optim.optim.state_dict()['state'].items()} \n dist_optim state dict {dist_optim.optim.state_dict()['state'].items()} \n") - - - -@parameterize("dtype", [torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 -@parameterize("tp_zero_size", [(4, 2)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) +@parameterize("dtype", [torch.float32]) # torch.float32, torch.float16, torch.bfloat16 +@parameterize("tp_zero_size", [(4, 1)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int]): tp_size, zero_size = tp_zero_size local_rank = dist.get_rank() @@ -516,10 +522,11 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int use_zero=use_zero, ) else: + shard_to_param = set_master_param_to_shard_param(tp_param_group) dist_optim.setup_distributed( tensor_parallel_group=tp_group, data_parallel_group=dp_group, - shard_to_param=shard_to_param, + shard_to_param={}, use_zero=use_zero, ) @@ -582,12 +589,14 @@ def run_dist(rank, world_size, port): # exam_dist_adafactor_fwd_bwd() exam_dist_adafactor_zero() # exam_dist_adafactor_booster() + # run_bert_test(optim_class=Adafactor, sharded_optim_class=DistributedAdaFactor) + @pytest.mark.dist @rerun_if_address_is_in_use() def test_dist_adafactor(): - spawn(run_dist, nprocs=8) + spawn(run_dist, nprocs=4) if __name__ == "__main__": From 020ed547c4c846739bc52b7e9a970ab30204af82 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 9 Apr 2024 17:13:20 +0800 Subject: [PATCH 04/35] [feature] Add transformer-bert to testcase;solve a bug related to indivisible shape (induction in use_zero and tp is row parallel); --- colossalai/nn/optimizer/adafactor.py | 1 + .../nn/optimizer/distributed_adafactor.py | 111 ++++--- tests/test_optimizer/test_dist_adafactor.py | 312 ++++++++++++------ 3 files changed, 279 insertions(+), 145 deletions(-) diff --git a/colossalai/nn/optimizer/adafactor.py b/colossalai/nn/optimizer/adafactor.py index c71676bf5459..0cedbb2512be 100644 --- a/colossalai/nn/optimizer/adafactor.py +++ b/colossalai/nn/optimizer/adafactor.py @@ -36,6 +36,7 @@ def __init__( relative_step=True, warmup_init=False, ): + lr=None if lr is not None and relative_step: raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") if warmup_init and not relative_step: diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index 89ed71d4d799..b1d313678ecb 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -3,16 +3,17 @@ import torch import torch.distributed as dist -from torch.optim import Optimizer +# from torch.optim import Optimizer +from colossalai.interface.optimizer import DistributedOptim -from colossalai.shardformer.layer._operation import _gather +from colossalai.shardformer.layer._operation import _gather, _split from colossalai.tensor.d_tensor import get_sharding_spec, is_distributed_tensor # DistributedAdaFactor (with Tensor parallel and Zero stage 2) __all__ = ["DistributedAdaFactor"] -class DistributedAdaFactor(Optimizer): +class DistributedAdaFactor(DistributedOptim): def __init__( self, params, @@ -26,6 +27,7 @@ def __init__( relative_step=True, warmup_init=False, ): + lr=None if lr is not None and relative_step: raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") if warmup_init and not relative_step: @@ -42,7 +44,6 @@ def __init__( "relative_step": relative_step, "warmup_init": warmup_init, } - super().__init__(params, defaults) self.tensor_parallel_size = 1 self.tensor_parallel_group = None self.data_parallel_size = 1 @@ -53,13 +54,14 @@ def __init__( self.factored = None # bool self.use_first_moment = None # bool self.use_zero = True - self.is_dist = {} + super().__init__(params, defaults) + def setup_distributed( self, tensor_parallel_group: dist.ProcessGroup = None, data_parallel_group: dist.ProcessGroup = None, - shard_to_param: Dict = None, + shard_to_param: Dict = {}, use_zero: bool = True, ) -> None: """ @@ -82,6 +84,7 @@ def setup_distributed( self.use_zero = use_zero self.shard_to_param = shard_to_param if shard_to_param is not None else {} + @staticmethod def _get_lr(param_group, param_state): @@ -161,7 +164,9 @@ def step(self, closure=None): raise RuntimeError("Adafactor does not support sparse gradients.") state = self.state[p] self.grad_shape = grad.shape # 1 dim shape - + + # print(f"self.shard_to_param {self.shard_to_param}") + param_is_dtensor = is_distributed_tensor(self.shard_to_param.get(id(p))) if param_is_dtensor: @@ -184,9 +189,16 @@ def step(self, closure=None): ) # [W/TP] if self.shard_spec.sharding_sequence[-1] == "R": # Row Parallel - state["exp_avg_sq_row"] = torch.zeros( + # Row Residual situation + if self.grad_shape[0] % self.data_parallel_size != 0: + state["exp_avg_sq_row"] = torch.zeros( + self.grad_shape[0], device=p.device, dtype=p.dtype + ) # [H/dp/Tp] + else: + state["exp_avg_sq_row"] = torch.zeros( self.grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype - ) # [H/dp/Tp] + ) # [H/dp/Tp] + state["exp_avg_sq_col"] = torch.zeros( self.grad_shape[1], device=p.device, dtype=p.dtype ) # [W] @@ -202,10 +214,6 @@ def step(self, closure=None): else: state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) - p_data_fp32 = p.float() - # if p.dtype in {torch.float16, torch.bfloat16}: - # p_data_fp32 = p_data_fp32.float() - state["step"] += 1 lr = self._get_lr(group, state) beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) @@ -228,7 +236,6 @@ def step(self, closure=None): exp_avg_sq_row.div_(self.tensor_parallel_size) update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) update_reshape.mul_(grad_reshape) - # update = update_reshape.view(update_reshape.shape[0]*update_reshape.shape[1]) if self.use_zero: update = update_reshape.view(-1) else: @@ -238,27 +245,54 @@ def step(self, closure=None): # Row Parallel ---> sq_col need Do (row) Reduce # ============================== elif self.shard_spec.sharding_sequence[-1] == "R": - update_reshape = update.view(-1, self.grad_shape[1]) - grad_reshape = grad.view(-1, self.grad_shape[1]) - exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp/tp] - exp_avg_sq_col = state["exp_avg_sq_col"] # [W] - exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) - exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) - # reduce col - dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group) - exp_avg_sq_col.div_(self.tensor_parallel_size) - # gather row - exp_avg_sq_row_gather = _gather( - input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group - ) - sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True) - update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam) - update_reshape.mul_(grad_reshape) - # update = update_reshape.view(update_reshape.shape[0]*update_reshape.shape[1]) - if self.use_zero: - update = update_reshape.view(-1) + # Row Residual situation + if self.grad_shape[0] % self.data_parallel_size != 0: + # gather update[flatten] along dp group then reshape to [H/tp, W] + update = _gather( + input_=update, dim=-1, process_group=self.data_parallel_group + ) + # view update to origin[tp] shape + update_reshape = update.view(-1, self.grad_shape[1]) + + # gather grad[flatten] along dp group then reshape to [H/tp, W] + grad = _gather( + input_=grad, dim=-1, process_group=self.data_parallel_group + ) + grad_reshape = grad.view(-1, self.grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group) + exp_avg_sq_col.div_(self.tensor_parallel_size) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + if self.use_zero: + update = _split(input_=update_reshape.view(-1), dim=-1, process_group=self.data_parallel_group) + else: + update = update_reshape else: - update = update_reshape + update_reshape = update.view(-1, self.grad_shape[1]) + grad_reshape = grad.view(-1, self.grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp/tp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group) + exp_avg_sq_col.div_(self.tensor_parallel_size) + # gather row + exp_avg_sq_row_gather = _gather( + input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group + ) + sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True) + update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam) + update_reshape.mul_(grad_reshape) + if self.use_zero: + update = update_reshape.view(-1) + else: + update = update_reshape else: exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) @@ -273,12 +307,9 @@ def step(self, closure=None): update = exp_avg if group["weight_decay"] != 0: - p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) - - p_data_fp32.add_(-update) - + p.add_(p, alpha=(-group["weight_decay"] * lr)) + + p.add_(-update) - if p.dtype in {torch.float16, torch.bfloat16}: - p.copy_(p_data_fp32) return loss diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index 253b4568f48c..f5424ee1738a 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -9,7 +9,7 @@ import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import TorchDDPPlugin +from colossalai.booster.plugin import TorchDDPPlugin, HybridParallelPlugin from colossalai.cluster import ProcessGroupMesh from colossalai.device.device_mesh import DeviceMesh from colossalai.nn.optimizer.adafactor import Adafactor @@ -20,15 +20,24 @@ distribute_tensor, get_layout, get_sharding_spec, + get_device_mesh, is_distributed_tensor, shard_colwise, shard_rowwise, ) +from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.tensor.d_tensor.sharding_spec import DimSpec from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed from colossalai.zero import LowLevelZeroOptimizer +from tests.kit.model_zoo import model_zoo from tests.test_optimizer._utils import run_bert_test +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_weight, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) HEIGHT = 4 WIDTH = 4 @@ -39,8 +48,8 @@ def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torc rtol = None atol = None if dtype is torch.float32: - rtol = 1e-05 - atol = 1e-05 + rtol = 5e-04 + atol = 5e-04 elif dtype is torch.float16: rtol = 5e-2 atol = 5e-4 @@ -48,8 +57,8 @@ def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torc rtol = 4e-3 atol = 4e-3 - return torch.all(tensor1.isclose(tensor2, rtol=rtol, atol=atol)) - # assert_close(tensor1, tensor2, rtol=rtol, atol=atol) + # return torch.all(tensor1.isclose(tensor2, rtol=rtol, atol=atol)) + assert_close(tensor1, tensor2, rtol=rtol, atol=atol) # setup param groups; (For zero test optim) @@ -161,12 +170,18 @@ def forward(self, x): return x + + @parameterize("dtype", [torch.float32]) # , torch.float16, torch.bfloat16 -def exam_dist_adafactor_base(dtype: torch.dtype): - local_rank = int(os.environ["LOCAL_RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) +@parameterize("tp_zero_size", [(4, 1)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) +def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): + tp_size, zero_size = tp_zero_size + local_rank = dist.get_rank() + use_zero = True if zero_size > 1 else False + + proc_mesh = ProcessGroupMesh(tp_size, zero_size) + tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) - tensor_parallel_size = world_size torch.set_default_dtype(dtype) set_seed(42) @@ -176,64 +191,57 @@ def exam_dist_adafactor_base(dtype: torch.dtype): H, W = 4096, 4096 model_col = nn.Linear(H, W).to(local_rank) # Col parallel weight weight, bias = model_col.weight, model_col.bias - device_mesh = DeviceMesh( - torch.Tensor([i for i in range(world_size)]), (1, tensor_parallel_size), init_process_group=True - ) - tp_group = device_mesh.get_process_group(axis=1) + # ============================== # Col Parallel # ============================== - weight_col_shard = shard_colwise(weight.clone(), device_mesh.get_process_group(axis=1)) + weight_col_shard = shard_colwise(weight.clone(), tp_group) weight_col_shard_layout = get_layout(weight_col_shard) # Layout info weight_col_shard_layout.global_shape weight_col_shard_shard_spec = get_sharding_spec(weight_col_shard) # Shard spec weight_col_shard_flatten = nn.Parameter(weight_col_shard.clone().flatten().requires_grad_(True)) bias_col_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True)) - col_params_shape = { - id(weight_col_shard_flatten): weight_col_shard_layout.global_shape, - id(bias_col_flatten): bias.shape, - } - col_sharding_spec_dict = {id(weight_col_shard_flatten): weight_col_shard_shard_spec, id(bias_col_flatten): None} # ============================== # Row Parallel # ============================== - weight_row_shard = shard_rowwise(weight.clone(), device_mesh.get_process_group(axis=1)) + weight_row_shard = shard_rowwise(weight.clone(), tp_group) weight_row_shard_layout = get_layout(weight_row_shard) # Layout info weight_row_shard_layout.global_shape weight_row_shard_shard_spec = get_sharding_spec(weight_row_shard) # Shard spec weight_row_shard_flatten = nn.Parameter( weight_row_shard.clone().flatten().requires_grad_(True) ) # flatten input(not dtensor) to optimizer bias_row_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True)) - row_params_shape = { - id(weight_row_shard_flatten): weight_row_shard_layout.global_shape, - id(bias_row_flatten): bias.shape, - } - row_sharding_spec_dict = {id(weight_row_shard_flatten): weight_row_shard_shard_spec, id(bias_row_flatten): None} + + base_param_group = setup_param_groups([weight, bias]) + cp_param_group = setup_param_groups([weight_col_shard_flatten, bias_col_flatten]) + rp_param_group = setup_param_groups([weight_row_shard_flatten, bias_row_flatten]) # ============================== # Init Optimizer # ============================== # base - optimizer_base = Adafactor([weight, bias]) - - # col parallel - optimizer_cp = DistributedAdaFactor([weight_col_shard_flatten, bias_col_flatten]) - optimizer_cp.setup_distributed( + optimizer_base = Adafactor(base_param_group) + cp_dist_optim = DistributedAdaFactor(cp_param_group) + rp_dist_optim = DistributedAdaFactor(rp_param_group) + + shard_to_param_cp = set_master_param_to_shard_param(cp_dist_optim) + cp_dist_optim.setup_distributed( tensor_parallel_group=tp_group, - data_parallel_group=None, - sharding_spec_dict=col_sharding_spec_dict, - param_shape=col_params_shape, + data_parallel_group=dp_group, + shard_to_param=shard_to_param_cp, + use_zero=use_zero, ) - # row parallel - optimizer_rp = DistributedAdaFactor([weight_row_shard_flatten, bias_row_flatten]) - optimizer_rp.setup_distributed( + + shard_to_param_rp = set_master_param_to_shard_param(rp_dist_optim) + rp_dist_optim.setup_distributed( tensor_parallel_group=tp_group, - data_parallel_group=None, - sharding_spec_dict=row_sharding_spec_dict, - param_shape=row_params_shape, + data_parallel_group=dp_group, + shard_to_param=shard_to_param_rp, + use_zero=use_zero, ) + N_STEPS = 1 for _ in range(N_STEPS): # base step @@ -243,29 +251,29 @@ def exam_dist_adafactor_base(dtype: torch.dtype): optimizer_base.step() # col parallel step - optimizer_cp.zero_grad() + cp_dist_optim.zero_grad() weight_col_shard_flatten.grad = ( - distribute_tensor(weight.grad, device_mesh, weight_col_shard_shard_spec).clone().flatten() + distribute_tensor(weight.grad, get_device_mesh(weight_col_shard), weight_col_shard_shard_spec).clone().flatten() ) bias_col_flatten.grad = bias.grad.clone().flatten() - optimizer_cp.step() + cp_dist_optim.step() # row parallel step - optimizer_rp.zero_grad() + rp_dist_optim.zero_grad() weight_row_shard_flatten.grad = ( - distribute_tensor(weight.grad, device_mesh, weight_row_shard_shard_spec).clone().flatten() + distribute_tensor(weight.grad, get_device_mesh(weight_row_shard), weight_row_shard_shard_spec).clone().flatten() ) bias_row_flatten.grad = bias.grad.clone().flatten() - optimizer_rp.step() + rp_dist_optim.step() # gather result weight_col_gather = _gather( - input_=weight_col_shard_flatten.data.view(-1, H // tensor_parallel_size), + input_=weight_col_shard_flatten.data.view(-1, H // tp_size), dim=-1, - process_group=device_mesh.get_process_group(axis=1), + process_group=tp_group, ) # gather weight_row_gather = _gather( - input_=weight_row_shard_flatten.data, dim=-1, process_group=device_mesh.get_process_group(axis=1) + input_=weight_row_shard_flatten.data, dim=-1, process_group=tp_group ).view( -1, W ) # gather @@ -278,91 +286,96 @@ def exam_dist_adafactor_base(dtype: torch.dtype): @parameterize("dtype", [torch.float32]) # , torch.float16, torch.bfloat16 -def exam_dist_adafactor_fwd_bwd(dtype: torch.dtype): - local_rank = int(os.environ["LOCAL_RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - tensor_parallel_size = world_size +@parameterize("tp_zero_size", [(4, 1)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) +def exam_dist_adafactor_fwd_bwd(dtype: torch.dtype, tp_zero_size: tuple[int, int]): + tp_size, zero_size = tp_zero_size + use_zero = True if zero_size > 1 else False + local_rank = dist.get_rank() + + clear_layout_converter() + proc_mesh = ProcessGroupMesh(tp_size, zero_size) + tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) + torch.set_default_dtype(dtype) set_seed(42) # ============================== # Model Init # ============================== - device_mesh = DeviceMesh( - torch.Tensor([i for i in range(world_size)]), (1, tensor_parallel_size), init_process_group=True - ) base_model = MlpModel().to(local_rank) - tp_model = TPModel( - copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), device_mesh.get_process_group(axis=1) - ).to(local_rank) - tp_group = device_mesh.get_process_group(axis=1) + tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) base_param_group = setup_param_groups(base_model) - tp_param_group, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model) + tp_param_group = setup_param_groups(tp_model) # ============================== # Optimizer Init # ============================== base_optim = Adafactor(base_param_group) dist_optim = DistributedAdaFactor(tp_param_group) + + shard_to_param = set_master_param_to_shard_param(tp_param_group) dist_optim.setup_distributed( tensor_parallel_group=tp_group, - data_parallel_group=None, - sharding_spec_dict=tp_shard_spec, - param_shape=tp_param_shape, + data_parallel_group=dp_group, + shard_to_param=shard_to_param, + use_zero=use_zero, ) - + # ============================== # Correctness Verify # ============================== x = torch.randn(HEIGHT, WIDTH, device=local_rank) - loss_tp = tp_model(x).sum() - loss_tp.backward() - - loss = base_model(x).sum() - loss.backward() + out = base_model(x) + out_tp = tp_model(x) - base_optim.zero_grad() - dist_optim.zero_grad() + if zero_size > 1: + dist_optim.backward(out_tp.sum()) + base_optim.backward(out.sum()) + else: + out_tp.sum().backward() + out.sum().backward() base_optim.step() dist_optim.step() + base_optim.zero_grad() + dist_optim.zero_grad() + for p, tp_p in zip(base_param_group, tp_param_group): - if tp_shard_spec[id(tp_p)] is not None: - if len(tp_shard_spec[id(tp_p)].sharding_sequence) >= 2: - # print(f"device {local_rank} \n tp_p shard spec {tp_shard_spec[id(tp_p)]}\n len {len(tp_shard_spec[id(tp_p)].sharding_sequence)}") - # if tp_p tp_shard_spec is col tp --> view to (-1, H // tensor_parallel_size) then gather - if tp_shard_spec[id(tp_p)].sharding_sequence[0] == "R": - tp_p = _gather( - input_=tp_p.data.view(-1, HEIGHT // tensor_parallel_size), - dim=-1, - process_group=device_mesh.get_process_group(axis=1), - ) # gather - # if tp_p tp_shard_spec is row tp --> gather then view to (-1, H // tensor_parallel_size) - else: - tp_p = _gather(input_=tp_p.data, dim=-1, process_group=device_mesh.get_process_group(axis=1)).view( - -1, WIDTH - ) # gather + param_is_distributed = is_distributed_tensor(tp_p) + if param_is_distributed: + shard_spec = get_sharding_spec(tp_p) + # print(f"device {local_rank} shard spec{shard_spec} len {len(shard_spec.sharding_sequence)}\n") + if len(shard_spec.sharding_sequence) >= 2: + # Col Parallel + if shard_spec.sharding_sequence[0] == "R": + tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + # ROW Parallel + if shard_spec.sharding_sequence[-1] == "R": + tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather else: - # bias parallel - tp_p = _gather(input_=tp_p.data, dim=-1, process_group=device_mesh.get_process_group(axis=1)) - # print(f"device {local_rank} \n p {p}\n tp_p {tp_p}\n") + # TP bias + tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + else: - # compare p and tp no need + # No TP bias pass - # print(f"device {local_rank} \n p {p}\n tp_p {tp_p}\n") - correctness_verify(p.data, tp_p.data, dtype) - # print(f"correct {correctness}") + correctness = correctness_verify(p.data, tp_p.data, dtype) + # print(f"{correctness}\n p.data {p.data}\n tp_p.data{tp_p.data}\n") + # print(f"Curr Param correct {correctness}") + # print(f"device {local_rank} base_optim state dict {base_optim.optim.state_dict()['state'].items()} \n dist_optim state dict {dist_optim.optim.state_dict()['state'].items()} \n") -@parameterize("dtype", [torch.float32]) # torch.float32, torch.float16, torch.bfloat16 -@parameterize("tp_zero_size", [(4, 1)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) +@parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 +@parameterize("tp_zero_size", [(1, 4), (4, 1), (2, 2)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): tp_size, zero_size = tp_zero_size use_zero = True if zero_size > 1 else False local_rank = dist.get_rank() + + clear_layout_converter() proc_mesh = ProcessGroupMesh(tp_size, zero_size) tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) @@ -442,6 +455,7 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): base_optim.zero_grad() dist_optim.zero_grad() + print(f"data type {dtype},tp size {tp_size}, dp size {zero_size}\n") for p, tp_p in zip(base_param_group, tp_param_group): param_is_distributed = is_distributed_tensor(tp_p) if param_is_distributed: @@ -462,17 +476,20 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): # No TP bias pass correctness = correctness_verify(p.data, tp_p.data, dtype) - # print(f"{correctness}\n p.data {p.data}\n tp_p.data{tp_p.data}\n") + # print(f"Curr Param correct {correctness}") - # print(f"device {local_rank} base_optim state dict {base_optim.optim.state_dict()['state'].items()} \n dist_optim state dict {dist_optim.optim.state_dict()['state'].items()} \n") + # if not correctness: + # print(f"{correctness}\n p.data {p.data}\n tp_p.data{tp_p.data}\n") -@parameterize("dtype", [torch.float32]) # torch.float32, torch.float16, torch.bfloat16 -@parameterize("tp_zero_size", [(4, 1)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) +@parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 +@parameterize("tp_zero_size", [(1, 4), (4, 1), (2, 2)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int]): tp_size, zero_size = tp_zero_size local_rank = dist.get_rank() use_zero = True if zero_size > 1 else False + + clear_layout_converter() proc_mesh = ProcessGroupMesh(tp_size, zero_size) tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) @@ -526,7 +543,7 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int dist_optim.setup_distributed( tensor_parallel_group=tp_group, data_parallel_group=dp_group, - shard_to_param={}, + shard_to_param=shard_to_param, use_zero=use_zero, ) @@ -559,7 +576,7 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int base_optim.zero_grad() dist_optim.zero_grad() - + print(f"data type {dtype},tp size {tp_size}, dp size {zero_size}\n") for p, tp_p in zip(base_param_group, tp_param_group): param_is_distributed = is_distributed_tensor(tp_p) if param_is_distributed: @@ -575,11 +592,97 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int else: # TP bias tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + else: # No TP bias pass correctness = correctness_verify(p.data, tp_p.data, dtype) - print(f"Curr Param correct {correctness}") + +@parameterize( + "test_config", + [ + { + "tp_size": 1, + "num_microbatches": 4, + "zero_stage": 2, + "precision": "bf16", + }, + { + "tp_size": 2, + "num_microbatches": 4, + "zero_stage": 2, + "precision": "bf16", + }, + { + "tp_size": 4, + "num_microbatches": 4, + "zero_stage": 2, + "precision": "bf16", + }, + { + "tp_size": 1, + "num_microbatches": 4, + "zero_stage": 2, + "precision": "fp16", + }, + { + "tp_size": 2, + "num_microbatches": 4, + "zero_stage": 2, + "precision": "fp16", + }, + { + "tp_size": 4, + "num_microbatches": 4, + "zero_stage": 2, + "precision": "fp16", + }, + { + "tp_size": 2, + "num_microbatches": 4, + "zero_stage": 1, + "precision": "bf16", + } + ], +) +def exam_bert_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + test_config["use_lazy_init"] = False + test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel + test_config["initial_scale"] = 2**15 # avoid overflow + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + + if name == "transformers_bert": + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + bert = unwrap_model(org_model, "BertModel", "bert") + sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") + weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"] + + org_optimizer.step() + sharded_optimizer.step() + + # check weights + if test_config["precision"] == "bf16": + atol, rtol = 5e-4, 1e-4 + else: + atol, rtol = 5e-4, 5e-4 + if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): + check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) + clear_layout_converter() + torch.cuda.empty_cache() + def run_dist(rank, world_size, port): @@ -588,9 +691,8 @@ def run_dist(rank, world_size, port): # exam_dist_adafactor_base() # exam_dist_adafactor_fwd_bwd() exam_dist_adafactor_zero() - # exam_dist_adafactor_booster() - # run_bert_test(optim_class=Adafactor, sharded_optim_class=DistributedAdaFactor) - + exam_bert_test() + @pytest.mark.dist From ce58adfd276d505cda9b5da50ced0b8a7808b767 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 9 Apr 2024 18:01:59 +0800 Subject: [PATCH 05/35] [feature] Add transformers_bert model zoo in testcase; --- .../nn/optimizer/distributed_adafactor.py | 2 +- tests/test_optimizer/test_dist_adafactor.py | 45 +++++++++---------- 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index b1d313678ecb..794dda755a16 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -189,7 +189,7 @@ def step(self, closure=None): ) # [W/TP] if self.shard_spec.sharding_sequence[-1] == "R": # Row Parallel - # Row Residual situation + # Row indivisible shape situation if self.grad_shape[0] % self.data_parallel_size != 0: state["exp_avg_sq_row"] = torch.zeros( self.grad_shape[0], device=p.device, dtype=p.dtype diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index f5424ee1738a..4b7cb6c161df 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -653,35 +653,34 @@ def exam_bert_test(test_config): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name == "transformers_bert": - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( - model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor - ) + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor + ) - org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( - org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster - ) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) - stage_manager = booster.plugin.stage_manager - tp_group = booster.plugin.tp_group + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group - bert = unwrap_model(org_model, "BertModel", "bert") - sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") - weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"] + bert = unwrap_model(org_model, "BertModel", "bert") + sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") + weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"] - org_optimizer.step() - sharded_optimizer.step() + org_optimizer.step() + sharded_optimizer.step() - # check weights - if test_config["precision"] == "bf16": - atol, rtol = 5e-4, 1e-4 - else: - atol, rtol = 5e-4, 5e-4 - if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): - check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) - clear_layout_converter() - torch.cuda.empty_cache() + # check weights + if test_config["precision"] == "bf16": + atol, rtol = 5e-4, 1e-4 + else: + atol, rtol = 5e-4, 5e-4 + if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): + check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) + clear_layout_converter() + torch.cuda.empty_cache() From efac2a1854f611453c4390ecfc1f03af7c498f97 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 10 Apr 2024 14:30:11 +0800 Subject: [PATCH 06/35] [feature] add user documentation to docs/source/feature. --- colossalai/nn/optimizer/README.md | 4 +- .../nn/optimizer/distributed_adafactor.py | 1 - .../en/features/distributed_adafactor.md | 138 ++++++++++++++++++ 3 files changed, 141 insertions(+), 2 deletions(-) create mode 100644 docs/source/en/features/distributed_adafactor.md diff --git a/colossalai/nn/optimizer/README.md b/colossalai/nn/optimizer/README.md index 07c95143c74c..14de2e541508 100644 --- a/colossalai/nn/optimizer/README.md +++ b/colossalai/nn/optimizer/README.md @@ -89,7 +89,9 @@ A series of optimizers have been optimized and integrated. ### Distributed Adafactor -Distributed Adafactor supports tensor parallelism and ZerO optimization. +Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details. + +### Distributed Adafactor API ### Performance | Version | iter | Float Percision | Device Nums | weight shape | Avg runtime(ms) | Avg Speed Up Rate | Best Speed Up Rate | diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index 794dda755a16..c73ca46d8934 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -253,7 +253,6 @@ def step(self, closure=None): ) # view update to origin[tp] shape update_reshape = update.view(-1, self.grad_shape[1]) - # gather grad[flatten] along dp group then reshape to [H/tp, W] grad = _gather( input_=grad, dim=-1, process_group=self.data_parallel_group diff --git a/docs/source/en/features/distributed_adafactor.md b/docs/source/en/features/distributed_adafactor.md new file mode 100644 index 000000000000..d15f03bece84 --- /dev/null +++ b/docs/source/en/features/distributed_adafactor.md @@ -0,0 +1,138 @@ +# Distributed Adafactor + +Author: + +**Related Paper** +- [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235) + +## Introduction + +Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details. + + +## Performance + +| Parallel strategy | iter | Float Percision | Device Nums | weight shape | Avg runtime(ms) | Avg Speed Up Rate | Best Speed Up Rate | +| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | +| AdaFactor | 50 | float32 | 2 | [4096 , 4096] | 0.58 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.41 | 1.39 | 56.91 | +| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.61 | 0.96 | 18.69 | +| AdaFactor | 50 | float16 | 2 | [4096 , 4096] | 0.72 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.54 | 1.33 | 26.03 | +| DistAdaFactor(Row Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.67 | 1.08 | 20.55 | +| AdaFactor | 50 | bfloat16 | 2 | [4096 , 4096] | 0.72 | - | - | +| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.55 | 1.31 | 26.11 | +| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.67 | 1.07 | 21.86 | +| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | +| AdaFactor | 50 | float32 | 4 | [4096 , 4096] | 0.57 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.38 | 1.48 | 53.99 | +| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.60 | 0.95 | 16.53 | +| AdaFactor | 50 | float16 | 4 | [4096 , 4096] | 0.70 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.50 | 1.44 | 21.98 | +| DistAdaFactor(Row Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.64 | 1.12 | 15.35 | +| AdaFactor | 50 | bfloat16 | 4 | [4096 , 4096] | 0.72 | - | - | +| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.56 | 1.29 | 25.63 | +| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.71 | 1.09 | 21.52 | +| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | +| AdaFactor | 50 | float32 | 8 | [4096 , 4096] | 0.56 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.38 | 1.50 | 54.51 | +| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.91 | 0.67 | 15.68 | +| AdaFactor | 50 | float16 | 8 | [4096 , 4096] | 0.74 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float16 | 8 | [4096 , 4096] | 0.84 | 0.87 | 9.21 | +| DistAdaFactor(Row Parallel) | 50 | float16 | 8 | [4096 , 4096] | 1.03 | 0.75 | 16.12 | +| AdaFactor | 50 | bfloat16 | 8 | [4096 , 4096] | 0.71 | - | - | +| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.54 | 1.31 | 27.28 | +| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.73 | 1.03 | 25.01 | + + +## Hands-On Practice +We now demonstrate how to use Distributed Adafactor. +### step 1. Import libraries + +```python +import torch +from torch import nn +import torch.distributed as dist + +from colossalai.cluster import ProcessGroupMesh +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row +from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor + +``` + +### step 2. Initialize Distributed Environment and Parallism Group +We then need to initialize distributed environment. For demo purpose, we uses `colossalai.launch`. You can refer to [Launch Colossal-AI](../basics/launch_colossalai.md) +for other initialization methods. We use `ProcessGroupMesh` to create tensor parallelism group and data parallelism group. + +```python +# Distributed Enviroment +config = {} +colossalai.launch(config=config, rank=rank, world_size=world_size,host="localhost", port=port, backend="nccl") + +# Parallism Group +tp_size, zero_size = 2, 2 +use_zero = True if zero_size > 1 else False +proc_mesh = ProcessGroupMesh(tp_size, zero_size) +tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) +``` + +### step 3. Initialize Module +Build our model. We created an MLP using two Linear Layer. + +```python +# Init a Tensor Paralleled Module +class TPModel(nn.Module): + def __init__(self, linear1, linear2, tp_group=None): + super().__init__() + self.linear1 = Linear1D_Col.from_native_module( + linear1, process_group=tp_group, gather_output=False, overlap=True + ) + self.linear2 = Linear1D_Row.from_native_module(linear2, process_group=tp_group, parallel_input=True) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x +HEIGHT = 4096 +WIDTH = 4096 +tp_model = TPModel(copy.deepcopy(nn.Linear(HEIGHT, WIDTH)), copy.deepcopy(nn.Linear(HEIGHT, WIDTH)), tp_group).to(local_rank) + +# Get Module parameter +tp_param_group = [p for n, p in tp_model.named_parameters()] +``` + +### step 4. Initialize Optimizer +Then, We initialise the optimiser using the model parameter. Then, we set the distributed information for optimiser. + +```python +# Init a Optimizer +dist_optim = DistributedAdaFactor(tp_param_group) +shard_to_param = {id(p):p for p in tp_param_group} + +# Setup distributed information for Optimizer +dist_optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=dp_group, + shard_to_param=shard_to_param, + use_zero=use_zero, +) +``` + +### step 5. Perform a forward and backward propagation for model and step the gradient + +```python +# Random initialise dataset +x = torch.randn(HEIGHT, WIDTH, device=local_rank) + +# Fwd and Bwd +out_tp = tp_model(x) +if zero_size > 1: + dist_optim.backward(out_tp.sum()) +else: + out_tp.sum().backward() + +# perform step for param and grad +dist_optim.step() +dist_optim.zero_grad() +``` + From 40a55287f1aee0944ac80d89557d30a9716a9cb1 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 10 Apr 2024 14:58:50 +0800 Subject: [PATCH 07/35] [feature] add API Reference & Sample to optimizer Readme; add state check for bert exam; --- colossalai/nn/optimizer/README.md | 26 ++++++++++++++++++- .../nn/optimizer/distributed_adafactor.py | 9 +++---- .../en/features/distributed_adafactor.md | 2 +- tests/test_optimizer/test_dist_adafactor.py | 8 +++++- 4 files changed, 36 insertions(+), 9 deletions(-) diff --git a/colossalai/nn/optimizer/README.md b/colossalai/nn/optimizer/README.md index 14de2e541508..fdf8f128cf0a 100644 --- a/colossalai/nn/optimizer/README.md +++ b/colossalai/nn/optimizer/README.md @@ -91,7 +91,31 @@ A series of optimizers have been optimized and integrated. Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details. -### Distributed Adafactor API +### API Reference + +{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} + +### Sample: Init with booster + +```python +# ============================== +# Model Init +# ============================== +tp_model = TPModel() + +# ============================== +# Optimizer Init +# ============================== +dist_optim = DistributedAdaFactor([p for n, p in tp_model.named_parameters()]) + +# ============================== +# Booster Init +# ============================== +plugin = TorchDDPPlugin() +booster = Booster(plugin=plugin) +criterion = lambda x: x.mean() +tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion) +``` ### Performance | Version | iter | Float Percision | Device Nums | weight shape | Avg runtime(ms) | Avg Speed Up Rate | Best Speed Up Rate | diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index c73ca46d8934..8fa90c6cb66d 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -27,7 +27,7 @@ def __init__( relative_step=True, warmup_init=False, ): - lr=None + lr = None if lr is not None and relative_step: raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") if warmup_init and not relative_step: @@ -162,17 +162,14 @@ def step(self, closure=None): grad = p.grad if grad.is_sparse: raise RuntimeError("Adafactor does not support sparse gradients.") + state = self.state[p] self.grad_shape = grad.shape # 1 dim shape - - # print(f"self.shard_to_param {self.shard_to_param}") - param_is_dtensor = is_distributed_tensor(self.shard_to_param.get(id(p))) - if param_is_dtensor: self.grad_shape = self.shard_to_param.get(id(p)).shape # tp shape (2 dim) - self.factored, self.use_first_moment = self._get_options(group, self.grad_shape) + if len(state) == 0: state["step"] = 0 if self.use_first_moment: diff --git a/docs/source/en/features/distributed_adafactor.md b/docs/source/en/features/distributed_adafactor.md index d15f03bece84..fbddf7d7d2be 100644 --- a/docs/source/en/features/distributed_adafactor.md +++ b/docs/source/en/features/distributed_adafactor.md @@ -135,4 +135,4 @@ else: dist_optim.step() dist_optim.zero_grad() ``` - + diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index 4b7cb6c161df..0813c0594b7e 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -31,7 +31,7 @@ from colossalai.utils import set_seed from colossalai.zero import LowLevelZeroOptimizer from tests.kit.model_zoo import model_zoo -from tests.test_optimizer._utils import run_bert_test +from tests.test_optimizer._utils import run_bert_test, check_optim_states from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, check_weight, @@ -39,6 +39,7 @@ unwrap_model, ) + HEIGHT = 4 WIDTH = 4 _TP_SPEC = DimSpec([0]) @@ -679,6 +680,11 @@ def exam_bert_test(test_config): atol, rtol = 5e-4, 5e-4 if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) + + # check optim states + check_optim_states(org_optimizer, sharded_optimizer.optim) + + clear_layout_converter() torch.cuda.empty_cache() From 1c9bb930b59d0f363a1b749e592b2bed5d1c1d51 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 10 Apr 2024 15:22:18 +0800 Subject: [PATCH 08/35] [feature] modify user documentation; --- colossalai/nn/optimizer/README.md | 68 ------------- .../en/features/distributed_adafactor.md | 95 ++++++++++++------- 2 files changed, 61 insertions(+), 102 deletions(-) diff --git a/colossalai/nn/optimizer/README.md b/colossalai/nn/optimizer/README.md index fdf8f128cf0a..d3f8badc7313 100644 --- a/colossalai/nn/optimizer/README.md +++ b/colossalai/nn/optimizer/README.md @@ -81,71 +81,3 @@ If you wish to add an optimizer for a specific application, please follow the st If your PR is accepted, we may invite you to put up a tutorial or blog in [ColossalAI Documentation](https://colossalai.org/). - - -## Optimizer - -A series of optimizers have been optimized and integrated. - -### Distributed Adafactor - -Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details. - -### API Reference - -{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} - -### Sample: Init with booster - -```python -# ============================== -# Model Init -# ============================== -tp_model = TPModel() - -# ============================== -# Optimizer Init -# ============================== -dist_optim = DistributedAdaFactor([p for n, p in tp_model.named_parameters()]) - -# ============================== -# Booster Init -# ============================== -plugin = TorchDDPPlugin() -booster = Booster(plugin=plugin) -criterion = lambda x: x.mean() -tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion) -``` - -### Performance -| Version | iter | Float Percision | Device Nums | weight shape | Avg runtime(ms) | Avg Speed Up Rate | Best Speed Up Rate | -| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | -| AdaFactor | 50 | float32 | 2 | [4096 , 4096] | 0.58 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.41 | 1.39 | 56.91 | -| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.61 | 0.96 | 18.69 | -| AdaFactor | 50 | float16 | 2 | [4096 , 4096] | 0.72 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.54 | 1.33 | 26.03 | -| DistAdaFactor(Row Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.67 | 1.08 | 20.55 | -| AdaFactor | 50 | bfloat16 | 2 | [4096 , 4096] | 0.72 | - | - | -| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.55 | 1.31 | 26.11 | -| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.67 | 1.07 | 21.86 | -| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | -| AdaFactor | 50 | float32 | 4 | [4096 , 4096] | 0.57 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.38 | 1.48 | 53.99 | -| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.60 | 0.95 | 16.53 | -| AdaFactor | 50 | float16 | 4 | [4096 , 4096] | 0.70 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.50 | 1.44 | 21.98 | -| DistAdaFactor(Row Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.64 | 1.12 | 15.35 | -| AdaFactor | 50 | bfloat16 | 4 | [4096 , 4096] | 0.72 | - | - | -| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.56 | 1.29 | 25.63 | -| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.71 | 1.09 | 21.52 | -| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | -| AdaFactor | 50 | float32 | 8 | [4096 , 4096] | 0.56 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.38 | 1.50 | 54.51 | -| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.91 | 0.67 | 15.68 | -| AdaFactor | 50 | float16 | 8 | [4096 , 4096] | 0.74 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float16 | 8 | [4096 , 4096] | 0.84 | 0.87 | 9.21 | -| DistAdaFactor(Row Parallel) | 50 | float16 | 8 | [4096 , 4096] | 1.03 | 0.75 | 16.12 | -| AdaFactor | 50 | bfloat16 | 8 | [4096 , 4096] | 0.71 | - | - | -| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.54 | 1.31 | 27.28 | -| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.73 | 1.03 | 25.01 | diff --git a/docs/source/en/features/distributed_adafactor.md b/docs/source/en/features/distributed_adafactor.md index fbddf7d7d2be..b607e5d0aa9e 100644 --- a/docs/source/en/features/distributed_adafactor.md +++ b/docs/source/en/features/distributed_adafactor.md @@ -9,41 +9,9 @@ Author: Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details. +### API Reference -## Performance - -| Parallel strategy | iter | Float Percision | Device Nums | weight shape | Avg runtime(ms) | Avg Speed Up Rate | Best Speed Up Rate | -| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | -| AdaFactor | 50 | float32 | 2 | [4096 , 4096] | 0.58 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.41 | 1.39 | 56.91 | -| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.61 | 0.96 | 18.69 | -| AdaFactor | 50 | float16 | 2 | [4096 , 4096] | 0.72 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.54 | 1.33 | 26.03 | -| DistAdaFactor(Row Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.67 | 1.08 | 20.55 | -| AdaFactor | 50 | bfloat16 | 2 | [4096 , 4096] | 0.72 | - | - | -| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.55 | 1.31 | 26.11 | -| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.67 | 1.07 | 21.86 | -| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | -| AdaFactor | 50 | float32 | 4 | [4096 , 4096] | 0.57 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.38 | 1.48 | 53.99 | -| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.60 | 0.95 | 16.53 | -| AdaFactor | 50 | float16 | 4 | [4096 , 4096] | 0.70 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.50 | 1.44 | 21.98 | -| DistAdaFactor(Row Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.64 | 1.12 | 15.35 | -| AdaFactor | 50 | bfloat16 | 4 | [4096 , 4096] | 0.72 | - | - | -| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.56 | 1.29 | 25.63 | -| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.71 | 1.09 | 21.52 | -| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | -| AdaFactor | 50 | float32 | 8 | [4096 , 4096] | 0.56 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.38 | 1.50 | 54.51 | -| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.91 | 0.67 | 15.68 | -| AdaFactor | 50 | float16 | 8 | [4096 , 4096] | 0.74 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float16 | 8 | [4096 , 4096] | 0.84 | 0.87 | 9.21 | -| DistAdaFactor(Row Parallel) | 50 | float16 | 8 | [4096 , 4096] | 1.03 | 0.75 | 16.12 | -| AdaFactor | 50 | bfloat16 | 8 | [4096 , 4096] | 0.71 | - | - | -| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.54 | 1.31 | 27.28 | -| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.73 | 1.03 | 25.01 | - +{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} ## Hands-On Practice We now demonstrate how to use Distributed Adafactor. @@ -135,4 +103,63 @@ else: dist_optim.step() dist_optim.zero_grad() ``` + +## Run with booster +We highly recommend users to use booster, a simple, easy to use, and efficient interface. The Code Below is the Distributed Adafactor launched with booster. + +```python +# ============================== +# Model Init +# ============================== +tp_model = TPModel() + +# ============================== +# Optimizer Init +# ============================== +dist_optim = DistributedAdaFactor([p for n, p in tp_model.named_parameters()]) + +# ============================== +# Booster Init +# ============================== +plugin = TorchDDPPlugin() +booster = Booster(plugin=plugin) +criterion = lambda x: x.mean() +tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion) +``` + +## Performance + +| Parallel strategy | iter | Float Percision | Device Nums | weight shape | Avg runtime(ms) | Avg Speed Up Rate | Best Speed Up Rate | +| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | +| AdaFactor | 50 | float32 | 2 | [4096 , 4096] | 0.58 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.41 | 1.39 | 56.91 | +| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.61 | 0.96 | 18.69 | +| AdaFactor | 50 | float16 | 2 | [4096 , 4096] | 0.72 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.54 | 1.33 | 26.03 | +| DistAdaFactor(Row Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.67 | 1.08 | 20.55 | +| AdaFactor | 50 | bfloat16 | 2 | [4096 , 4096] | 0.72 | - | - | +| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.55 | 1.31 | 26.11 | +| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.67 | 1.07 | 21.86 | +| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | +| AdaFactor | 50 | float32 | 4 | [4096 , 4096] | 0.57 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.38 | 1.48 | 53.99 | +| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.60 | 0.95 | 16.53 | +| AdaFactor | 50 | float16 | 4 | [4096 , 4096] | 0.70 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.50 | 1.44 | 21.98 | +| DistAdaFactor(Row Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.64 | 1.12 | 15.35 | +| AdaFactor | 50 | bfloat16 | 4 | [4096 , 4096] | 0.72 | - | - | +| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.56 | 1.29 | 25.63 | +| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.71 | 1.09 | 21.52 | +| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | +| AdaFactor | 50 | float32 | 8 | [4096 , 4096] | 0.56 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.38 | 1.50 | 54.51 | +| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.91 | 0.67 | 15.68 | +| AdaFactor | 50 | float16 | 8 | [4096 , 4096] | 0.74 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float16 | 8 | [4096 , 4096] | 0.84 | 0.87 | 9.21 | +| DistAdaFactor(Row Parallel) | 50 | float16 | 8 | [4096 , 4096] | 1.03 | 0.75 | 16.12 | +| AdaFactor | 50 | bfloat16 | 8 | [4096 , 4096] | 0.71 | - | - | +| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.54 | 1.31 | 27.28 | +| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.73 | 1.03 | 25.01 | + + From 1039f3426aa3eaf8fc172c4cee35a3510e345cc7 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 10 Apr 2024 15:25:47 +0800 Subject: [PATCH 09/35] [fix] fix readme format issue; --- docs/source/en/features/distributed_adafactor.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/source/en/features/distributed_adafactor.md b/docs/source/en/features/distributed_adafactor.md index b607e5d0aa9e..3ba609283def 100644 --- a/docs/source/en/features/distributed_adafactor.md +++ b/docs/source/en/features/distributed_adafactor.md @@ -140,7 +140,6 @@ tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, crit | AdaFactor | 50 | bfloat16 | 2 | [4096 , 4096] | 0.72 | - | - | | DistAdaFactor(Col Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.55 | 1.31 | 26.11 | | DistAdaFactor(Row Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.67 | 1.07 | 21.86 | -| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | | AdaFactor | 50 | float32 | 4 | [4096 , 4096] | 0.57 | - | - | | DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.38 | 1.48 | 53.99 | | DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.60 | 0.95 | 16.53 | @@ -150,7 +149,6 @@ tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, crit | AdaFactor | 50 | bfloat16 | 4 | [4096 , 4096] | 0.72 | - | - | | DistAdaFactor(Col Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.56 | 1.29 | 25.63 | | DistAdaFactor(Row Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.71 | 1.09 | 21.52 | -| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | | AdaFactor | 50 | float32 | 8 | [4096 , 4096] | 0.56 | - | - | | DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.38 | 1.50 | 54.51 | | DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.91 | 0.67 | 15.68 | From 2ffca49595b4b6eb75c8260f7b6da63ce2758211 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 10 Apr 2024 20:05:07 +0800 Subject: [PATCH 10/35] [fix] add zero=0 in testcase; cached augment in dict; --- colossalai/nn/optimizer/adafactor.py | 13 +-- .../nn/optimizer/distributed_adafactor.py | 85 +++++++++------- tests/test_optimizer/test_dist_adafactor.py | 96 +++++++++++-------- 3 files changed, 109 insertions(+), 85 deletions(-) diff --git a/colossalai/nn/optimizer/adafactor.py b/colossalai/nn/optimizer/adafactor.py index 0cedbb2512be..0a9230d7c585 100644 --- a/colossalai/nn/optimizer/adafactor.py +++ b/colossalai/nn/optimizer/adafactor.py @@ -120,8 +120,6 @@ def step(self, closure=None): # grad shape is same as weigh / bias """ grad = p.grad - if grad.dtype in {torch.float16, torch.bfloat16}: - grad = grad.float() if grad.is_sparse: raise RuntimeError("Adafactor does not support sparse gradients.") @@ -168,10 +166,6 @@ def step(self, closure=None): else: state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) - p_data_fp32 = p - if p.dtype in {torch.float16, torch.bfloat16}: - p_data_fp32 = p_data_fp32.float() - state["step"] += 1 # state["RMS"] = self._rms(p_data_fp32) lr = self._get_lr(group, state) @@ -201,9 +195,8 @@ def step(self, closure=None): update = exp_avg if group["weight_decay"] != 0: - p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) - p_data_fp32.add_(-update) - if p.dtype in {torch.float16, torch.bfloat16}: - p.copy_(p_data_fp32) + p.add_(p, alpha=(-group["weight_decay"] * lr)) + p.add_(-update) + return loss diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index 8fa90c6cb66d..ba9caad29739 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -49,11 +49,13 @@ def __init__( self.data_parallel_size = 1 self.data_parallel_group = None self.shard_to_param = None # Dict{id:shape}, sample {id(param): torch.tensor} - self.shard_spec = None - self.grad_shape = None - self.factored = None # bool - self.use_first_moment = None # bool self.use_zero = True + + self.param_is_dtensor_dict = {} # {id(p): True/False} + self.grad_shape_dict = {} # {id(p): master param shape} + self.factored_dict = {} # {id(p): True/False} + self.use_first_moment_dict = {} # {id(p): True/False} + self.shard_spec_dict = {} # {id(p): ShardSpec} super().__init__(params, defaults) @@ -84,8 +86,21 @@ def setup_distributed( self.use_zero = use_zero self.shard_to_param = shard_to_param if shard_to_param is not None else {} - - + # grad is None, cause we dont setup now + for group in self.param_groups: + for p in group["params"]: + self.param_is_dtensor_dict[id(p)] = is_distributed_tensor(self.shard_to_param.get(id(p))) + if self.param_is_dtensor_dict[id(p)]: + self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape + else: + self.grad_shape_dict[id(p)] = p.shape + self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] = self._get_options(group, self.grad_shape_dict[id(p)]) + # if self.factored_dict[id(p)]: + if self.param_is_dtensor_dict[id(p)]: + self.shard_spec_dict[id(p)] = get_sharding_spec(self.shard_to_param.get(id(p))) + else: + self.shard_spec_dict[id(p)] = None + @staticmethod def _get_lr(param_group, param_state): rel_step_sz = param_group["lr"] @@ -164,48 +179,47 @@ def step(self, closure=None): raise RuntimeError("Adafactor does not support sparse gradients.") state = self.state[p] - self.grad_shape = grad.shape # 1 dim shape - param_is_dtensor = is_distributed_tensor(self.shard_to_param.get(id(p))) + grad_shape = self.grad_shape_dict[id(p)] + param_is_dtensor = self.param_is_dtensor_dict[id(p)] if param_is_dtensor: - self.grad_shape = self.shard_to_param.get(id(p)).shape # tp shape (2 dim) - self.factored, self.use_first_moment = self._get_options(group, self.grad_shape) - + grad_shape = self.shard_to_param.get(id(p)).shape # tp shape (2 dim) + factored, use_first_moment = self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] + shard_spec = self.shard_spec_dict[id(p)] if len(state) == 0: state["step"] = 0 - if self.use_first_moment: + if use_first_moment: # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(p) - if self.factored: - self.shard_spec = get_sharding_spec(self.shard_to_param.get(id(p))) - if self.shard_spec.sharding_sequence[0] == "R": # Col Parallel + if factored and param_is_dtensor: + if shard_spec.sharding_sequence[0] == "R": # Col Parallel state["exp_avg_sq_row"] = torch.zeros( - self.grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype + grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype ) # [H/dp] state["exp_avg_sq_col"] = torch.zeros( - self.grad_shape[1], device=p.device, dtype=p.dtype + grad_shape[1], device=p.device, dtype=p.dtype ) # [W/TP] - if self.shard_spec.sharding_sequence[-1] == "R": # Row Parallel + if shard_spec.sharding_sequence[-1] == "R": # Row Parallel # Row indivisible shape situation - if self.grad_shape[0] % self.data_parallel_size != 0: + if grad_shape[0] % self.data_parallel_size != 0: state["exp_avg_sq_row"] = torch.zeros( - self.grad_shape[0], device=p.device, dtype=p.dtype + grad_shape[0], device=p.device, dtype=p.dtype ) # [H/dp/Tp] else: state["exp_avg_sq_row"] = torch.zeros( - self.grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype + grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype ) # [H/dp/Tp] state["exp_avg_sq_col"] = torch.zeros( - self.grad_shape[1], device=p.device, dtype=p.dtype + grad_shape[1], device=p.device, dtype=p.dtype ) # [W] else: state["exp_avg_sq"] = torch.zeros_like(p) state["RMS"] = 0 else: - if self.use_first_moment: + if use_first_moment: state["exp_avg"] = state["exp_avg"].to(grad) - if self.factored: + if factored: state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) else: @@ -216,15 +230,14 @@ def step(self, closure=None): beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) update = (grad**2) + group["eps"][0] - if self.factored: + if factored and param_is_dtensor: # ============================== # First Dim is R, Last Dim is S{} means split dim -1 ---> # Coloum Parallel ---> sq_row need Do (col) Reduce # ============================== - self.shard_spec = get_sharding_spec(self.shard_to_param.get(id(p))) - if self.shard_spec.sharding_sequence[0] == "R": - update_reshape = update.view(-1, self.grad_shape[1]) - grad_reshape = grad.view(-1, self.grad_shape[1]) + if shard_spec.sharding_sequence[0] == "R": + update_reshape = update.view(-1, grad_shape[1]) + grad_reshape = grad.view(-1, grad_shape[1]) exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp] exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp] exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) @@ -241,20 +254,20 @@ def step(self, closure=None): # Last Dim is R, First Dim is S{} means split dim 0 ---> # Row Parallel ---> sq_col need Do (row) Reduce # ============================== - elif self.shard_spec.sharding_sequence[-1] == "R": + elif shard_spec.sharding_sequence[-1] == "R": # Row Residual situation - if self.grad_shape[0] % self.data_parallel_size != 0: + if grad_shape[0] % self.data_parallel_size != 0: # gather update[flatten] along dp group then reshape to [H/tp, W] update = _gather( input_=update, dim=-1, process_group=self.data_parallel_group ) # view update to origin[tp] shape - update_reshape = update.view(-1, self.grad_shape[1]) + update_reshape = update.view(-1, grad_shape[1]) # gather grad[flatten] along dp group then reshape to [H/tp, W] grad = _gather( input_=grad, dim=-1, process_group=self.data_parallel_group ) - grad_reshape = grad.view(-1, self.grad_shape[1]) + grad_reshape = grad.view(-1, grad_shape[1]) exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp] exp_avg_sq_col = state["exp_avg_sq_col"] # [W] exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) @@ -269,8 +282,8 @@ def step(self, closure=None): else: update = update_reshape else: - update_reshape = update.view(-1, self.grad_shape[1]) - grad_reshape = grad.view(-1, self.grad_shape[1]) + update_reshape = update.view(-1, grad_shape[1]) + grad_reshape = grad.view(-1, grad_shape[1]) exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp/tp] exp_avg_sq_col = state["exp_avg_sq_col"] # [W] exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) @@ -297,7 +310,7 @@ def step(self, closure=None): # (Line No.8) RMS # update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) update.mul_(lr) - if self.use_first_moment: + if use_first_moment: exp_avg = state["exp_avg"] exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) update = exp_avg diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index 0813c0594b7e..218eebd0d4da 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -38,6 +38,7 @@ run_forward_backward_with_hybrid_plugin, unwrap_model, ) +from colossalai.shardformer.layer.utils import Randomizer HEIGHT = 4 @@ -370,7 +371,7 @@ def exam_dist_adafactor_fwd_bwd(dtype: torch.dtype, tp_zero_size: tuple[int, int @parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 -@parameterize("tp_zero_size", [(1, 4), (4, 1), (2, 2)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) +@parameterize("tp_zero_size", [(2, 2), (4, 1),(1, 4)]) # (2, 2), (4, 1),(1, 4), def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): tp_size, zero_size = tp_zero_size use_zero = True if zero_size > 1 else False @@ -456,7 +457,7 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): base_optim.zero_grad() dist_optim.zero_grad() - print(f"data type {dtype},tp size {tp_size}, dp size {zero_size}\n") + # print(f"data type {dtype},tp size {tp_size}, dp size {zero_size}\n") for p, tp_p in zip(base_param_group, tp_param_group): param_is_distributed = is_distributed_tensor(tp_p) if param_is_distributed: @@ -478,10 +479,6 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): pass correctness = correctness_verify(p.data, tp_p.data, dtype) - # print(f"Curr Param correct {correctness}") - # if not correctness: - # print(f"{correctness}\n p.data {p.data}\n tp_p.data{tp_p.data}\n") - @parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 @parameterize("tp_zero_size", [(1, 4), (4, 1), (2, 2)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) @@ -643,7 +640,14 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int "num_microbatches": 4, "zero_stage": 1, "precision": "bf16", - } + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 4, + "zero_stage": 0, + "precision": "bf16", + }, ], ) def exam_bert_test(test_config): @@ -651,42 +655,56 @@ def exam_bert_test(test_config): test_config["use_lazy_init"] = False test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel test_config["initial_scale"] = 2**15 # avoid overflow + model_list = [ + "transformers_bert" + "transformers_bert_for_pretraining" + "transformers_bert_lm_head_model" + "transformers_bert_for_masked_lm" + "transformers_bert_for_sequence_classification" + # "transformers_bert_for_token_classification" + "transformers_bert_for_next_sentence" + "transformers_bert_for_mcq" + "transformers_bert_for_question_answering" + ] for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( - model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor - ) - - org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( - org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster - ) - - - stage_manager = booster.plugin.stage_manager - tp_group = booster.plugin.tp_group - - bert = unwrap_model(org_model, "BertModel", "bert") - sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") - weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"] - - org_optimizer.step() - sharded_optimizer.step() - - # check weights - if test_config["precision"] == "bf16": - atol, rtol = 5e-4, 1e-4 - else: - atol, rtol = 5e-4, 5e-4 - if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): - check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) - - # check optim states - check_optim_states(org_optimizer, sharded_optimizer.optim) - + + if name in model_list: + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + bert = unwrap_model(org_model, "BertModel", "bert") + sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") + weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"] + + org_optimizer.step() + sharded_optimizer.step() + + # check weights + if test_config["precision"] == "bf16": + atol, rtol = 5e-4, 5e-4 + else: + atol, rtol = 5e-4, 5e-4 + if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): + check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) + print(f"{name} check pass") + # check optim states + check_optim_states(org_optimizer, sharded_optimizer.optim) + clear_layout_converter() + Randomizer.reset_index() torch.cuda.empty_cache() + @@ -695,7 +713,7 @@ def run_dist(rank, world_size, port): colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # exam_dist_adafactor_base() # exam_dist_adafactor_fwd_bwd() - exam_dist_adafactor_zero() + # exam_dist_adafactor_zero() exam_bert_test() From 0fd62a022d380d55d757a5ef7cc69edfd4319224 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 11 Apr 2024 11:13:43 +0800 Subject: [PATCH 11/35] [fix] fix percision issue; --- tests/test_optimizer/test_dist_adafactor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index 218eebd0d4da..a98fca5c47d2 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -654,14 +654,14 @@ def exam_bert_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") test_config["use_lazy_init"] = False test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel - test_config["initial_scale"] = 2**15 # avoid overflow + test_config["initial_scale"] = 2**10 # avoid overflow model_list = [ "transformers_bert" "transformers_bert_for_pretraining" "transformers_bert_lm_head_model" "transformers_bert_for_masked_lm" "transformers_bert_for_sequence_classification" - # "transformers_bert_for_token_classification" + "transformers_bert_for_token_classification" "transformers_bert_for_next_sentence" "transformers_bert_for_mcq" "transformers_bert_for_question_answering" @@ -713,7 +713,7 @@ def run_dist(rank, world_size, port): colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # exam_dist_adafactor_base() # exam_dist_adafactor_fwd_bwd() - # exam_dist_adafactor_zero() + exam_dist_adafactor_zero() exam_bert_test() From 28c3a409df4cadafb347f2f50d839cbaf45fdc43 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 11 Apr 2024 15:34:30 +0800 Subject: [PATCH 12/35] [feature] add distributed rms; --- colossalai/nn/optimizer/adafactor.py | 3 +- .../nn/optimizer/distributed_adafactor.py | 20 ++++++- tests/test_optimizer/test_dist_adafactor.py | 60 ++++++------------- 3 files changed, 36 insertions(+), 47 deletions(-) diff --git a/colossalai/nn/optimizer/adafactor.py b/colossalai/nn/optimizer/adafactor.py index 0a9230d7c585..3b7d998c9f28 100644 --- a/colossalai/nn/optimizer/adafactor.py +++ b/colossalai/nn/optimizer/adafactor.py @@ -186,7 +186,7 @@ def step(self, closure=None): exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) update = exp_avg_sq.rsqrt().mul_(grad) # RMS - # update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) update.mul_(lr) if use_first_moment: @@ -198,5 +198,4 @@ def step(self, closure=None): p.add_(p, alpha=(-group["weight_decay"] * lr)) p.add_(-update) - return loss diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index ba9caad29739..d8d28b89b152 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -306,9 +306,23 @@ def step(self, closure=None): exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) update = exp_avg_sq.rsqrt().mul_(grad) - + # (Line No.8) RMS - # update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + # perform a sum on each device + update_sum = update.pow(2).sum() + num_of_element = update.numel() + # reduce sum on tp group if exist + if self.tensor_parallel_size > 1 and param_is_dtensor: + dist.all_reduce(update_sum, group=self.tensor_parallel_group) + num_of_element = num_of_element * self.tensor_parallel_size + # reduce sum on dp group if exist + if self.data_parallel_size > 1 and param_is_dtensor: + dist.all_reduce(update_sum, group=self.data_parallel_group) + num_of_element = num_of_element * self.data_parallel_size + # div num of element + rms = (update_sum / num_of_element).sqrt() + update.div_((rms / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) if use_first_moment: exp_avg = state["exp_avg"] @@ -319,6 +333,6 @@ def step(self, closure=None): p.add_(p, alpha=(-group["weight_decay"] * lr)) p.add_(-update) - + return loss diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index a98fca5c47d2..c950047e9806 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -45,7 +45,6 @@ WIDTH = 4 _TP_SPEC = DimSpec([0]) - def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torch.dtype = torch.float32): rtol = None atol = None @@ -62,7 +61,6 @@ def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torc # return torch.all(tensor1.isclose(tensor2, rtol=rtol, atol=atol)) assert_close(tensor1, tensor2, rtol=rtol, atol=atol) - # setup param groups; (For zero test optim) def setup_param_groups_zero(model: nn.Module) -> list: no_decay = ["bias", "LayerNorm.weight"] @@ -78,13 +76,11 @@ def setup_param_groups_zero(model: nn.Module) -> list: ] return optimizer_grouped_parameters - # setup param groups; (For base optim) def setup_param_groups(model: nn.Module) -> list: optimizer_grouped_parameters = [p for n, p in model.named_parameters()] return optimizer_grouped_parameters - # setup flatten param groups, sharding spec and shape; (For dist optim) def setup_flatten_param_groups_sharding_spec_shape(model: nn.Module) -> dict: flatten_optimizer_grouped_parameters = [] @@ -100,8 +96,6 @@ def setup_flatten_param_groups_sharding_spec_shape(model: nn.Module) -> dict: else: sharding_spec[id(flatten_p)] = None param_shape[id(flatten_p)] = p.shape - # print(f"sharding_spec {sharding_spec}") - # print(f"param_shape {param_shape}") return flatten_optimizer_grouped_parameters, sharding_spec, param_shape @@ -140,12 +134,10 @@ def set_dist_grad( p.grad = p.data p.data = orig_p - def set_master_param_to_shard_param(master_param_list) -> dict: master_param_to_shard_param ={id(p):p for p in master_param_list} return master_param_to_shard_param - class MlpModel(nn.Module): def __init__(self): super(MlpModel, self).__init__() @@ -157,7 +149,6 @@ def forward(self, x): x = self.linear2(x) return x - class TPModel(nn.Module): def __init__(self, linear1, linear2, tp_group=None): super().__init__() @@ -171,9 +162,6 @@ def forward(self, x): x = self.linear2(x) return x - - - @parameterize("dtype", [torch.float32]) # , torch.float16, torch.bfloat16 @parameterize("tp_zero_size", [(4, 1)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): @@ -190,7 +178,7 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): # ============================== # Base Case # ============================== - H, W = 4096, 4096 + H, W = HEIGHT, WIDTH model_col = nn.Linear(H, W).to(local_rank) # Col parallel weight weight, bias = model_col.weight, model_col.bias @@ -286,7 +274,6 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): print(f"col corrness {col_correct} row correct {row_correct}") - @parameterize("dtype", [torch.float32]) # , torch.float16, torch.bfloat16 @parameterize("tp_zero_size", [(4, 1)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) def exam_dist_adafactor_fwd_bwd(dtype: torch.dtype, tp_zero_size: tuple[int, int]): @@ -349,7 +336,6 @@ def exam_dist_adafactor_fwd_bwd(dtype: torch.dtype, tp_zero_size: tuple[int, int param_is_distributed = is_distributed_tensor(tp_p) if param_is_distributed: shard_spec = get_sharding_spec(tp_p) - # print(f"device {local_rank} shard spec{shard_spec} len {len(shard_spec.sharding_sequence)}\n") if len(shard_spec.sharding_sequence) >= 2: # Col Parallel if shard_spec.sharding_sequence[0] == "R": @@ -365,13 +351,9 @@ def exam_dist_adafactor_fwd_bwd(dtype: torch.dtype, tp_zero_size: tuple[int, int # No TP bias pass correctness = correctness_verify(p.data, tp_p.data, dtype) - # print(f"{correctness}\n p.data {p.data}\n tp_p.data{tp_p.data}\n") - # print(f"Curr Param correct {correctness}") - # print(f"device {local_rank} base_optim state dict {base_optim.optim.state_dict()['state'].items()} \n dist_optim state dict {dist_optim.optim.state_dict()['state'].items()} \n") - @parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 -@parameterize("tp_zero_size", [(2, 2), (4, 1),(1, 4)]) # (2, 2), (4, 1),(1, 4), +@parameterize("tp_zero_size", [(2, 2), (4, 1), (1, 4)]) # (2, 2), (4, 1), (1, 4) def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): tp_size, zero_size = tp_zero_size use_zero = True if zero_size > 1 else False @@ -457,12 +439,10 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): base_optim.zero_grad() dist_optim.zero_grad() - # print(f"data type {dtype},tp size {tp_size}, dp size {zero_size}\n") for p, tp_p in zip(base_param_group, tp_param_group): param_is_distributed = is_distributed_tensor(tp_p) if param_is_distributed: shard_spec = get_sharding_spec(tp_p) - # print(f"device {local_rank} shard spec{shard_spec} len {len(shard_spec.sharding_sequence)}\n") if len(shard_spec.sharding_sequence) >= 2: # Col Parallel if shard_spec.sharding_sequence[0] == "R": @@ -478,8 +458,10 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): # No TP bias pass correctness = correctness_verify(p.data, tp_p.data, dtype) + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() - @parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 @parameterize("tp_zero_size", [(1, 4), (4, 1), (2, 2)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int]): @@ -574,12 +556,10 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int base_optim.zero_grad() dist_optim.zero_grad() - print(f"data type {dtype},tp size {tp_size}, dp size {zero_size}\n") for p, tp_p in zip(base_param_group, tp_param_group): param_is_distributed = is_distributed_tensor(tp_p) if param_is_distributed: shard_spec = get_sharding_spec(tp_p) - # print(f"device {local_rank} shard spec{shard_spec} len {len(shard_spec.sharding_sequence)}\n") if len(shard_spec.sharding_sequence) >= 2: # Col Parallel if shard_spec.sharding_sequence[0] == "R": @@ -656,20 +636,21 @@ def exam_bert_test(test_config): test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel test_config["initial_scale"] = 2**10 # avoid overflow model_list = [ - "transformers_bert" - "transformers_bert_for_pretraining" - "transformers_bert_lm_head_model" - "transformers_bert_for_masked_lm" - "transformers_bert_for_sequence_classification" - "transformers_bert_for_token_classification" - "transformers_bert_for_next_sentence" - "transformers_bert_for_mcq" - "transformers_bert_for_question_answering" + "transformers_bert", + "transformers_bert_for_pretraining", + "transformers_bert_lm_head_model", + "transformers_bert_for_masked_lm", + "transformers_bert_for_sequence_classification", + "transformers_bert_for_token_classification", + "transformers_bert_for_next_sentence", + "transformers_bert_for_mcq", + "transformers_bert_for_question_answering", ] for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - + print(f"model name {name} {name in model_list}") if name in model_list: + print(f"{name} check start") org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor ) @@ -696,28 +677,23 @@ def exam_bert_test(test_config): atol, rtol = 5e-4, 5e-4 if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) - print(f"{name} check pass") # check optim states check_optim_states(org_optimizer, sharded_optimizer.optim) + print(f"{name} check pass") clear_layout_converter() Randomizer.reset_index() torch.cuda.empty_cache() - - - def run_dist(rank, world_size, port): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # exam_dist_adafactor_base() # exam_dist_adafactor_fwd_bwd() - exam_dist_adafactor_zero() + # exam_dist_adafactor_zero() exam_bert_test() - - @pytest.mark.dist @rerun_if_address_is_in_use() def test_dist_adafactor(): From a9c5bf7edb6595dc108d8d4bed9761d270b0109b Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 11 Apr 2024 15:47:05 +0800 Subject: [PATCH 13/35] [feature] remove useless comment in testcase; --- tests/test_optimizer/test_dist_adafactor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index c950047e9806..764087635e2f 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -691,7 +691,7 @@ def run_dist(rank, world_size, port): colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # exam_dist_adafactor_base() # exam_dist_adafactor_fwd_bwd() - # exam_dist_adafactor_zero() + exam_dist_adafactor_zero() exam_bert_test() @pytest.mark.dist From 150ac1975f7e472b661373abd9b13036e4612514 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 11 Apr 2024 17:13:28 +0800 Subject: [PATCH 14/35] [fix] Remove useless test; open zero test; remove fp16 test in bert exam; --- .../nn/optimizer/distributed_adafactor.py | 1 - tests/test_optimizer/test_dist_adafactor.py | 233 +----------------- 2 files changed, 11 insertions(+), 223 deletions(-) diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index d8d28b89b152..fb4f09131227 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -95,7 +95,6 @@ def setup_distributed( else: self.grad_shape_dict[id(p)] = p.shape self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] = self._get_options(group, self.grad_shape_dict[id(p)]) - # if self.factored_dict[id(p)]: if self.param_is_dtensor_dict[id(p)]: self.shard_spec_dict[id(p)] = get_sharding_spec(self.shard_to_param.get(id(p))) else: diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index 764087635e2f..17da80521cbd 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -162,8 +162,8 @@ def forward(self, x): x = self.linear2(x) return x -@parameterize("dtype", [torch.float32]) # , torch.float16, torch.bfloat16 -@parameterize("tp_zero_size", [(4, 1)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) +@parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 +@parameterize("tp_zero_size", [(4, 1)]) def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): tp_size, zero_size = tp_zero_size local_rank = dist.get_rank() @@ -202,20 +202,20 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): ) # flatten input(not dtensor) to optimizer bias_row_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True)) - base_param_group = setup_param_groups([weight, bias]) - cp_param_group = setup_param_groups([weight_col_shard_flatten, bias_col_flatten]) - rp_param_group = setup_param_groups([weight_row_shard_flatten, bias_row_flatten]) + # base_param_group = setup_param_groups([weight, bias]) + # cp_param_group = setup_param_groups([weight_col_shard_flatten, bias_col_flatten]) + # rp_param_group = setup_param_groups([weight_row_shard_flatten, bias_row_flatten]) # ============================== # Init Optimizer # ============================== # base - optimizer_base = Adafactor(base_param_group) - cp_dist_optim = DistributedAdaFactor(cp_param_group) - rp_dist_optim = DistributedAdaFactor(rp_param_group) + optimizer_base = Adafactor([weight, bias]) + cp_dist_optim = DistributedAdaFactor([weight_col_shard_flatten, bias_col_flatten]) + rp_dist_optim = DistributedAdaFactor([weight_row_shard_flatten, bias_row_flatten]) - shard_to_param_cp = set_master_param_to_shard_param(cp_dist_optim) + shard_to_param_cp = set_master_param_to_shard_param([weight_col_shard_flatten, bias_col_flatten]) cp_dist_optim.setup_distributed( tensor_parallel_group=tp_group, data_parallel_group=dp_group, @@ -223,7 +223,7 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): use_zero=use_zero, ) - shard_to_param_rp = set_master_param_to_shard_param(rp_dist_optim) + shard_to_param_rp = set_master_param_to_shard_param([weight_row_shard_flatten, bias_row_flatten]) rp_dist_optim.setup_distributed( tensor_parallel_group=tp_group, data_parallel_group=dp_group, @@ -274,84 +274,6 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): print(f"col corrness {col_correct} row correct {row_correct}") -@parameterize("dtype", [torch.float32]) # , torch.float16, torch.bfloat16 -@parameterize("tp_zero_size", [(4, 1)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) -def exam_dist_adafactor_fwd_bwd(dtype: torch.dtype, tp_zero_size: tuple[int, int]): - tp_size, zero_size = tp_zero_size - use_zero = True if zero_size > 1 else False - local_rank = dist.get_rank() - - clear_layout_converter() - proc_mesh = ProcessGroupMesh(tp_size, zero_size) - tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) - - torch.set_default_dtype(dtype) - set_seed(42) - - # ============================== - # Model Init - # ============================== - base_model = MlpModel().to(local_rank) - tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) - - base_param_group = setup_param_groups(base_model) - tp_param_group = setup_param_groups(tp_model) - - # ============================== - # Optimizer Init - # ============================== - base_optim = Adafactor(base_param_group) - dist_optim = DistributedAdaFactor(tp_param_group) - - shard_to_param = set_master_param_to_shard_param(tp_param_group) - dist_optim.setup_distributed( - tensor_parallel_group=tp_group, - data_parallel_group=dp_group, - shard_to_param=shard_to_param, - use_zero=use_zero, - ) - - # ============================== - # Correctness Verify - # ============================== - x = torch.randn(HEIGHT, WIDTH, device=local_rank) - - out = base_model(x) - out_tp = tp_model(x) - - if zero_size > 1: - dist_optim.backward(out_tp.sum()) - base_optim.backward(out.sum()) - else: - out_tp.sum().backward() - out.sum().backward() - - base_optim.step() - dist_optim.step() - - base_optim.zero_grad() - dist_optim.zero_grad() - - for p, tp_p in zip(base_param_group, tp_param_group): - param_is_distributed = is_distributed_tensor(tp_p) - if param_is_distributed: - shard_spec = get_sharding_spec(tp_p) - if len(shard_spec.sharding_sequence) >= 2: - # Col Parallel - if shard_spec.sharding_sequence[0] == "R": - tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather - # ROW Parallel - if shard_spec.sharding_sequence[-1] == "R": - tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather - else: - # TP bias - tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather - - else: - # No TP bias - pass - correctness = correctness_verify(p.data, tp_p.data, dtype) - @parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 @parameterize("tp_zero_size", [(2, 2), (4, 1), (1, 4)]) # (2, 2), (4, 1), (1, 4) def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): @@ -462,120 +384,6 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): Randomizer.reset_index() torch.cuda.empty_cache() -@parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 -@parameterize("tp_zero_size", [(1, 4), (4, 1), (2, 2)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) -def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int]): - tp_size, zero_size = tp_zero_size - local_rank = dist.get_rank() - use_zero = True if zero_size > 1 else False - - clear_layout_converter() - - proc_mesh = ProcessGroupMesh(tp_size, zero_size) - tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) - - torch.set_default_dtype(dtype) - set_seed(42) - - # ============================== - # Model Init - # ============================== - base_model = MlpModel().to(local_rank) - tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) - - base_param_group = setup_param_groups(base_model) - tp_param_group = setup_param_groups(tp_model) - - # ============================== - # Optimizer Init - # ============================== - base_optim = Adafactor(base_param_group) - dist_optim = DistributedAdaFactor(tp_param_group) - - # Setup distributed optimizer - if zero_size > 1: - base_optim = LowLevelZeroOptimizer( - base_optim, - overlap_communication=True, - initial_scale=128, - partition_grad=True, - dp_process_group=dp_group, - verbose=True, - ) - - dist_optim = LowLevelZeroOptimizer( - dist_optim, - overlap_communication=True, - initial_scale=128, - partition_grad=True, - dp_process_group=dp_group, - verbose=True, - ) - shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened - dist_optim.optim.setup_distributed( - tensor_parallel_group=tp_group, - data_parallel_group=dp_group, - shard_to_param=shard_to_param, - use_zero=use_zero, - ) - else: - shard_to_param = set_master_param_to_shard_param(tp_param_group) - dist_optim.setup_distributed( - tensor_parallel_group=tp_group, - data_parallel_group=dp_group, - shard_to_param=shard_to_param, - use_zero=use_zero, - ) - - # ============================== - # Booster Init - # ============================== - plugin = TorchDDPPlugin() - booster = Booster(plugin=plugin) - criterion = lambda x: x.mean() - - tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion) - - # ============================== - # Correctness Verify - # ============================== - x = torch.randn(HEIGHT, WIDTH, device=local_rank) - - out = base_model(x) - out_tp = tp_model(x) - - if zero_size > 1: - dist_optim.backward(out_tp.sum()) - base_optim.backward(out.sum()) - else: - out_tp.sum().backward() - out.sum().backward() - - base_optim.step() - dist_optim.step() - - base_optim.zero_grad() - dist_optim.zero_grad() - for p, tp_p in zip(base_param_group, tp_param_group): - param_is_distributed = is_distributed_tensor(tp_p) - if param_is_distributed: - shard_spec = get_sharding_spec(tp_p) - if len(shard_spec.sharding_sequence) >= 2: - # Col Parallel - if shard_spec.sharding_sequence[0] == "R": - tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather - # ROW Parallel - if shard_spec.sharding_sequence[-1] == "R": - tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather - else: - # TP bias - tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather - - else: - # No TP bias - pass - correctness = correctness_verify(p.data, tp_p.data, dtype) - @parameterize( "test_config", [ @@ -597,24 +405,6 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int "zero_stage": 2, "precision": "bf16", }, - { - "tp_size": 1, - "num_microbatches": 4, - "zero_stage": 2, - "precision": "fp16", - }, - { - "tp_size": 2, - "num_microbatches": 4, - "zero_stage": 2, - "precision": "fp16", - }, - { - "tp_size": 4, - "num_microbatches": 4, - "zero_stage": 2, - "precision": "fp16", - }, { "tp_size": 2, "num_microbatches": 4, @@ -689,8 +479,7 @@ def exam_bert_test(test_config): def run_dist(rank, world_size, port): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - # exam_dist_adafactor_base() - # exam_dist_adafactor_fwd_bwd() + exam_dist_adafactor_base() exam_dist_adafactor_zero() exam_bert_test() From e783599672f9898c86de97d3e1330f4f27b43f49 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 11 Apr 2024 18:01:17 +0800 Subject: [PATCH 15/35] [feature] Extract distributed rms function; --- .../nn/optimizer/distributed_adafactor.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index fb4f09131227..b67580fd0f28 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -125,8 +125,20 @@ def _get_options(param_group, param_shape): return factored, use_first_moment @staticmethod - def _rms(tensor): - return tensor.norm(2) / (tensor.numel() ** 0.5) + def _rms(tensor, param_is_dtensor, tp_size, dp_size, tp_group, dp_group): + tensor_sum = tensor.pow(2).sum() + num_of_element = tensor.numel() + # reduce sum on tp group if exist + if tp_size > 1 and param_is_dtensor: + dist.all_reduce(tensor_sum, group=tp_group) + num_of_element = num_of_element * tp_size + # reduce sum on dp group if exist + if dp_size > 1 and param_is_dtensor: + dist.all_reduce(tensor_sum, group=dp_group) + num_of_element = num_of_element * dp_size + # div num of element + rms = (tensor_sum / num_of_element).sqrt() + return rms @staticmethod def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): @@ -217,12 +229,12 @@ def step(self, closure=None): state["RMS"] = 0 else: if use_first_moment: - state["exp_avg"] = state["exp_avg"].to(grad) + state["exp_avg"] = state["exp_avg"] if factored: - state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) - state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + state["exp_avg_sq_row"] = state["exp_avg_sq_row"] + state["exp_avg_sq_col"] = state["exp_avg_sq_col"] else: - state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + state["exp_avg_sq"] = state["exp_avg_sq"] state["step"] += 1 lr = self._get_lr(group, state) @@ -306,20 +318,8 @@ def step(self, closure=None): exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) update = exp_avg_sq.rsqrt().mul_(grad) - # (Line No.8) RMS - # perform a sum on each device - update_sum = update.pow(2).sum() - num_of_element = update.numel() - # reduce sum on tp group if exist - if self.tensor_parallel_size > 1 and param_is_dtensor: - dist.all_reduce(update_sum, group=self.tensor_parallel_group) - num_of_element = num_of_element * self.tensor_parallel_size - # reduce sum on dp group if exist - if self.data_parallel_size > 1 and param_is_dtensor: - dist.all_reduce(update_sum, group=self.data_parallel_group) - num_of_element = num_of_element * self.data_parallel_size - # div num of element - rms = (update_sum / num_of_element).sqrt() + # # (Line No.8) RMS + rms = self._rms(update, param_is_dtensor, self.tensor_parallel_size, self.data_parallel_size, self.tensor_parallel_group, self.data_parallel_group) update.div_((rms / group["clip_threshold"]).clamp_(min=1.0)) update.mul_(lr) From 9d33a34bc6745e92a07bdf3e4f436606959493a8 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 11 Apr 2024 18:57:34 +0800 Subject: [PATCH 16/35] [feature] add booster + lowlevelzeroPlugin in test; --- tests/test_optimizer/test_dist_adafactor.py | 128 +++++++++++++++++++- 1 file changed, 126 insertions(+), 2 deletions(-) diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index 17da80521cbd..a239e0e16555 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -9,7 +9,8 @@ import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import TorchDDPPlugin, HybridParallelPlugin +from colossalai.booster.plugin import TorchDDPPlugin, HybridParallelPlugin, LowLevelZeroPlugin + from colossalai.cluster import ProcessGroupMesh from colossalai.device.device_mesh import DeviceMesh from colossalai.nn.optimizer.adafactor import Adafactor @@ -272,7 +273,7 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): col_correct = correctness_verify(weight.data, weight_col_gather.data, dtype) row_correct = correctness_verify(weight.data, weight_row_gather.data, dtype) - print(f"col corrness {col_correct} row correct {row_correct}") + print(f"Base Test Pass") @parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 @parameterize("tp_zero_size", [(2, 2), (4, 1), (1, 4)]) # (2, 2), (4, 1), (1, 4) @@ -383,6 +384,127 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): clear_layout_converter() Randomizer.reset_index() torch.cuda.empty_cache() + print(f"Zero Test Pass") + +@parameterize("dtype", [torch.float16]) # torch.float32, torch.float16, torch.bfloat16 +@parameterize("tp_zero_size", [(2, 2), (1, 4)]) # (2, 2), (4, 1), (1, 4) +def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int]): + tp_size, zero_size = tp_zero_size + use_zero = True if zero_size > 1 else False + local_rank = dist.get_rank() + + clear_layout_converter() + + proc_mesh = ProcessGroupMesh(tp_size, zero_size) + tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) + + torch.set_default_dtype(dtype) + set_seed(42) + + # ============================== + # Model Init + # ============================== + base_model = MlpModel().to(local_rank) + tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) + + base_param_group = setup_param_groups(base_model) + tp_param_group = setup_param_groups(tp_model) + tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model) + + # ============================== + # Optimizer Init + # ============================== + base_optim = Adafactor(base_param_group) + dist_optim = DistributedAdaFactor(tp_param_group) + + # Setup distributed optimizer + if zero_size > 1: + base_optim = LowLevelZeroOptimizer( + base_optim, + overlap_communication=True, + initial_scale=128, + partition_grad=True, + dp_process_group=dp_group, + verbose=True, + ) + + dist_optim = LowLevelZeroOptimizer( + dist_optim, + overlap_communication=True, + initial_scale=128, + partition_grad=True, + dp_process_group=dp_group, + verbose=True, + ) + shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened + dist_optim.optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=dp_group, + shard_to_param=shard_to_param, + use_zero=use_zero, + ) + else: + shard_to_param = set_master_param_to_shard_param(tp_param_group) + dist_optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=dp_group, + shard_to_param=shard_to_param, + use_zero=use_zero, + ) + + # ============================== + # Booster Init + # ============================== + plugin = LowLevelZeroPlugin() + booster = Booster(plugin=plugin) + criterion = lambda x: x.mean() + + tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion) + + # ============================== + # Correctness Verify + # ============================== + x = torch.randn(HEIGHT, WIDTH, device=local_rank) + + out = base_model(x) + out_tp = tp_model(x) + + if zero_size > 1: + dist_optim.backward(out_tp.sum()) + base_optim.backward(out.sum()) + else: + out_tp.sum().backward() + out.sum().backward() + + base_optim.step() + dist_optim.step() + + base_optim.zero_grad() + dist_optim.zero_grad() + + for p, tp_p in zip(base_param_group, tp_param_group): + param_is_distributed = is_distributed_tensor(tp_p) + if param_is_distributed: + shard_spec = get_sharding_spec(tp_p) + if len(shard_spec.sharding_sequence) >= 2: + # Col Parallel + if shard_spec.sharding_sequence[0] == "R": + tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + # ROW Parallel + if shard_spec.sharding_sequence[-1] == "R": + tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather + else: + # TP bias + tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + + else: + # No TP bias + pass + correctness = correctness_verify(p.data, tp_p.data, dtype) + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + print(f"Booster Test Pass") @parameterize( "test_config", @@ -475,12 +597,14 @@ def exam_bert_test(test_config): clear_layout_converter() Randomizer.reset_index() torch.cuda.empty_cache() + print(f"Bert Model Zoo Test Pass") def run_dist(rank, world_size, port): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_dist_adafactor_base() exam_dist_adafactor_zero() + exam_dist_adafactor_booster() exam_bert_test() @pytest.mark.dist From 419c1c0adee3867badcd1f2e52e8b831cfbf9c11 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 12 Apr 2024 11:04:27 +0800 Subject: [PATCH 17/35] [feature] add Start_with_booster_API case in md; add Supporting Information in md; --- .../en/features/distributed_adafactor.md | 100 ++++++-------- tests/test_optimizer/test_dist_adafactor.py | 129 +----------------- 2 files changed, 45 insertions(+), 184 deletions(-) diff --git a/docs/source/en/features/distributed_adafactor.md b/docs/source/en/features/distributed_adafactor.md index 3ba609283def..725067bbcd28 100644 --- a/docs/source/en/features/distributed_adafactor.md +++ b/docs/source/en/features/distributed_adafactor.md @@ -14,7 +14,7 @@ Distributed Adafactor is an optimiser that supports hybrid optimisation, includi {{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} ## Hands-On Practice -We now demonstrate how to use Distributed Adafactor. +We now demonstrate how to start Distributed Adafactor with booster API. ### step 1. Import libraries ```python @@ -86,7 +86,15 @@ dist_optim.setup_distributed( ) ``` -### step 5. Perform a forward and backward propagation for model and step the gradient +### step 5.Init Booster + +```python +plugin = LowLevelZeroPlugin() +booster = Booster(plugin=plugin) +criterion = lambda x: x.mean() +tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion) +``` +### step 6.Perform a forward and backward propagation for model and step the gradient ```python # Random initialise dataset @@ -104,60 +112,36 @@ dist_optim.step() dist_optim.zero_grad() ``` -## Run with booster -We highly recommend users to use booster, a simple, easy to use, and efficient interface. The Code Below is the Distributed Adafactor launched with booster. - -```python -# ============================== -# Model Init -# ============================== -tp_model = TPModel() - -# ============================== -# Optimizer Init -# ============================== -dist_optim = DistributedAdaFactor([p for n, p in tp_model.named_parameters()]) - -# ============================== -# Booster Init -# ============================== -plugin = TorchDDPPlugin() -booster = Booster(plugin=plugin) -criterion = lambda x: x.mean() -tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion) -``` - -## Performance - -| Parallel strategy | iter | Float Percision | Device Nums | weight shape | Avg runtime(ms) | Avg Speed Up Rate | Best Speed Up Rate | -| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | -| AdaFactor | 50 | float32 | 2 | [4096 , 4096] | 0.58 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.41 | 1.39 | 56.91 | -| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.61 | 0.96 | 18.69 | -| AdaFactor | 50 | float16 | 2 | [4096 , 4096] | 0.72 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.54 | 1.33 | 26.03 | -| DistAdaFactor(Row Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.67 | 1.08 | 20.55 | -| AdaFactor | 50 | bfloat16 | 2 | [4096 , 4096] | 0.72 | - | - | -| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.55 | 1.31 | 26.11 | -| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.67 | 1.07 | 21.86 | -| AdaFactor | 50 | float32 | 4 | [4096 , 4096] | 0.57 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.38 | 1.48 | 53.99 | -| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.60 | 0.95 | 16.53 | -| AdaFactor | 50 | float16 | 4 | [4096 , 4096] | 0.70 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.50 | 1.44 | 21.98 | -| DistAdaFactor(Row Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.64 | 1.12 | 15.35 | -| AdaFactor | 50 | bfloat16 | 4 | [4096 , 4096] | 0.72 | - | - | -| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.56 | 1.29 | 25.63 | -| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.71 | 1.09 | 21.52 | -| AdaFactor | 50 | float32 | 8 | [4096 , 4096] | 0.56 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.38 | 1.50 | 54.51 | -| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.91 | 0.67 | 15.68 | -| AdaFactor | 50 | float16 | 8 | [4096 , 4096] | 0.74 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float16 | 8 | [4096 , 4096] | 0.84 | 0.87 | 9.21 | -| DistAdaFactor(Row Parallel) | 50 | float16 | 8 | [4096 , 4096] | 1.03 | 0.75 | 16.12 | -| AdaFactor | 50 | bfloat16 | 8 | [4096 , 4096] | 0.71 | - | - | -| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.54 | 1.31 | 27.28 | -| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.73 | 1.03 | 25.01 | - - +## Supporting Information +Model/Feature Compatibility Matrix: +
Model/Feature | +Transformers Bert |
+ Transformers Bert For Pretraining |
+ Transformers Bert Lm Head Model |
+ Transformers Bert For Masked Lm |
+ Transformers Bert For Sequence Classification |
+ Transformers Bert For Token Classification |
+ Transformers Bert For Next Sentence |
+ Transformers Bert For Multiple-choice Question |
+ Transformers Bert For Question Answering |
+ |||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Distributedt Adafactor |
+ ✔️ | +✔️ | +✔️ | +✔️ | +✔️ | +✔️ | +✔️ | +✔️ | +✔️ | +|||||||||||||||||||||||||||||
+ |
Model/Feature | +Transformers Bert |
+ Transformers Bert For Pretraining |
+ Transformers Bert Lm Head Model |
+ Transformers Bert For Masked Lm |
+ Transformers Bert For Sequence Classification |
+ Transformers Bert For Token Classification |
+ Transformers Bert For Next Sentence |
+ Transformers Bert For Multiple-choice Question |
+ Transformers Bert For Question Answering |
+ |||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Hybrid Parallel Plugin |
+ ✔️ | +✔️ | +✔️ | +✔️ | +✔️ | +✔️ | +✔️ | +✔️ | +✔️ | +|||||||||||||||||||||||||||||
Low Level Zero Plugin |
+ ✔️ | +✔️ | +✔️ | +✔️ | +✔️ | +✔️ | +✔️ | +✔️ | +✔️ | +|||||||||||||||||||||||||||||
Torch DDP Plugin |
+ ✔️ | +✔️ | +✔️ | +✔️ | +✔️ | +✔️ | +✔️ | +✔️ | +✔️ | +|||||||||||||||||||||||||||||
Gemini Plugin |
+ ❌ | +❌ | +❌ | +❌ | +❌ | +❌ | +❌ | +❌ | +❌ | +|||||||||||||||||||||||||||||
Moe Hybrid Plugin |
+ ❌ | +❌ | +❌ | +❌ | +❌ | +❌ | +❌ | +❌ | +❌ | +|||||||||||||||||||||||||||||
+ |