From f77ffc2c6b7485f01193148e8758b77386e89a6b Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 8 Apr 2024 17:35:01 +0800 Subject: [PATCH 01/35] [feature] solve conflict; update optimizer readme; --- colossalai/nn/optimizer/README.md | 44 ++ colossalai/nn/optimizer/adafactor.py | 208 ++++++ .../nn/optimizer/distributed_adafactor.py | 275 ++++++++ tests/test_optimizer/test_dist_adafactor.py | 594 ++++++++++++++++++ 4 files changed, 1121 insertions(+) create mode 100644 colossalai/nn/optimizer/adafactor.py create mode 100644 colossalai/nn/optimizer/distributed_adafactor.py create mode 100644 tests/test_optimizer/test_dist_adafactor.py diff --git a/colossalai/nn/optimizer/README.md b/colossalai/nn/optimizer/README.md index d3f8badc7313..b846f1cf55cb 100644 --- a/colossalai/nn/optimizer/README.md +++ b/colossalai/nn/optimizer/README.md @@ -81,3 +81,47 @@ If you wish to add an optimizer for a specific application, please follow the st If your PR is accepted, we may invite you to put up a tutorial or blog in [ColossalAI Documentation](https://colossalai.org/). + + +## Optimizer + +A series of optimizers have been optimized and integrated. + +### Distributed Adafactor + +Distributed Adafactor supports tensor parallelism and ZerO optimization. Here is a brief flowchart of how adafactor implements the tensor parallel: + +[[Tensor Parallel Strategy in Distributed Adafactor]](adafactor_strategy.png) + +### Performance +| Version | iter | Float Percision | Device Nums | weight shape | Avg runtime(ms) | Avg Speed Up Rate | Best Speed Up Rate | +| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | +| AdaFactor | 50 | float32 | 2 | [4096 , 4096] | 0.58 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.41 | 1.39 | 56.91 | +| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.61 | 0.96 | 18.69 | +| AdaFactor | 50 | float16 | 2 | [4096 , 4096] | 0.72 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.54 | 1.33 | 26.03 | +| DistAdaFactor(Row Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.67 | 1.08 | 20.55 | +| AdaFactor | 50 | bfloat16 | 2 | [4096 , 4096] | 0.72 | - | - | +| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.55 | 1.31 | 26.11 | +| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.67 | 1.07 | 21.86 | +| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | +| AdaFactor | 50 | float32 | 4 | [4096 , 4096] | 0.57 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.38 | 1.48 | 53.99 | +| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.60 | 0.95 | 16.53 | +| AdaFactor | 50 | float16 | 4 | [4096 , 4096] | 0.70 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.50 | 1.44 | 21.98 | +| DistAdaFactor(Row Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.64 | 1.12 | 15.35 | +| AdaFactor | 50 | bfloat16 | 4 | [4096 , 4096] | 0.72 | - | - | +| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.56 | 1.29 | 25.63 | +| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.71 | 1.09 | 21.52 | +| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | +| AdaFactor | 50 | float32 | 8 | [4096 , 4096] | 0.56 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.38 | 1.50 | 54.51 | +| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.91 | 0.67 | 15.68 | +| AdaFactor | 50 | float16 | 8 | [4096 , 4096] | 0.74 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float16 | 8 | [4096 , 4096] | 0.84 | 0.87 | 9.21 | +| DistAdaFactor(Row Parallel) | 50 | float16 | 8 | [4096 , 4096] | 1.03 | 0.75 | 16.12 | +| AdaFactor | 50 | bfloat16 | 8 | [4096 , 4096] | 0.71 | - | - | +| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.54 | 1.31 | 27.28 | +| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.73 | 1.03 | 25.01 | diff --git a/colossalai/nn/optimizer/adafactor.py b/colossalai/nn/optimizer/adafactor.py new file mode 100644 index 000000000000..c71676bf5459 --- /dev/null +++ b/colossalai/nn/optimizer/adafactor.py @@ -0,0 +1,208 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +from torch.optim import Optimizer + +__all__ = ["Adafactor"] + + +# Adafactor +class Adafactor(Optimizer): + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + ): + if lr is not None and relative_step: + raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") + if warmup_init and not relative_step: + raise ValueError("`warmup_init=True` requires `relative_step=True`") + + defaults = { + "lr": lr, + "eps": eps, + "clip_threshold": clip_threshold, + "decay_rate": decay_rate, + "beta1": beta1, + "weight_decay": weight_decay, + "scale_parameter": scale_parameter, + "relative_step": relative_step, + "warmup_init": warmup_init, + } + super().__init__(params, defaults) + + @staticmethod + def _get_lr(param_group, param_state): + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps"][1], param_state["RMS"]) + return param_scale * rel_step_sz + + @staticmethod + def _get_options(param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + """ + param_groups: Dict + { + "params":[weight, bias] + "lr" + "eps" + "clip_threshold" + "decay_rate" + "beta1" + "weight_decay" + "scale_parameter" + "relative_step" + "warmup_init" + } + """ + + for group in self.param_groups: + # update weight & bias + for p in group["params"]: + if p.grad is None: + continue + """ + # grad shape is same as weigh / bias + """ + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + """ + p is weight + state + {'step', + 'exp_avg_sq_row', + 'exp_avg_sq_col', + 'RMS' + } + + p is bias + state + {'step', + 'exp_avg_sq', + 'RMS' + } + """ + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = self._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1], device=grad.device) + state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:], device=grad.device) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state["step"] += 1 + # state["RMS"] = self._rms(p_data_fp32) + lr = self._get_lr(group, state) + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad**2) + group["eps"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + # Exponential average of row indexes + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + # Exponential average of columns indexes + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + # RMS + # update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) + p_data_fp32.add_(-update) + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_data_fp32) + + return loss diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py new file mode 100644 index 000000000000..50792ba9af0d --- /dev/null +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -0,0 +1,275 @@ +import math +from typing import Dict + +import torch +import torch.distributed as dist +from torch.optim import Optimizer + +from colossalai.shardformer.layer._operation import _gather +from colossalai.tensor.d_tensor import get_sharding_spec, is_distributed_tensor + +# DistributedAdaFactor (with Tensor parallel and Zero stage 2) +__all__ = ["DistributedAdaFactor"] + + +class DistributedAdaFactor(Optimizer): + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + ): + if lr is not None and relative_step: + raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") + if warmup_init and not relative_step: + raise ValueError("`warmup_init=True` requires `relative_step=True`") + + defaults = { + "lr": lr, + "eps": eps, + "clip_threshold": clip_threshold, + "decay_rate": decay_rate, + "beta1": beta1, + "weight_decay": weight_decay, + "scale_parameter": scale_parameter, + "relative_step": relative_step, + "warmup_init": warmup_init, + } + super().__init__(params, defaults) + self.tensor_parallel_size = 1 + self.tensor_parallel_group = None + self.data_parallel_size = 1 + self.data_parallel_group = None + self.shard_to_param = None # Dict{id:shape}, sample {id(param): torch.tensor} + self.shard_spec = None + self.grad_shape = None + self.factored = None # bool + self.use_first_moment = None # bool + self.use_zero = True + + def setup_distributed( + self, + tensor_parallel_group: dist.ProcessGroup = None, + data_parallel_group: dist.ProcessGroup = None, + shard_to_param: Dict = None, + use_zero: bool = True, + ) -> None: + """ + Inject features to the Optimizer + + Args: + tensor_parallel_group: The devices group for tensor parallel; + data_parallel_group: The devices group for data parallel; + sharding_spec_dict: ShardingSpecs of Each params; + param_shape: Paramater Shape of Each params; + use_zero: Whether or not to use zero; + + """ + self.tensor_parallel_group = tensor_parallel_group # "Expected row process group" + self.data_parallel_group = data_parallel_group + if self.tensor_parallel_group is not None: + self.tensor_parallel_size = dist.get_world_size(self.tensor_parallel_group) + if self.data_parallel_group is not None: + self.data_parallel_size = dist.get_world_size(self.data_parallel_group) + self.shard_to_param = shard_to_param + self.use_zero = use_zero + + @staticmethod + def _get_lr(param_group, param_state): + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps"][1], param_state["RMS"]) + return param_scale * rel_step_sz + + @staticmethod + def _get_options(param_group, param_shape): + """ + Determines whether the current param is factored + Args: + param_group : param group + param_shape : Original Shape of param + + """ + factored = len(param_shape) >= 2 + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + # approx_sq_grad for row parallel weight + @staticmethod + def _approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam): + # row_meam = sq_row_meam + r_factor = (exp_avg_sq_row / sq_row_meam).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization steps + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + """ + param_groups: Dict + { + "params":[weight, bias] + "lr" + "eps" + "clip_threshold" + "decay_rate" + "beta1" + "weight_decay" + "scale_parameter" + "relative_step" + "warmup_init" + } + """ + for group in self.param_groups: + # update weight & bias + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + state = self.state[p] + self.grad_shape = grad.shape # 1 dim shape + + param_is_dtensor = is_distributed_tensor(self.shard_to_param.get(id(p))) + + if param_is_dtensor: + self.grad_shape = self.shard_to_param.get(id(p)).shape # tp shape (2 dim) + + self.factored, self.use_first_moment = self._get_options(group, self.grad_shape) + if len(state) == 0: + state["step"] = 0 + if self.use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p) + if self.factored: + self.shard_spec = get_sharding_spec(self.shard_to_param.get(id(p))) + if self.shard_spec.sharding_sequence[0] == "R": # Col Parallel + state["exp_avg_sq_row"] = torch.zeros( + self.grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype + ) # [H/dp] + state["exp_avg_sq_col"] = torch.zeros( + self.grad_shape[1], device=p.device, dtype=p.dtype + ) # [W/TP] + + if self.shard_spec.sharding_sequence[-1] == "R": # Row Parallel + state["exp_avg_sq_row"] = torch.zeros( + self.grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype + ) # [H/dp/Tp] + state["exp_avg_sq_col"] = torch.zeros( + self.grad_shape[1], device=p.device, dtype=p.dtype + ) # [W] + else: + state["exp_avg_sq"] = torch.zeros_like(p) + state["RMS"] = 0 + else: + if self.use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if self.factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p.float() + # if p.dtype in {torch.float16, torch.bfloat16}: + # p_data_fp32 = p_data_fp32.float() + + state["step"] += 1 + lr = self._get_lr(group, state) + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad**2) + group["eps"][0] + + if self.factored: + # ============================== + # First Dim is R, Last Dim is S{} means split dim -1 ---> + # Coloum Parallel ---> sq_row need Do (col) Reduce + # ============================== + self.shard_spec = get_sharding_spec(self.shard_to_param.get(id(p))) + if self.shard_spec.sharding_sequence[0] == "R": + update_reshape = update.view(-1, self.grad_shape[1]) + grad_reshape = grad.view(-1, self.grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + dist.all_reduce(exp_avg_sq_row, group=self.tensor_parallel_group) + exp_avg_sq_row.div_(self.tensor_parallel_size) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + # update = update_reshape.view(update_reshape.shape[0]*update_reshape.shape[1]) + update = update_reshape.view(-1) + # ============================== + # Last Dim is R, First Dim is S{} means split dim 0 ---> + # Row Parallel ---> sq_col need Do (row) Reduce + # ============================== + elif self.shard_spec.sharding_sequence[-1] == "R": + update_reshape = update.view(-1, self.grad_shape[1]) + grad_reshape = grad.view(-1, self.grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp/tp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group) + exp_avg_sq_col.div_(self.tensor_parallel_size) + # gather row + exp_avg_sq_row_gather = _gather( + input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group + ) + sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True) + update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam) + update_reshape.mul_(grad_reshape) + # update = update_reshape.view(update_reshape.shape[0]*update_reshape.shape[1]) + update = update_reshape.view(-1) + else: + exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + # (Line No.8) RMS + # update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + if self.use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) + + p_data_fp32.add_(-update).flatten() + + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_data_fp32) + + return loss diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py new file mode 100644 index 000000000000..f675f42301a8 --- /dev/null +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -0,0 +1,594 @@ +import copy +import os + +import pytest +import torch +import torch.distributed as dist +from torch import nn + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import TorchDDPPlugin +from colossalai.cluster import ProcessGroupMesh +from colossalai.device.device_mesh import DeviceMesh +from colossalai.nn.optimizer.adafactor import Adafactor +from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer._operation import _gather +from colossalai.tensor.d_tensor import ( + distribute_tensor, + get_layout, + get_sharding_spec, + is_distributed_tensor, + shard_colwise, + shard_rowwise, +) +from colossalai.tensor.d_tensor.sharding_spec import DimSpec +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed +from colossalai.zero import LowLevelZeroOptimizer + +HEIGHT = 4096 +WIDTH = 4096 +_TP_SPEC = DimSpec([0]) + + +def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torch.dtype = torch.float32): + rtol = None + atol = None + if dtype is torch.float32: + rtol = 1e-05 + atol = 1e-05 + elif dtype is torch.float16: + rtol = 5e-2 + atol = 5e-4 + elif dtype is torch.bfloat16: + rtol = 4e-3 + atol = 4e-3 + + return torch.all(tensor1.isclose(tensor2, rtol=rtol, atol=atol)) + # assert_close(tensor1, tensor2, rtol=rtol, atol=atol) + + +# setup param groups; (For zero test optim) +def setup_param_groups_zero(model: nn.Module) -> list: + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": 0.1, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + return optimizer_grouped_parameters + + +# setup param groups; (For base optim) +def setup_param_groups(model: nn.Module) -> list: + optimizer_grouped_parameters = [p for n, p in model.named_parameters()] + return optimizer_grouped_parameters + + +# setup flatten param groups, sharding spec and shape; (For dist optim) +def setup_flatten_param_groups_sharding_spec_shape(model: nn.Module) -> dict: + flatten_optimizer_grouped_parameters = [] + sharding_spec = {} # {id(flatten param): get_layout(p).global_shape} + param_shape = {} # {id(flatten param): get_sharding_spec(p)} + for n, p in model.named_parameters(): + # flatten_p = copy.deepcopy(p).flatten() + flatten_p = nn.Parameter(p.clone().flatten().requires_grad_(True)) + flatten_optimizer_grouped_parameters.append(flatten_p) + if is_distributed_tensor(p): + sharding_spec[id(flatten_p)] = get_sharding_spec(p) + param_shape[id(flatten_p)] = get_layout(p).global_shape + else: + sharding_spec[id(flatten_p)] = None + param_shape[id(flatten_p)] = p.shape + # print(f"sharding_spec {sharding_spec}") + # print(f"param_shape {param_shape}") + return flatten_optimizer_grouped_parameters, sharding_spec, param_shape + + +def set_dist_grad( + dist_module: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype, group: dist.ProcessGroup +) -> None: + """ + Set split grads for Tensor Parallel or ZeRO DP. + We do not need a separate treatment for ZeRO, + as the wrapper takes care of reduce-scattering grads. + """ + rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + + for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()): + if torch_p.grad is None: + torch_p.grad = torch.zeros_like(torch_p) + + is_distributed = hasattr(p, "dist_layout") + if is_distributed: + sharding = p.dist_layout.sharding_spec.sharding_sequence + split_dim = sharding.index(_TP_SPEC) + shape = torch_p.split(world_size, dim=split_dim)[rank].shape + + indices = torch.arange(shape[split_dim] * rank, shape[split_dim] * (rank + 1)) + # Generate grads only for the correctly split chunk + torch_p.grad.index_add_(split_dim, indices, torch.randn(shape, device=torch_p.device, dtype=g_dtype)) + + else: + shape = torch_p.shape + torch_p.grad += torch.randn(shape, device=torch_p.device, dtype=g_dtype) + + # avoid inconsistent grad and param dtype error + orig_p = p.data + p.data = torch_p.grad.clone().to(g_dtype) + p.grad = p.data + p.data = orig_p + + +class MlpModel(nn.Module): + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(HEIGHT, WIDTH) + self.linear2 = nn.Linear(WIDTH, HEIGHT) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +class TPModel(nn.Module): + def __init__(self, linear1, linear2, tp_group=None): + super().__init__() + self.linear1 = Linear1D_Col.from_native_module( + linear1, process_group=tp_group, gather_output=False, overlap=True + ) + self.linear2 = Linear1D_Row.from_native_module(linear2, process_group=tp_group, parallel_input=True) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +@parameterize("dtype", [torch.float32]) # , torch.float16, torch.bfloat16 +def exam_dist_adafactor_base(dtype: torch.dtype): + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + tensor_parallel_size = world_size + torch.set_default_dtype(dtype) + set_seed(42) + + # ============================== + # Base Case + # ============================== + H, W = 4096, 4096 + model_col = nn.Linear(H, W).to(local_rank) # Col parallel weight + weight, bias = model_col.weight, model_col.bias + device_mesh = DeviceMesh( + torch.Tensor([i for i in range(world_size)]), (1, tensor_parallel_size), init_process_group=True + ) + tp_group = device_mesh.get_process_group(axis=1) + # ============================== + # Col Parallel + # ============================== + weight_col_shard = shard_colwise(weight.clone(), device_mesh.get_process_group(axis=1)) + weight_col_shard_layout = get_layout(weight_col_shard) # Layout info weight_col_shard_layout.global_shape + weight_col_shard_shard_spec = get_sharding_spec(weight_col_shard) # Shard spec + weight_col_shard_flatten = nn.Parameter(weight_col_shard.clone().flatten().requires_grad_(True)) + bias_col_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True)) + col_params_shape = { + id(weight_col_shard_flatten): weight_col_shard_layout.global_shape, + id(bias_col_flatten): bias.shape, + } + col_sharding_spec_dict = {id(weight_col_shard_flatten): weight_col_shard_shard_spec, id(bias_col_flatten): None} + + # ============================== + # Row Parallel + # ============================== + weight_row_shard = shard_rowwise(weight.clone(), device_mesh.get_process_group(axis=1)) + weight_row_shard_layout = get_layout(weight_row_shard) # Layout info weight_row_shard_layout.global_shape + weight_row_shard_shard_spec = get_sharding_spec(weight_row_shard) # Shard spec + weight_row_shard_flatten = nn.Parameter( + weight_row_shard.clone().flatten().requires_grad_(True) + ) # flatten input(not dtensor) to optimizer + bias_row_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True)) + row_params_shape = { + id(weight_row_shard_flatten): weight_row_shard_layout.global_shape, + id(bias_row_flatten): bias.shape, + } + row_sharding_spec_dict = {id(weight_row_shard_flatten): weight_row_shard_shard_spec, id(bias_row_flatten): None} + + # ============================== + # Init Optimizer + # ============================== + + # base + optimizer_base = Adafactor([weight, bias]) + + # col parallel + optimizer_cp = DistributedAdaFactor([weight_col_shard_flatten, bias_col_flatten]) + optimizer_cp.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=None, + sharding_spec_dict=col_sharding_spec_dict, + param_shape=col_params_shape, + ) + # row parallel + optimizer_rp = DistributedAdaFactor([weight_row_shard_flatten, bias_row_flatten]) + optimizer_rp.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=None, + sharding_spec_dict=row_sharding_spec_dict, + param_shape=row_params_shape, + ) + + N_STEPS = 1 + for _ in range(N_STEPS): + # base step + optimizer_base.zero_grad() + weight.grad = torch.rand_like(weight) + bias.grad = torch.rand_like(bias) + optimizer_base.step() + + # col parallel step + optimizer_cp.zero_grad() + weight_col_shard_flatten.grad = ( + distribute_tensor(weight.grad, device_mesh, weight_col_shard_shard_spec).clone().flatten() + ) + bias_col_flatten.grad = bias.grad.clone().flatten() + optimizer_cp.step() + + # row parallel step + optimizer_rp.zero_grad() + weight_row_shard_flatten.grad = ( + distribute_tensor(weight.grad, device_mesh, weight_row_shard_shard_spec).clone().flatten() + ) + bias_row_flatten.grad = bias.grad.clone().flatten() + optimizer_rp.step() + + # gather result + weight_col_gather = _gather( + input_=weight_col_shard_flatten.data.view(-1, H // tensor_parallel_size), + dim=-1, + process_group=device_mesh.get_process_group(axis=1), + ) # gather + weight_row_gather = _gather( + input_=weight_row_shard_flatten.data, dim=-1, process_group=device_mesh.get_process_group(axis=1) + ).view( + -1, W + ) # gather + + # verify + col_correct = correctness_verify(weight.data, weight_col_gather.data, dtype) + row_correct = correctness_verify(weight.data, weight_row_gather.data, dtype) + + print(f"col corrness {col_correct} row correct {row_correct}") + + +@parameterize("dtype", [torch.float32]) # , torch.float16, torch.bfloat16 +def exam_dist_adafactor_fwd_bwd(dtype: torch.dtype): + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + tensor_parallel_size = world_size + torch.set_default_dtype(dtype) + set_seed(42) + + # ============================== + # Model Init + # ============================== + device_mesh = DeviceMesh( + torch.Tensor([i for i in range(world_size)]), (1, tensor_parallel_size), init_process_group=True + ) + base_model = MlpModel().to(local_rank) + tp_model = TPModel( + copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), device_mesh.get_process_group(axis=1) + ).to(local_rank) + tp_group = device_mesh.get_process_group(axis=1) + + base_param_group = setup_param_groups(base_model) + tp_param_group, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model) + + # ============================== + # Optimizer Init + # ============================== + base_optim = Adafactor(base_param_group) + dist_optim = DistributedAdaFactor(tp_param_group) + dist_optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=None, + sharding_spec_dict=tp_shard_spec, + param_shape=tp_param_shape, + ) + + # ============================== + # Correctness Verify + # ============================== + x = torch.randn(HEIGHT, WIDTH, device=local_rank) + + loss_tp = tp_model(x).sum() + loss_tp.backward() + + loss = base_model(x).sum() + loss.backward() + + base_optim.zero_grad() + dist_optim.zero_grad() + + base_optim.step() + dist_optim.step() + + for p, tp_p in zip(base_param_group, tp_param_group): + if tp_shard_spec[id(tp_p)] is not None: + if len(tp_shard_spec[id(tp_p)].sharding_sequence) >= 2: + # print(f"device {local_rank} \n tp_p shard spec {tp_shard_spec[id(tp_p)]}\n len {len(tp_shard_spec[id(tp_p)].sharding_sequence)}") + # if tp_p tp_shard_spec is col tp --> view to (-1, H // tensor_parallel_size) then gather + if tp_shard_spec[id(tp_p)].sharding_sequence[0] == "R": + tp_p = _gather( + input_=tp_p.data.view(-1, HEIGHT // tensor_parallel_size), + dim=-1, + process_group=device_mesh.get_process_group(axis=1), + ) # gather + # if tp_p tp_shard_spec is row tp --> gather then view to (-1, H // tensor_parallel_size) + else: + tp_p = _gather(input_=tp_p.data, dim=-1, process_group=device_mesh.get_process_group(axis=1)).view( + -1, WIDTH + ) # gather + else: + # bias parallel + tp_p = _gather(input_=tp_p.data, dim=-1, process_group=device_mesh.get_process_group(axis=1)) + # print(f"device {local_rank} \n p {p}\n tp_p {tp_p}\n") + else: + # compare p and tp no need + pass + # print(f"device {local_rank} \n p {p}\n tp_p {tp_p}\n") + correctness_verify(p.data, tp_p.data, dtype) + # print(f"correct {correctness}") + + +@parameterize("dtype", [torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 +@parameterize("tp_zero_size", [(4, 2)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) +def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): + tp_size, zero_size = tp_zero_size + use_zero = True if zero_size > 1 else False + local_rank = dist.get_rank() + + proc_mesh = ProcessGroupMesh(tp_size, zero_size) + tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) + + torch.set_default_dtype(dtype) + set_seed(42) + + # ============================== + # Model Init + # ============================== + base_model = MlpModel().to(local_rank) + tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) + + base_param_group = setup_param_groups(base_model) + tp_param_group = setup_param_groups(tp_model) + tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model) + + # ============================== + # Optimizer Init + # ============================== + base_optim = Adafactor(base_param_group) + dist_optim = DistributedAdaFactor(tp_param_group) + + # Setup distributed optimizer + if zero_size > 1: + base_optim = LowLevelZeroOptimizer( + base_optim, + overlap_communication=True, + initial_scale=128, + partition_grad=True, + dp_process_group=dp_group, + verbose=True, + ) + + dist_optim = LowLevelZeroOptimizer( + dist_optim, + overlap_communication=True, + initial_scale=128, + partition_grad=True, + dp_process_group=dp_group, + verbose=True, + ) + shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened + dist_optim.optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=dp_group, + shard_to_param=shard_to_param, + use_zero=use_zero, + ) + else: + dist_optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=dp_group, + shard_to_param=shard_to_param, + use_zero=use_zero, + ) + + # ============================== + # Correctness Verify + # ============================== + x = torch.randn(HEIGHT, WIDTH, device=local_rank) + + out = base_model(x) + out_tp = tp_model(x) + + if zero_size > 1: + dist_optim.backward(out_tp.sum()) + base_optim.backward(out.sum()) + else: + out_tp.sum().backward() + out.sum().backward() + + base_optim.step() + dist_optim.step() + + base_optim.zero_grad() + dist_optim.zero_grad() + + for p, tp_p in zip(base_param_group, tp_param_group): + param_is_distributed = is_distributed_tensor(tp_p) + if param_is_distributed: + shard_spec = get_sharding_spec(tp_p) + # print(f"device {local_rank} shard spec{shard_spec} len {len(shard_spec.sharding_sequence)}\n") + if len(shard_spec.sharding_sequence) >= 2: + # Col Parallel + if shard_spec.sharding_sequence[0] == "R": + tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + # ROW Parallel + if shard_spec.sharding_sequence[-1] == "R": + tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather + else: + # TP bias + tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + + else: + # No TP bias + pass + correctness = correctness_verify(p.data, tp_p.data, dtype) + print(f"Curr Param correct {correctness}") + # print(f"device {local_rank} base_optim state dict {base_optim.optim.state_dict()['state'].items()} \n dist_optim state dict {dist_optim.optim.state_dict()['state'].items()} \n") + + + + + +@parameterize("dtype", [torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 +@parameterize("tp_zero_size", [(4, 2)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) +def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int]): + tp_size, zero_size = tp_zero_size + local_rank = dist.get_rank() + use_zero = True if zero_size > 1 else False + + proc_mesh = ProcessGroupMesh(tp_size, zero_size) + tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) + + torch.set_default_dtype(dtype) + set_seed(42) + + # ============================== + # Model Init + # ============================== + base_model = MlpModel().to(local_rank) + tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) + + base_param_group = setup_param_groups(base_model) + tp_param_group = setup_param_groups(tp_model) + + # ============================== + # Optimizer Init + # ============================== + base_optim = Adafactor(base_param_group) + dist_optim = DistributedAdaFactor(tp_param_group) + + # Setup distributed optimizer + if zero_size > 1: + base_optim = LowLevelZeroOptimizer( + base_optim, + overlap_communication=True, + initial_scale=128, + partition_grad=True, + dp_process_group=dp_group, + verbose=True, + ) + + dist_optim = LowLevelZeroOptimizer( + dist_optim, + overlap_communication=True, + initial_scale=128, + partition_grad=True, + dp_process_group=dp_group, + verbose=True, + ) + shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened + dist_optim.optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=dp_group, + shard_to_param=shard_to_param, + use_zero=use_zero, + ) + else: + dist_optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=dp_group, + shard_to_param=shard_to_param, + use_zero=use_zero, + ) + + # ============================== + # Booster Init + # ============================== + plugin = TorchDDPPlugin() + booster = Booster(plugin=plugin) + criterion = lambda x: x.mean() + + tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion) + + # ============================== + # Correctness Verify + # ============================== + x = torch.randn(HEIGHT, WIDTH, device=local_rank) + + out = base_model(x) + out_tp = tp_model(x) + + if zero_size > 1: + dist_optim.backward(out_tp.sum()) + base_optim.backward(out.sum()) + else: + out_tp.sum().backward() + out.sum().backward() + + base_optim.step() + dist_optim.step() + + base_optim.zero_grad() + dist_optim.zero_grad() + + for p, tp_p in zip(base_param_group, tp_param_group): + param_is_distributed = is_distributed_tensor(tp_p) + if param_is_distributed: + shard_spec = get_sharding_spec(tp_p) + # print(f"device {local_rank} shard spec{shard_spec} len {len(shard_spec.sharding_sequence)}\n") + if len(shard_spec.sharding_sequence) >= 2: + # Col Parallel + if shard_spec.sharding_sequence[0] == "R": + tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + # ROW Parallel + if shard_spec.sharding_sequence[-1] == "R": + tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather + else: + # TP bias + tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + else: + # No TP bias + pass + correctness = correctness_verify(p.data, tp_p.data, dtype) + print(f"Curr Param correct {correctness}") + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + # exam_dist_adafactor_base() + # exam_dist_adafactor_fwd_bwd() + exam_dist_adafactor_zero() + # exam_dist_adafactor_booster() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_dist_adafactor(): + spawn(run_dist, nprocs=8) + + +if __name__ == "__main__": + test_dist_adafactor() From b75ac58272e649495761dcbae3bd6e8cc7a3083c Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 8 Apr 2024 17:37:26 +0800 Subject: [PATCH 02/35] [feature] update optimize readme; --- colossalai/nn/optimizer/README.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/colossalai/nn/optimizer/README.md b/colossalai/nn/optimizer/README.md index b846f1cf55cb..07c95143c74c 100644 --- a/colossalai/nn/optimizer/README.md +++ b/colossalai/nn/optimizer/README.md @@ -89,9 +89,7 @@ A series of optimizers have been optimized and integrated. ### Distributed Adafactor -Distributed Adafactor supports tensor parallelism and ZerO optimization. Here is a brief flowchart of how adafactor implements the tensor parallel: - -[[Tensor Parallel Strategy in Distributed Adafactor]](adafactor_strategy.png) +Distributed Adafactor supports tensor parallelism and ZerO optimization. ### Performance | Version | iter | Float Percision | Device Nums | weight shape | Avg runtime(ms) | Avg Speed Up Rate | Best Speed Up Rate | From d5f72fef4f2a0d4a2a4fbc34fbd09145b574bf7e Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 8 Apr 2024 19:28:11 +0800 Subject: [PATCH 03/35] [fix] fix testcase; --- .../nn/optimizer/distributed_adafactor.py | 17 +++++++--- tests/test_optimizer/test_dist_adafactor.py | 33 ++++++++++++------- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index 50792ba9af0d..89ed71d4d799 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -53,6 +53,7 @@ def __init__( self.factored = None # bool self.use_first_moment = None # bool self.use_zero = True + self.is_dist = {} def setup_distributed( self, @@ -78,8 +79,9 @@ def setup_distributed( self.tensor_parallel_size = dist.get_world_size(self.tensor_parallel_group) if self.data_parallel_group is not None: self.data_parallel_size = dist.get_world_size(self.data_parallel_group) - self.shard_to_param = shard_to_param self.use_zero = use_zero + + self.shard_to_param = shard_to_param if shard_to_param is not None else {} @staticmethod def _get_lr(param_group, param_state): @@ -227,7 +229,10 @@ def step(self, closure=None): update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) update_reshape.mul_(grad_reshape) # update = update_reshape.view(update_reshape.shape[0]*update_reshape.shape[1]) - update = update_reshape.view(-1) + if self.use_zero: + update = update_reshape.view(-1) + else: + update = update_reshape # ============================== # Last Dim is R, First Dim is S{} means split dim 0 ---> # Row Parallel ---> sq_col need Do (row) Reduce @@ -250,7 +255,10 @@ def step(self, closure=None): update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam) update_reshape.mul_(grad_reshape) # update = update_reshape.view(update_reshape.shape[0]*update_reshape.shape[1]) - update = update_reshape.view(-1) + if self.use_zero: + update = update_reshape.view(-1) + else: + update = update_reshape else: exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) @@ -267,7 +275,8 @@ def step(self, closure=None): if group["weight_decay"] != 0: p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) - p_data_fp32.add_(-update).flatten() + p_data_fp32.add_(-update) + if p.dtype in {torch.float16, torch.bfloat16}: p.copy_(p_data_fp32) diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index f675f42301a8..253b4568f48c 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -5,6 +5,7 @@ import torch import torch.distributed as dist from torch import nn +from torch.testing import assert_close import colossalai from colossalai.booster import Booster @@ -27,9 +28,10 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed from colossalai.zero import LowLevelZeroOptimizer +from tests.test_optimizer._utils import run_bert_test -HEIGHT = 4096 -WIDTH = 4096 +HEIGHT = 4 +WIDTH = 4 _TP_SPEC = DimSpec([0]) @@ -128,6 +130,11 @@ def set_dist_grad( p.data = orig_p +def set_master_param_to_shard_param(master_param_list) -> dict: + master_param_to_shard_param ={id(p):p for p in master_param_list} + return master_param_to_shard_param + + class MlpModel(nn.Module): def __init__(self): super(MlpModel, self).__init__() @@ -350,8 +357,8 @@ def exam_dist_adafactor_fwd_bwd(dtype: torch.dtype): # print(f"correct {correctness}") -@parameterize("dtype", [torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 -@parameterize("tp_zero_size", [(4, 2)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) +@parameterize("dtype", [torch.float32]) # torch.float32, torch.float16, torch.bfloat16 +@parameterize("tp_zero_size", [(4, 1)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): tp_size, zero_size = tp_zero_size use_zero = True if zero_size > 1 else False @@ -406,6 +413,7 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): use_zero=use_zero, ) else: + shard_to_param = set_master_param_to_shard_param(tp_param_group) dist_optim.setup_distributed( tensor_parallel_group=tp_group, data_parallel_group=dp_group, @@ -454,15 +462,13 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): # No TP bias pass correctness = correctness_verify(p.data, tp_p.data, dtype) - print(f"Curr Param correct {correctness}") + # print(f"{correctness}\n p.data {p.data}\n tp_p.data{tp_p.data}\n") + # print(f"Curr Param correct {correctness}") # print(f"device {local_rank} base_optim state dict {base_optim.optim.state_dict()['state'].items()} \n dist_optim state dict {dist_optim.optim.state_dict()['state'].items()} \n") - - - -@parameterize("dtype", [torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 -@parameterize("tp_zero_size", [(4, 2)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) +@parameterize("dtype", [torch.float32]) # torch.float32, torch.float16, torch.bfloat16 +@parameterize("tp_zero_size", [(4, 1)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int]): tp_size, zero_size = tp_zero_size local_rank = dist.get_rank() @@ -516,10 +522,11 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int use_zero=use_zero, ) else: + shard_to_param = set_master_param_to_shard_param(tp_param_group) dist_optim.setup_distributed( tensor_parallel_group=tp_group, data_parallel_group=dp_group, - shard_to_param=shard_to_param, + shard_to_param={}, use_zero=use_zero, ) @@ -582,12 +589,14 @@ def run_dist(rank, world_size, port): # exam_dist_adafactor_fwd_bwd() exam_dist_adafactor_zero() # exam_dist_adafactor_booster() + # run_bert_test(optim_class=Adafactor, sharded_optim_class=DistributedAdaFactor) + @pytest.mark.dist @rerun_if_address_is_in_use() def test_dist_adafactor(): - spawn(run_dist, nprocs=8) + spawn(run_dist, nprocs=4) if __name__ == "__main__": From 020ed547c4c846739bc52b7e9a970ab30204af82 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 9 Apr 2024 17:13:20 +0800 Subject: [PATCH 04/35] [feature] Add transformer-bert to testcase;solve a bug related to indivisible shape (induction in use_zero and tp is row parallel); --- colossalai/nn/optimizer/adafactor.py | 1 + .../nn/optimizer/distributed_adafactor.py | 111 ++++--- tests/test_optimizer/test_dist_adafactor.py | 312 ++++++++++++------ 3 files changed, 279 insertions(+), 145 deletions(-) diff --git a/colossalai/nn/optimizer/adafactor.py b/colossalai/nn/optimizer/adafactor.py index c71676bf5459..0cedbb2512be 100644 --- a/colossalai/nn/optimizer/adafactor.py +++ b/colossalai/nn/optimizer/adafactor.py @@ -36,6 +36,7 @@ def __init__( relative_step=True, warmup_init=False, ): + lr=None if lr is not None and relative_step: raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") if warmup_init and not relative_step: diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index 89ed71d4d799..b1d313678ecb 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -3,16 +3,17 @@ import torch import torch.distributed as dist -from torch.optim import Optimizer +# from torch.optim import Optimizer +from colossalai.interface.optimizer import DistributedOptim -from colossalai.shardformer.layer._operation import _gather +from colossalai.shardformer.layer._operation import _gather, _split from colossalai.tensor.d_tensor import get_sharding_spec, is_distributed_tensor # DistributedAdaFactor (with Tensor parallel and Zero stage 2) __all__ = ["DistributedAdaFactor"] -class DistributedAdaFactor(Optimizer): +class DistributedAdaFactor(DistributedOptim): def __init__( self, params, @@ -26,6 +27,7 @@ def __init__( relative_step=True, warmup_init=False, ): + lr=None if lr is not None and relative_step: raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") if warmup_init and not relative_step: @@ -42,7 +44,6 @@ def __init__( "relative_step": relative_step, "warmup_init": warmup_init, } - super().__init__(params, defaults) self.tensor_parallel_size = 1 self.tensor_parallel_group = None self.data_parallel_size = 1 @@ -53,13 +54,14 @@ def __init__( self.factored = None # bool self.use_first_moment = None # bool self.use_zero = True - self.is_dist = {} + super().__init__(params, defaults) + def setup_distributed( self, tensor_parallel_group: dist.ProcessGroup = None, data_parallel_group: dist.ProcessGroup = None, - shard_to_param: Dict = None, + shard_to_param: Dict = {}, use_zero: bool = True, ) -> None: """ @@ -82,6 +84,7 @@ def setup_distributed( self.use_zero = use_zero self.shard_to_param = shard_to_param if shard_to_param is not None else {} + @staticmethod def _get_lr(param_group, param_state): @@ -161,7 +164,9 @@ def step(self, closure=None): raise RuntimeError("Adafactor does not support sparse gradients.") state = self.state[p] self.grad_shape = grad.shape # 1 dim shape - + + # print(f"self.shard_to_param {self.shard_to_param}") + param_is_dtensor = is_distributed_tensor(self.shard_to_param.get(id(p))) if param_is_dtensor: @@ -184,9 +189,16 @@ def step(self, closure=None): ) # [W/TP] if self.shard_spec.sharding_sequence[-1] == "R": # Row Parallel - state["exp_avg_sq_row"] = torch.zeros( + # Row Residual situation + if self.grad_shape[0] % self.data_parallel_size != 0: + state["exp_avg_sq_row"] = torch.zeros( + self.grad_shape[0], device=p.device, dtype=p.dtype + ) # [H/dp/Tp] + else: + state["exp_avg_sq_row"] = torch.zeros( self.grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype - ) # [H/dp/Tp] + ) # [H/dp/Tp] + state["exp_avg_sq_col"] = torch.zeros( self.grad_shape[1], device=p.device, dtype=p.dtype ) # [W] @@ -202,10 +214,6 @@ def step(self, closure=None): else: state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) - p_data_fp32 = p.float() - # if p.dtype in {torch.float16, torch.bfloat16}: - # p_data_fp32 = p_data_fp32.float() - state["step"] += 1 lr = self._get_lr(group, state) beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) @@ -228,7 +236,6 @@ def step(self, closure=None): exp_avg_sq_row.div_(self.tensor_parallel_size) update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) update_reshape.mul_(grad_reshape) - # update = update_reshape.view(update_reshape.shape[0]*update_reshape.shape[1]) if self.use_zero: update = update_reshape.view(-1) else: @@ -238,27 +245,54 @@ def step(self, closure=None): # Row Parallel ---> sq_col need Do (row) Reduce # ============================== elif self.shard_spec.sharding_sequence[-1] == "R": - update_reshape = update.view(-1, self.grad_shape[1]) - grad_reshape = grad.view(-1, self.grad_shape[1]) - exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp/tp] - exp_avg_sq_col = state["exp_avg_sq_col"] # [W] - exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) - exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) - # reduce col - dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group) - exp_avg_sq_col.div_(self.tensor_parallel_size) - # gather row - exp_avg_sq_row_gather = _gather( - input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group - ) - sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True) - update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam) - update_reshape.mul_(grad_reshape) - # update = update_reshape.view(update_reshape.shape[0]*update_reshape.shape[1]) - if self.use_zero: - update = update_reshape.view(-1) + # Row Residual situation + if self.grad_shape[0] % self.data_parallel_size != 0: + # gather update[flatten] along dp group then reshape to [H/tp, W] + update = _gather( + input_=update, dim=-1, process_group=self.data_parallel_group + ) + # view update to origin[tp] shape + update_reshape = update.view(-1, self.grad_shape[1]) + + # gather grad[flatten] along dp group then reshape to [H/tp, W] + grad = _gather( + input_=grad, dim=-1, process_group=self.data_parallel_group + ) + grad_reshape = grad.view(-1, self.grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group) + exp_avg_sq_col.div_(self.tensor_parallel_size) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + if self.use_zero: + update = _split(input_=update_reshape.view(-1), dim=-1, process_group=self.data_parallel_group) + else: + update = update_reshape else: - update = update_reshape + update_reshape = update.view(-1, self.grad_shape[1]) + grad_reshape = grad.view(-1, self.grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp/tp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group) + exp_avg_sq_col.div_(self.tensor_parallel_size) + # gather row + exp_avg_sq_row_gather = _gather( + input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group + ) + sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True) + update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam) + update_reshape.mul_(grad_reshape) + if self.use_zero: + update = update_reshape.view(-1) + else: + update = update_reshape else: exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) @@ -273,12 +307,9 @@ def step(self, closure=None): update = exp_avg if group["weight_decay"] != 0: - p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) - - p_data_fp32.add_(-update) - + p.add_(p, alpha=(-group["weight_decay"] * lr)) + + p.add_(-update) - if p.dtype in {torch.float16, torch.bfloat16}: - p.copy_(p_data_fp32) return loss diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index 253b4568f48c..f5424ee1738a 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -9,7 +9,7 @@ import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import TorchDDPPlugin +from colossalai.booster.plugin import TorchDDPPlugin, HybridParallelPlugin from colossalai.cluster import ProcessGroupMesh from colossalai.device.device_mesh import DeviceMesh from colossalai.nn.optimizer.adafactor import Adafactor @@ -20,15 +20,24 @@ distribute_tensor, get_layout, get_sharding_spec, + get_device_mesh, is_distributed_tensor, shard_colwise, shard_rowwise, ) +from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.tensor.d_tensor.sharding_spec import DimSpec from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed from colossalai.zero import LowLevelZeroOptimizer +from tests.kit.model_zoo import model_zoo from tests.test_optimizer._utils import run_bert_test +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_weight, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) HEIGHT = 4 WIDTH = 4 @@ -39,8 +48,8 @@ def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torc rtol = None atol = None if dtype is torch.float32: - rtol = 1e-05 - atol = 1e-05 + rtol = 5e-04 + atol = 5e-04 elif dtype is torch.float16: rtol = 5e-2 atol = 5e-4 @@ -48,8 +57,8 @@ def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torc rtol = 4e-3 atol = 4e-3 - return torch.all(tensor1.isclose(tensor2, rtol=rtol, atol=atol)) - # assert_close(tensor1, tensor2, rtol=rtol, atol=atol) + # return torch.all(tensor1.isclose(tensor2, rtol=rtol, atol=atol)) + assert_close(tensor1, tensor2, rtol=rtol, atol=atol) # setup param groups; (For zero test optim) @@ -161,12 +170,18 @@ def forward(self, x): return x + + @parameterize("dtype", [torch.float32]) # , torch.float16, torch.bfloat16 -def exam_dist_adafactor_base(dtype: torch.dtype): - local_rank = int(os.environ["LOCAL_RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) +@parameterize("tp_zero_size", [(4, 1)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) +def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): + tp_size, zero_size = tp_zero_size + local_rank = dist.get_rank() + use_zero = True if zero_size > 1 else False + + proc_mesh = ProcessGroupMesh(tp_size, zero_size) + tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) - tensor_parallel_size = world_size torch.set_default_dtype(dtype) set_seed(42) @@ -176,64 +191,57 @@ def exam_dist_adafactor_base(dtype: torch.dtype): H, W = 4096, 4096 model_col = nn.Linear(H, W).to(local_rank) # Col parallel weight weight, bias = model_col.weight, model_col.bias - device_mesh = DeviceMesh( - torch.Tensor([i for i in range(world_size)]), (1, tensor_parallel_size), init_process_group=True - ) - tp_group = device_mesh.get_process_group(axis=1) + # ============================== # Col Parallel # ============================== - weight_col_shard = shard_colwise(weight.clone(), device_mesh.get_process_group(axis=1)) + weight_col_shard = shard_colwise(weight.clone(), tp_group) weight_col_shard_layout = get_layout(weight_col_shard) # Layout info weight_col_shard_layout.global_shape weight_col_shard_shard_spec = get_sharding_spec(weight_col_shard) # Shard spec weight_col_shard_flatten = nn.Parameter(weight_col_shard.clone().flatten().requires_grad_(True)) bias_col_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True)) - col_params_shape = { - id(weight_col_shard_flatten): weight_col_shard_layout.global_shape, - id(bias_col_flatten): bias.shape, - } - col_sharding_spec_dict = {id(weight_col_shard_flatten): weight_col_shard_shard_spec, id(bias_col_flatten): None} # ============================== # Row Parallel # ============================== - weight_row_shard = shard_rowwise(weight.clone(), device_mesh.get_process_group(axis=1)) + weight_row_shard = shard_rowwise(weight.clone(), tp_group) weight_row_shard_layout = get_layout(weight_row_shard) # Layout info weight_row_shard_layout.global_shape weight_row_shard_shard_spec = get_sharding_spec(weight_row_shard) # Shard spec weight_row_shard_flatten = nn.Parameter( weight_row_shard.clone().flatten().requires_grad_(True) ) # flatten input(not dtensor) to optimizer bias_row_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True)) - row_params_shape = { - id(weight_row_shard_flatten): weight_row_shard_layout.global_shape, - id(bias_row_flatten): bias.shape, - } - row_sharding_spec_dict = {id(weight_row_shard_flatten): weight_row_shard_shard_spec, id(bias_row_flatten): None} + + base_param_group = setup_param_groups([weight, bias]) + cp_param_group = setup_param_groups([weight_col_shard_flatten, bias_col_flatten]) + rp_param_group = setup_param_groups([weight_row_shard_flatten, bias_row_flatten]) # ============================== # Init Optimizer # ============================== # base - optimizer_base = Adafactor([weight, bias]) - - # col parallel - optimizer_cp = DistributedAdaFactor([weight_col_shard_flatten, bias_col_flatten]) - optimizer_cp.setup_distributed( + optimizer_base = Adafactor(base_param_group) + cp_dist_optim = DistributedAdaFactor(cp_param_group) + rp_dist_optim = DistributedAdaFactor(rp_param_group) + + shard_to_param_cp = set_master_param_to_shard_param(cp_dist_optim) + cp_dist_optim.setup_distributed( tensor_parallel_group=tp_group, - data_parallel_group=None, - sharding_spec_dict=col_sharding_spec_dict, - param_shape=col_params_shape, + data_parallel_group=dp_group, + shard_to_param=shard_to_param_cp, + use_zero=use_zero, ) - # row parallel - optimizer_rp = DistributedAdaFactor([weight_row_shard_flatten, bias_row_flatten]) - optimizer_rp.setup_distributed( + + shard_to_param_rp = set_master_param_to_shard_param(rp_dist_optim) + rp_dist_optim.setup_distributed( tensor_parallel_group=tp_group, - data_parallel_group=None, - sharding_spec_dict=row_sharding_spec_dict, - param_shape=row_params_shape, + data_parallel_group=dp_group, + shard_to_param=shard_to_param_rp, + use_zero=use_zero, ) + N_STEPS = 1 for _ in range(N_STEPS): # base step @@ -243,29 +251,29 @@ def exam_dist_adafactor_base(dtype: torch.dtype): optimizer_base.step() # col parallel step - optimizer_cp.zero_grad() + cp_dist_optim.zero_grad() weight_col_shard_flatten.grad = ( - distribute_tensor(weight.grad, device_mesh, weight_col_shard_shard_spec).clone().flatten() + distribute_tensor(weight.grad, get_device_mesh(weight_col_shard), weight_col_shard_shard_spec).clone().flatten() ) bias_col_flatten.grad = bias.grad.clone().flatten() - optimizer_cp.step() + cp_dist_optim.step() # row parallel step - optimizer_rp.zero_grad() + rp_dist_optim.zero_grad() weight_row_shard_flatten.grad = ( - distribute_tensor(weight.grad, device_mesh, weight_row_shard_shard_spec).clone().flatten() + distribute_tensor(weight.grad, get_device_mesh(weight_row_shard), weight_row_shard_shard_spec).clone().flatten() ) bias_row_flatten.grad = bias.grad.clone().flatten() - optimizer_rp.step() + rp_dist_optim.step() # gather result weight_col_gather = _gather( - input_=weight_col_shard_flatten.data.view(-1, H // tensor_parallel_size), + input_=weight_col_shard_flatten.data.view(-1, H // tp_size), dim=-1, - process_group=device_mesh.get_process_group(axis=1), + process_group=tp_group, ) # gather weight_row_gather = _gather( - input_=weight_row_shard_flatten.data, dim=-1, process_group=device_mesh.get_process_group(axis=1) + input_=weight_row_shard_flatten.data, dim=-1, process_group=tp_group ).view( -1, W ) # gather @@ -278,91 +286,96 @@ def exam_dist_adafactor_base(dtype: torch.dtype): @parameterize("dtype", [torch.float32]) # , torch.float16, torch.bfloat16 -def exam_dist_adafactor_fwd_bwd(dtype: torch.dtype): - local_rank = int(os.environ["LOCAL_RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - tensor_parallel_size = world_size +@parameterize("tp_zero_size", [(4, 1)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) +def exam_dist_adafactor_fwd_bwd(dtype: torch.dtype, tp_zero_size: tuple[int, int]): + tp_size, zero_size = tp_zero_size + use_zero = True if zero_size > 1 else False + local_rank = dist.get_rank() + + clear_layout_converter() + proc_mesh = ProcessGroupMesh(tp_size, zero_size) + tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) + torch.set_default_dtype(dtype) set_seed(42) # ============================== # Model Init # ============================== - device_mesh = DeviceMesh( - torch.Tensor([i for i in range(world_size)]), (1, tensor_parallel_size), init_process_group=True - ) base_model = MlpModel().to(local_rank) - tp_model = TPModel( - copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), device_mesh.get_process_group(axis=1) - ).to(local_rank) - tp_group = device_mesh.get_process_group(axis=1) + tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) base_param_group = setup_param_groups(base_model) - tp_param_group, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model) + tp_param_group = setup_param_groups(tp_model) # ============================== # Optimizer Init # ============================== base_optim = Adafactor(base_param_group) dist_optim = DistributedAdaFactor(tp_param_group) + + shard_to_param = set_master_param_to_shard_param(tp_param_group) dist_optim.setup_distributed( tensor_parallel_group=tp_group, - data_parallel_group=None, - sharding_spec_dict=tp_shard_spec, - param_shape=tp_param_shape, + data_parallel_group=dp_group, + shard_to_param=shard_to_param, + use_zero=use_zero, ) - + # ============================== # Correctness Verify # ============================== x = torch.randn(HEIGHT, WIDTH, device=local_rank) - loss_tp = tp_model(x).sum() - loss_tp.backward() - - loss = base_model(x).sum() - loss.backward() + out = base_model(x) + out_tp = tp_model(x) - base_optim.zero_grad() - dist_optim.zero_grad() + if zero_size > 1: + dist_optim.backward(out_tp.sum()) + base_optim.backward(out.sum()) + else: + out_tp.sum().backward() + out.sum().backward() base_optim.step() dist_optim.step() + base_optim.zero_grad() + dist_optim.zero_grad() + for p, tp_p in zip(base_param_group, tp_param_group): - if tp_shard_spec[id(tp_p)] is not None: - if len(tp_shard_spec[id(tp_p)].sharding_sequence) >= 2: - # print(f"device {local_rank} \n tp_p shard spec {tp_shard_spec[id(tp_p)]}\n len {len(tp_shard_spec[id(tp_p)].sharding_sequence)}") - # if tp_p tp_shard_spec is col tp --> view to (-1, H // tensor_parallel_size) then gather - if tp_shard_spec[id(tp_p)].sharding_sequence[0] == "R": - tp_p = _gather( - input_=tp_p.data.view(-1, HEIGHT // tensor_parallel_size), - dim=-1, - process_group=device_mesh.get_process_group(axis=1), - ) # gather - # if tp_p tp_shard_spec is row tp --> gather then view to (-1, H // tensor_parallel_size) - else: - tp_p = _gather(input_=tp_p.data, dim=-1, process_group=device_mesh.get_process_group(axis=1)).view( - -1, WIDTH - ) # gather + param_is_distributed = is_distributed_tensor(tp_p) + if param_is_distributed: + shard_spec = get_sharding_spec(tp_p) + # print(f"device {local_rank} shard spec{shard_spec} len {len(shard_spec.sharding_sequence)}\n") + if len(shard_spec.sharding_sequence) >= 2: + # Col Parallel + if shard_spec.sharding_sequence[0] == "R": + tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + # ROW Parallel + if shard_spec.sharding_sequence[-1] == "R": + tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather else: - # bias parallel - tp_p = _gather(input_=tp_p.data, dim=-1, process_group=device_mesh.get_process_group(axis=1)) - # print(f"device {local_rank} \n p {p}\n tp_p {tp_p}\n") + # TP bias + tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + else: - # compare p and tp no need + # No TP bias pass - # print(f"device {local_rank} \n p {p}\n tp_p {tp_p}\n") - correctness_verify(p.data, tp_p.data, dtype) - # print(f"correct {correctness}") + correctness = correctness_verify(p.data, tp_p.data, dtype) + # print(f"{correctness}\n p.data {p.data}\n tp_p.data{tp_p.data}\n") + # print(f"Curr Param correct {correctness}") + # print(f"device {local_rank} base_optim state dict {base_optim.optim.state_dict()['state'].items()} \n dist_optim state dict {dist_optim.optim.state_dict()['state'].items()} \n") -@parameterize("dtype", [torch.float32]) # torch.float32, torch.float16, torch.bfloat16 -@parameterize("tp_zero_size", [(4, 1)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) +@parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 +@parameterize("tp_zero_size", [(1, 4), (4, 1), (2, 2)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): tp_size, zero_size = tp_zero_size use_zero = True if zero_size > 1 else False local_rank = dist.get_rank() + + clear_layout_converter() proc_mesh = ProcessGroupMesh(tp_size, zero_size) tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) @@ -442,6 +455,7 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): base_optim.zero_grad() dist_optim.zero_grad() + print(f"data type {dtype},tp size {tp_size}, dp size {zero_size}\n") for p, tp_p in zip(base_param_group, tp_param_group): param_is_distributed = is_distributed_tensor(tp_p) if param_is_distributed: @@ -462,17 +476,20 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): # No TP bias pass correctness = correctness_verify(p.data, tp_p.data, dtype) - # print(f"{correctness}\n p.data {p.data}\n tp_p.data{tp_p.data}\n") + # print(f"Curr Param correct {correctness}") - # print(f"device {local_rank} base_optim state dict {base_optim.optim.state_dict()['state'].items()} \n dist_optim state dict {dist_optim.optim.state_dict()['state'].items()} \n") + # if not correctness: + # print(f"{correctness}\n p.data {p.data}\n tp_p.data{tp_p.data}\n") -@parameterize("dtype", [torch.float32]) # torch.float32, torch.float16, torch.bfloat16 -@parameterize("tp_zero_size", [(4, 1)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) +@parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 +@parameterize("tp_zero_size", [(1, 4), (4, 1), (2, 2)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int]): tp_size, zero_size = tp_zero_size local_rank = dist.get_rank() use_zero = True if zero_size > 1 else False + + clear_layout_converter() proc_mesh = ProcessGroupMesh(tp_size, zero_size) tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) @@ -526,7 +543,7 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int dist_optim.setup_distributed( tensor_parallel_group=tp_group, data_parallel_group=dp_group, - shard_to_param={}, + shard_to_param=shard_to_param, use_zero=use_zero, ) @@ -559,7 +576,7 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int base_optim.zero_grad() dist_optim.zero_grad() - + print(f"data type {dtype},tp size {tp_size}, dp size {zero_size}\n") for p, tp_p in zip(base_param_group, tp_param_group): param_is_distributed = is_distributed_tensor(tp_p) if param_is_distributed: @@ -575,11 +592,97 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int else: # TP bias tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + else: # No TP bias pass correctness = correctness_verify(p.data, tp_p.data, dtype) - print(f"Curr Param correct {correctness}") + +@parameterize( + "test_config", + [ + { + "tp_size": 1, + "num_microbatches": 4, + "zero_stage": 2, + "precision": "bf16", + }, + { + "tp_size": 2, + "num_microbatches": 4, + "zero_stage": 2, + "precision": "bf16", + }, + { + "tp_size": 4, + "num_microbatches": 4, + "zero_stage": 2, + "precision": "bf16", + }, + { + "tp_size": 1, + "num_microbatches": 4, + "zero_stage": 2, + "precision": "fp16", + }, + { + "tp_size": 2, + "num_microbatches": 4, + "zero_stage": 2, + "precision": "fp16", + }, + { + "tp_size": 4, + "num_microbatches": 4, + "zero_stage": 2, + "precision": "fp16", + }, + { + "tp_size": 2, + "num_microbatches": 4, + "zero_stage": 1, + "precision": "bf16", + } + ], +) +def exam_bert_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + test_config["use_lazy_init"] = False + test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel + test_config["initial_scale"] = 2**15 # avoid overflow + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + + if name == "transformers_bert": + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + bert = unwrap_model(org_model, "BertModel", "bert") + sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") + weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"] + + org_optimizer.step() + sharded_optimizer.step() + + # check weights + if test_config["precision"] == "bf16": + atol, rtol = 5e-4, 1e-4 + else: + atol, rtol = 5e-4, 5e-4 + if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): + check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) + clear_layout_converter() + torch.cuda.empty_cache() + def run_dist(rank, world_size, port): @@ -588,9 +691,8 @@ def run_dist(rank, world_size, port): # exam_dist_adafactor_base() # exam_dist_adafactor_fwd_bwd() exam_dist_adafactor_zero() - # exam_dist_adafactor_booster() - # run_bert_test(optim_class=Adafactor, sharded_optim_class=DistributedAdaFactor) - + exam_bert_test() + @pytest.mark.dist From ce58adfd276d505cda9b5da50ced0b8a7808b767 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 9 Apr 2024 18:01:59 +0800 Subject: [PATCH 05/35] [feature] Add transformers_bert model zoo in testcase; --- .../nn/optimizer/distributed_adafactor.py | 2 +- tests/test_optimizer/test_dist_adafactor.py | 45 +++++++++---------- 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index b1d313678ecb..794dda755a16 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -189,7 +189,7 @@ def step(self, closure=None): ) # [W/TP] if self.shard_spec.sharding_sequence[-1] == "R": # Row Parallel - # Row Residual situation + # Row indivisible shape situation if self.grad_shape[0] % self.data_parallel_size != 0: state["exp_avg_sq_row"] = torch.zeros( self.grad_shape[0], device=p.device, dtype=p.dtype diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index f5424ee1738a..4b7cb6c161df 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -653,35 +653,34 @@ def exam_bert_test(test_config): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name == "transformers_bert": - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( - model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor - ) + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor + ) - org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( - org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster - ) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) - stage_manager = booster.plugin.stage_manager - tp_group = booster.plugin.tp_group + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group - bert = unwrap_model(org_model, "BertModel", "bert") - sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") - weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"] + bert = unwrap_model(org_model, "BertModel", "bert") + sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") + weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"] - org_optimizer.step() - sharded_optimizer.step() + org_optimizer.step() + sharded_optimizer.step() - # check weights - if test_config["precision"] == "bf16": - atol, rtol = 5e-4, 1e-4 - else: - atol, rtol = 5e-4, 5e-4 - if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): - check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) - clear_layout_converter() - torch.cuda.empty_cache() + # check weights + if test_config["precision"] == "bf16": + atol, rtol = 5e-4, 1e-4 + else: + atol, rtol = 5e-4, 5e-4 + if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): + check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) + clear_layout_converter() + torch.cuda.empty_cache() From efac2a1854f611453c4390ecfc1f03af7c498f97 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 10 Apr 2024 14:30:11 +0800 Subject: [PATCH 06/35] [feature] add user documentation to docs/source/feature. --- colossalai/nn/optimizer/README.md | 4 +- .../nn/optimizer/distributed_adafactor.py | 1 - .../en/features/distributed_adafactor.md | 138 ++++++++++++++++++ 3 files changed, 141 insertions(+), 2 deletions(-) create mode 100644 docs/source/en/features/distributed_adafactor.md diff --git a/colossalai/nn/optimizer/README.md b/colossalai/nn/optimizer/README.md index 07c95143c74c..14de2e541508 100644 --- a/colossalai/nn/optimizer/README.md +++ b/colossalai/nn/optimizer/README.md @@ -89,7 +89,9 @@ A series of optimizers have been optimized and integrated. ### Distributed Adafactor -Distributed Adafactor supports tensor parallelism and ZerO optimization. +Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details. + +### Distributed Adafactor API ### Performance | Version | iter | Float Percision | Device Nums | weight shape | Avg runtime(ms) | Avg Speed Up Rate | Best Speed Up Rate | diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index 794dda755a16..c73ca46d8934 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -253,7 +253,6 @@ def step(self, closure=None): ) # view update to origin[tp] shape update_reshape = update.view(-1, self.grad_shape[1]) - # gather grad[flatten] along dp group then reshape to [H/tp, W] grad = _gather( input_=grad, dim=-1, process_group=self.data_parallel_group diff --git a/docs/source/en/features/distributed_adafactor.md b/docs/source/en/features/distributed_adafactor.md new file mode 100644 index 000000000000..d15f03bece84 --- /dev/null +++ b/docs/source/en/features/distributed_adafactor.md @@ -0,0 +1,138 @@ +# Distributed Adafactor + +Author: + +**Related Paper** +- [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235) + +## Introduction + +Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details. + + +## Performance + +| Parallel strategy | iter | Float Percision | Device Nums | weight shape | Avg runtime(ms) | Avg Speed Up Rate | Best Speed Up Rate | +| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | +| AdaFactor | 50 | float32 | 2 | [4096 , 4096] | 0.58 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.41 | 1.39 | 56.91 | +| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.61 | 0.96 | 18.69 | +| AdaFactor | 50 | float16 | 2 | [4096 , 4096] | 0.72 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.54 | 1.33 | 26.03 | +| DistAdaFactor(Row Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.67 | 1.08 | 20.55 | +| AdaFactor | 50 | bfloat16 | 2 | [4096 , 4096] | 0.72 | - | - | +| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.55 | 1.31 | 26.11 | +| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.67 | 1.07 | 21.86 | +| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | +| AdaFactor | 50 | float32 | 4 | [4096 , 4096] | 0.57 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.38 | 1.48 | 53.99 | +| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.60 | 0.95 | 16.53 | +| AdaFactor | 50 | float16 | 4 | [4096 , 4096] | 0.70 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.50 | 1.44 | 21.98 | +| DistAdaFactor(Row Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.64 | 1.12 | 15.35 | +| AdaFactor | 50 | bfloat16 | 4 | [4096 , 4096] | 0.72 | - | - | +| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.56 | 1.29 | 25.63 | +| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.71 | 1.09 | 21.52 | +| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | +| AdaFactor | 50 | float32 | 8 | [4096 , 4096] | 0.56 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.38 | 1.50 | 54.51 | +| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.91 | 0.67 | 15.68 | +| AdaFactor | 50 | float16 | 8 | [4096 , 4096] | 0.74 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float16 | 8 | [4096 , 4096] | 0.84 | 0.87 | 9.21 | +| DistAdaFactor(Row Parallel) | 50 | float16 | 8 | [4096 , 4096] | 1.03 | 0.75 | 16.12 | +| AdaFactor | 50 | bfloat16 | 8 | [4096 , 4096] | 0.71 | - | - | +| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.54 | 1.31 | 27.28 | +| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.73 | 1.03 | 25.01 | + + +## Hands-On Practice +We now demonstrate how to use Distributed Adafactor. +### step 1. Import libraries + +```python +import torch +from torch import nn +import torch.distributed as dist + +from colossalai.cluster import ProcessGroupMesh +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row +from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor + +``` + +### step 2. Initialize Distributed Environment and Parallism Group +We then need to initialize distributed environment. For demo purpose, we uses `colossalai.launch`. You can refer to [Launch Colossal-AI](../basics/launch_colossalai.md) +for other initialization methods. We use `ProcessGroupMesh` to create tensor parallelism group and data parallelism group. + +```python +# Distributed Enviroment +config = {} +colossalai.launch(config=config, rank=rank, world_size=world_size,host="localhost", port=port, backend="nccl") + +# Parallism Group +tp_size, zero_size = 2, 2 +use_zero = True if zero_size > 1 else False +proc_mesh = ProcessGroupMesh(tp_size, zero_size) +tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) +``` + +### step 3. Initialize Module +Build our model. We created an MLP using two Linear Layer. + +```python +# Init a Tensor Paralleled Module +class TPModel(nn.Module): + def __init__(self, linear1, linear2, tp_group=None): + super().__init__() + self.linear1 = Linear1D_Col.from_native_module( + linear1, process_group=tp_group, gather_output=False, overlap=True + ) + self.linear2 = Linear1D_Row.from_native_module(linear2, process_group=tp_group, parallel_input=True) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x +HEIGHT = 4096 +WIDTH = 4096 +tp_model = TPModel(copy.deepcopy(nn.Linear(HEIGHT, WIDTH)), copy.deepcopy(nn.Linear(HEIGHT, WIDTH)), tp_group).to(local_rank) + +# Get Module parameter +tp_param_group = [p for n, p in tp_model.named_parameters()] +``` + +### step 4. Initialize Optimizer +Then, We initialise the optimiser using the model parameter. Then, we set the distributed information for optimiser. + +```python +# Init a Optimizer +dist_optim = DistributedAdaFactor(tp_param_group) +shard_to_param = {id(p):p for p in tp_param_group} + +# Setup distributed information for Optimizer +dist_optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=dp_group, + shard_to_param=shard_to_param, + use_zero=use_zero, +) +``` + +### step 5. Perform a forward and backward propagation for model and step the gradient + +```python +# Random initialise dataset +x = torch.randn(HEIGHT, WIDTH, device=local_rank) + +# Fwd and Bwd +out_tp = tp_model(x) +if zero_size > 1: + dist_optim.backward(out_tp.sum()) +else: + out_tp.sum().backward() + +# perform step for param and grad +dist_optim.step() +dist_optim.zero_grad() +``` + From 40a55287f1aee0944ac80d89557d30a9716a9cb1 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 10 Apr 2024 14:58:50 +0800 Subject: [PATCH 07/35] [feature] add API Reference & Sample to optimizer Readme; add state check for bert exam; --- colossalai/nn/optimizer/README.md | 26 ++++++++++++++++++- .../nn/optimizer/distributed_adafactor.py | 9 +++---- .../en/features/distributed_adafactor.md | 2 +- tests/test_optimizer/test_dist_adafactor.py | 8 +++++- 4 files changed, 36 insertions(+), 9 deletions(-) diff --git a/colossalai/nn/optimizer/README.md b/colossalai/nn/optimizer/README.md index 14de2e541508..fdf8f128cf0a 100644 --- a/colossalai/nn/optimizer/README.md +++ b/colossalai/nn/optimizer/README.md @@ -91,7 +91,31 @@ A series of optimizers have been optimized and integrated. Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details. -### Distributed Adafactor API +### API Reference + +{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} + +### Sample: Init with booster + +```python +# ============================== +# Model Init +# ============================== +tp_model = TPModel() + +# ============================== +# Optimizer Init +# ============================== +dist_optim = DistributedAdaFactor([p for n, p in tp_model.named_parameters()]) + +# ============================== +# Booster Init +# ============================== +plugin = TorchDDPPlugin() +booster = Booster(plugin=plugin) +criterion = lambda x: x.mean() +tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion) +``` ### Performance | Version | iter | Float Percision | Device Nums | weight shape | Avg runtime(ms) | Avg Speed Up Rate | Best Speed Up Rate | diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index c73ca46d8934..8fa90c6cb66d 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -27,7 +27,7 @@ def __init__( relative_step=True, warmup_init=False, ): - lr=None + lr = None if lr is not None and relative_step: raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") if warmup_init and not relative_step: @@ -162,17 +162,14 @@ def step(self, closure=None): grad = p.grad if grad.is_sparse: raise RuntimeError("Adafactor does not support sparse gradients.") + state = self.state[p] self.grad_shape = grad.shape # 1 dim shape - - # print(f"self.shard_to_param {self.shard_to_param}") - param_is_dtensor = is_distributed_tensor(self.shard_to_param.get(id(p))) - if param_is_dtensor: self.grad_shape = self.shard_to_param.get(id(p)).shape # tp shape (2 dim) - self.factored, self.use_first_moment = self._get_options(group, self.grad_shape) + if len(state) == 0: state["step"] = 0 if self.use_first_moment: diff --git a/docs/source/en/features/distributed_adafactor.md b/docs/source/en/features/distributed_adafactor.md index d15f03bece84..fbddf7d7d2be 100644 --- a/docs/source/en/features/distributed_adafactor.md +++ b/docs/source/en/features/distributed_adafactor.md @@ -135,4 +135,4 @@ else: dist_optim.step() dist_optim.zero_grad() ``` - + diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index 4b7cb6c161df..0813c0594b7e 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -31,7 +31,7 @@ from colossalai.utils import set_seed from colossalai.zero import LowLevelZeroOptimizer from tests.kit.model_zoo import model_zoo -from tests.test_optimizer._utils import run_bert_test +from tests.test_optimizer._utils import run_bert_test, check_optim_states from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, check_weight, @@ -39,6 +39,7 @@ unwrap_model, ) + HEIGHT = 4 WIDTH = 4 _TP_SPEC = DimSpec([0]) @@ -679,6 +680,11 @@ def exam_bert_test(test_config): atol, rtol = 5e-4, 5e-4 if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) + + # check optim states + check_optim_states(org_optimizer, sharded_optimizer.optim) + + clear_layout_converter() torch.cuda.empty_cache() From 1c9bb930b59d0f363a1b749e592b2bed5d1c1d51 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 10 Apr 2024 15:22:18 +0800 Subject: [PATCH 08/35] [feature] modify user documentation; --- colossalai/nn/optimizer/README.md | 68 ------------- .../en/features/distributed_adafactor.md | 95 ++++++++++++------- 2 files changed, 61 insertions(+), 102 deletions(-) diff --git a/colossalai/nn/optimizer/README.md b/colossalai/nn/optimizer/README.md index fdf8f128cf0a..d3f8badc7313 100644 --- a/colossalai/nn/optimizer/README.md +++ b/colossalai/nn/optimizer/README.md @@ -81,71 +81,3 @@ If you wish to add an optimizer for a specific application, please follow the st If your PR is accepted, we may invite you to put up a tutorial or blog in [ColossalAI Documentation](https://colossalai.org/). - - -## Optimizer - -A series of optimizers have been optimized and integrated. - -### Distributed Adafactor - -Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details. - -### API Reference - -{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} - -### Sample: Init with booster - -```python -# ============================== -# Model Init -# ============================== -tp_model = TPModel() - -# ============================== -# Optimizer Init -# ============================== -dist_optim = DistributedAdaFactor([p for n, p in tp_model.named_parameters()]) - -# ============================== -# Booster Init -# ============================== -plugin = TorchDDPPlugin() -booster = Booster(plugin=plugin) -criterion = lambda x: x.mean() -tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion) -``` - -### Performance -| Version | iter | Float Percision | Device Nums | weight shape | Avg runtime(ms) | Avg Speed Up Rate | Best Speed Up Rate | -| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | -| AdaFactor | 50 | float32 | 2 | [4096 , 4096] | 0.58 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.41 | 1.39 | 56.91 | -| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.61 | 0.96 | 18.69 | -| AdaFactor | 50 | float16 | 2 | [4096 , 4096] | 0.72 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.54 | 1.33 | 26.03 | -| DistAdaFactor(Row Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.67 | 1.08 | 20.55 | -| AdaFactor | 50 | bfloat16 | 2 | [4096 , 4096] | 0.72 | - | - | -| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.55 | 1.31 | 26.11 | -| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.67 | 1.07 | 21.86 | -| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | -| AdaFactor | 50 | float32 | 4 | [4096 , 4096] | 0.57 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.38 | 1.48 | 53.99 | -| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.60 | 0.95 | 16.53 | -| AdaFactor | 50 | float16 | 4 | [4096 , 4096] | 0.70 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.50 | 1.44 | 21.98 | -| DistAdaFactor(Row Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.64 | 1.12 | 15.35 | -| AdaFactor | 50 | bfloat16 | 4 | [4096 , 4096] | 0.72 | - | - | -| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.56 | 1.29 | 25.63 | -| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.71 | 1.09 | 21.52 | -| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | -| AdaFactor | 50 | float32 | 8 | [4096 , 4096] | 0.56 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.38 | 1.50 | 54.51 | -| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.91 | 0.67 | 15.68 | -| AdaFactor | 50 | float16 | 8 | [4096 , 4096] | 0.74 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float16 | 8 | [4096 , 4096] | 0.84 | 0.87 | 9.21 | -| DistAdaFactor(Row Parallel) | 50 | float16 | 8 | [4096 , 4096] | 1.03 | 0.75 | 16.12 | -| AdaFactor | 50 | bfloat16 | 8 | [4096 , 4096] | 0.71 | - | - | -| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.54 | 1.31 | 27.28 | -| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.73 | 1.03 | 25.01 | diff --git a/docs/source/en/features/distributed_adafactor.md b/docs/source/en/features/distributed_adafactor.md index fbddf7d7d2be..b607e5d0aa9e 100644 --- a/docs/source/en/features/distributed_adafactor.md +++ b/docs/source/en/features/distributed_adafactor.md @@ -9,41 +9,9 @@ Author: Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details. +### API Reference -## Performance - -| Parallel strategy | iter | Float Percision | Device Nums | weight shape | Avg runtime(ms) | Avg Speed Up Rate | Best Speed Up Rate | -| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | -| AdaFactor | 50 | float32 | 2 | [4096 , 4096] | 0.58 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.41 | 1.39 | 56.91 | -| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.61 | 0.96 | 18.69 | -| AdaFactor | 50 | float16 | 2 | [4096 , 4096] | 0.72 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.54 | 1.33 | 26.03 | -| DistAdaFactor(Row Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.67 | 1.08 | 20.55 | -| AdaFactor | 50 | bfloat16 | 2 | [4096 , 4096] | 0.72 | - | - | -| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.55 | 1.31 | 26.11 | -| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.67 | 1.07 | 21.86 | -| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | -| AdaFactor | 50 | float32 | 4 | [4096 , 4096] | 0.57 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.38 | 1.48 | 53.99 | -| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.60 | 0.95 | 16.53 | -| AdaFactor | 50 | float16 | 4 | [4096 , 4096] | 0.70 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.50 | 1.44 | 21.98 | -| DistAdaFactor(Row Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.64 | 1.12 | 15.35 | -| AdaFactor | 50 | bfloat16 | 4 | [4096 , 4096] | 0.72 | - | - | -| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.56 | 1.29 | 25.63 | -| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.71 | 1.09 | 21.52 | -| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | -| AdaFactor | 50 | float32 | 8 | [4096 , 4096] | 0.56 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.38 | 1.50 | 54.51 | -| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.91 | 0.67 | 15.68 | -| AdaFactor | 50 | float16 | 8 | [4096 , 4096] | 0.74 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float16 | 8 | [4096 , 4096] | 0.84 | 0.87 | 9.21 | -| DistAdaFactor(Row Parallel) | 50 | float16 | 8 | [4096 , 4096] | 1.03 | 0.75 | 16.12 | -| AdaFactor | 50 | bfloat16 | 8 | [4096 , 4096] | 0.71 | - | - | -| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.54 | 1.31 | 27.28 | -| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.73 | 1.03 | 25.01 | - +{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} ## Hands-On Practice We now demonstrate how to use Distributed Adafactor. @@ -135,4 +103,63 @@ else: dist_optim.step() dist_optim.zero_grad() ``` + +## Run with booster +We highly recommend users to use booster, a simple, easy to use, and efficient interface. The Code Below is the Distributed Adafactor launched with booster. + +```python +# ============================== +# Model Init +# ============================== +tp_model = TPModel() + +# ============================== +# Optimizer Init +# ============================== +dist_optim = DistributedAdaFactor([p for n, p in tp_model.named_parameters()]) + +# ============================== +# Booster Init +# ============================== +plugin = TorchDDPPlugin() +booster = Booster(plugin=plugin) +criterion = lambda x: x.mean() +tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion) +``` + +## Performance + +| Parallel strategy | iter | Float Percision | Device Nums | weight shape | Avg runtime(ms) | Avg Speed Up Rate | Best Speed Up Rate | +| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | +| AdaFactor | 50 | float32 | 2 | [4096 , 4096] | 0.58 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.41 | 1.39 | 56.91 | +| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.61 | 0.96 | 18.69 | +| AdaFactor | 50 | float16 | 2 | [4096 , 4096] | 0.72 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.54 | 1.33 | 26.03 | +| DistAdaFactor(Row Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.67 | 1.08 | 20.55 | +| AdaFactor | 50 | bfloat16 | 2 | [4096 , 4096] | 0.72 | - | - | +| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.55 | 1.31 | 26.11 | +| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.67 | 1.07 | 21.86 | +| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | +| AdaFactor | 50 | float32 | 4 | [4096 , 4096] | 0.57 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.38 | 1.48 | 53.99 | +| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.60 | 0.95 | 16.53 | +| AdaFactor | 50 | float16 | 4 | [4096 , 4096] | 0.70 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.50 | 1.44 | 21.98 | +| DistAdaFactor(Row Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.64 | 1.12 | 15.35 | +| AdaFactor | 50 | bfloat16 | 4 | [4096 , 4096] | 0.72 | - | - | +| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.56 | 1.29 | 25.63 | +| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.71 | 1.09 | 21.52 | +| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | +| AdaFactor | 50 | float32 | 8 | [4096 , 4096] | 0.56 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.38 | 1.50 | 54.51 | +| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.91 | 0.67 | 15.68 | +| AdaFactor | 50 | float16 | 8 | [4096 , 4096] | 0.74 | - | - | +| DistAdaFactor(Col Parallel) | 50 | float16 | 8 | [4096 , 4096] | 0.84 | 0.87 | 9.21 | +| DistAdaFactor(Row Parallel) | 50 | float16 | 8 | [4096 , 4096] | 1.03 | 0.75 | 16.12 | +| AdaFactor | 50 | bfloat16 | 8 | [4096 , 4096] | 0.71 | - | - | +| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.54 | 1.31 | 27.28 | +| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.73 | 1.03 | 25.01 | + + From 1039f3426aa3eaf8fc172c4cee35a3510e345cc7 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 10 Apr 2024 15:25:47 +0800 Subject: [PATCH 09/35] [fix] fix readme format issue; --- docs/source/en/features/distributed_adafactor.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/source/en/features/distributed_adafactor.md b/docs/source/en/features/distributed_adafactor.md index b607e5d0aa9e..3ba609283def 100644 --- a/docs/source/en/features/distributed_adafactor.md +++ b/docs/source/en/features/distributed_adafactor.md @@ -140,7 +140,6 @@ tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, crit | AdaFactor | 50 | bfloat16 | 2 | [4096 , 4096] | 0.72 | - | - | | DistAdaFactor(Col Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.55 | 1.31 | 26.11 | | DistAdaFactor(Row Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.67 | 1.07 | 21.86 | -| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | | AdaFactor | 50 | float32 | 4 | [4096 , 4096] | 0.57 | - | - | | DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.38 | 1.48 | 53.99 | | DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.60 | 0.95 | 16.53 | @@ -150,7 +149,6 @@ tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, crit | AdaFactor | 50 | bfloat16 | 4 | [4096 , 4096] | 0.72 | - | - | | DistAdaFactor(Col Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.56 | 1.29 | 25.63 | | DistAdaFactor(Row Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.71 | 1.09 | 21.52 | -| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | | AdaFactor | 50 | float32 | 8 | [4096 , 4096] | 0.56 | - | - | | DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.38 | 1.50 | 54.51 | | DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.91 | 0.67 | 15.68 | From 2ffca49595b4b6eb75c8260f7b6da63ce2758211 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 10 Apr 2024 20:05:07 +0800 Subject: [PATCH 10/35] [fix] add zero=0 in testcase; cached augment in dict; --- colossalai/nn/optimizer/adafactor.py | 13 +-- .../nn/optimizer/distributed_adafactor.py | 85 +++++++++------- tests/test_optimizer/test_dist_adafactor.py | 96 +++++++++++-------- 3 files changed, 109 insertions(+), 85 deletions(-) diff --git a/colossalai/nn/optimizer/adafactor.py b/colossalai/nn/optimizer/adafactor.py index 0cedbb2512be..0a9230d7c585 100644 --- a/colossalai/nn/optimizer/adafactor.py +++ b/colossalai/nn/optimizer/adafactor.py @@ -120,8 +120,6 @@ def step(self, closure=None): # grad shape is same as weigh / bias """ grad = p.grad - if grad.dtype in {torch.float16, torch.bfloat16}: - grad = grad.float() if grad.is_sparse: raise RuntimeError("Adafactor does not support sparse gradients.") @@ -168,10 +166,6 @@ def step(self, closure=None): else: state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) - p_data_fp32 = p - if p.dtype in {torch.float16, torch.bfloat16}: - p_data_fp32 = p_data_fp32.float() - state["step"] += 1 # state["RMS"] = self._rms(p_data_fp32) lr = self._get_lr(group, state) @@ -201,9 +195,8 @@ def step(self, closure=None): update = exp_avg if group["weight_decay"] != 0: - p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) - p_data_fp32.add_(-update) - if p.dtype in {torch.float16, torch.bfloat16}: - p.copy_(p_data_fp32) + p.add_(p, alpha=(-group["weight_decay"] * lr)) + p.add_(-update) + return loss diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index 8fa90c6cb66d..ba9caad29739 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -49,11 +49,13 @@ def __init__( self.data_parallel_size = 1 self.data_parallel_group = None self.shard_to_param = None # Dict{id:shape}, sample {id(param): torch.tensor} - self.shard_spec = None - self.grad_shape = None - self.factored = None # bool - self.use_first_moment = None # bool self.use_zero = True + + self.param_is_dtensor_dict = {} # {id(p): True/False} + self.grad_shape_dict = {} # {id(p): master param shape} + self.factored_dict = {} # {id(p): True/False} + self.use_first_moment_dict = {} # {id(p): True/False} + self.shard_spec_dict = {} # {id(p): ShardSpec} super().__init__(params, defaults) @@ -84,8 +86,21 @@ def setup_distributed( self.use_zero = use_zero self.shard_to_param = shard_to_param if shard_to_param is not None else {} - - + # grad is None, cause we dont setup now + for group in self.param_groups: + for p in group["params"]: + self.param_is_dtensor_dict[id(p)] = is_distributed_tensor(self.shard_to_param.get(id(p))) + if self.param_is_dtensor_dict[id(p)]: + self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape + else: + self.grad_shape_dict[id(p)] = p.shape + self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] = self._get_options(group, self.grad_shape_dict[id(p)]) + # if self.factored_dict[id(p)]: + if self.param_is_dtensor_dict[id(p)]: + self.shard_spec_dict[id(p)] = get_sharding_spec(self.shard_to_param.get(id(p))) + else: + self.shard_spec_dict[id(p)] = None + @staticmethod def _get_lr(param_group, param_state): rel_step_sz = param_group["lr"] @@ -164,48 +179,47 @@ def step(self, closure=None): raise RuntimeError("Adafactor does not support sparse gradients.") state = self.state[p] - self.grad_shape = grad.shape # 1 dim shape - param_is_dtensor = is_distributed_tensor(self.shard_to_param.get(id(p))) + grad_shape = self.grad_shape_dict[id(p)] + param_is_dtensor = self.param_is_dtensor_dict[id(p)] if param_is_dtensor: - self.grad_shape = self.shard_to_param.get(id(p)).shape # tp shape (2 dim) - self.factored, self.use_first_moment = self._get_options(group, self.grad_shape) - + grad_shape = self.shard_to_param.get(id(p)).shape # tp shape (2 dim) + factored, use_first_moment = self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] + shard_spec = self.shard_spec_dict[id(p)] if len(state) == 0: state["step"] = 0 - if self.use_first_moment: + if use_first_moment: # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(p) - if self.factored: - self.shard_spec = get_sharding_spec(self.shard_to_param.get(id(p))) - if self.shard_spec.sharding_sequence[0] == "R": # Col Parallel + if factored and param_is_dtensor: + if shard_spec.sharding_sequence[0] == "R": # Col Parallel state["exp_avg_sq_row"] = torch.zeros( - self.grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype + grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype ) # [H/dp] state["exp_avg_sq_col"] = torch.zeros( - self.grad_shape[1], device=p.device, dtype=p.dtype + grad_shape[1], device=p.device, dtype=p.dtype ) # [W/TP] - if self.shard_spec.sharding_sequence[-1] == "R": # Row Parallel + if shard_spec.sharding_sequence[-1] == "R": # Row Parallel # Row indivisible shape situation - if self.grad_shape[0] % self.data_parallel_size != 0: + if grad_shape[0] % self.data_parallel_size != 0: state["exp_avg_sq_row"] = torch.zeros( - self.grad_shape[0], device=p.device, dtype=p.dtype + grad_shape[0], device=p.device, dtype=p.dtype ) # [H/dp/Tp] else: state["exp_avg_sq_row"] = torch.zeros( - self.grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype + grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype ) # [H/dp/Tp] state["exp_avg_sq_col"] = torch.zeros( - self.grad_shape[1], device=p.device, dtype=p.dtype + grad_shape[1], device=p.device, dtype=p.dtype ) # [W] else: state["exp_avg_sq"] = torch.zeros_like(p) state["RMS"] = 0 else: - if self.use_first_moment: + if use_first_moment: state["exp_avg"] = state["exp_avg"].to(grad) - if self.factored: + if factored: state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) else: @@ -216,15 +230,14 @@ def step(self, closure=None): beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) update = (grad**2) + group["eps"][0] - if self.factored: + if factored and param_is_dtensor: # ============================== # First Dim is R, Last Dim is S{} means split dim -1 ---> # Coloum Parallel ---> sq_row need Do (col) Reduce # ============================== - self.shard_spec = get_sharding_spec(self.shard_to_param.get(id(p))) - if self.shard_spec.sharding_sequence[0] == "R": - update_reshape = update.view(-1, self.grad_shape[1]) - grad_reshape = grad.view(-1, self.grad_shape[1]) + if shard_spec.sharding_sequence[0] == "R": + update_reshape = update.view(-1, grad_shape[1]) + grad_reshape = grad.view(-1, grad_shape[1]) exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp] exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp] exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) @@ -241,20 +254,20 @@ def step(self, closure=None): # Last Dim is R, First Dim is S{} means split dim 0 ---> # Row Parallel ---> sq_col need Do (row) Reduce # ============================== - elif self.shard_spec.sharding_sequence[-1] == "R": + elif shard_spec.sharding_sequence[-1] == "R": # Row Residual situation - if self.grad_shape[0] % self.data_parallel_size != 0: + if grad_shape[0] % self.data_parallel_size != 0: # gather update[flatten] along dp group then reshape to [H/tp, W] update = _gather( input_=update, dim=-1, process_group=self.data_parallel_group ) # view update to origin[tp] shape - update_reshape = update.view(-1, self.grad_shape[1]) + update_reshape = update.view(-1, grad_shape[1]) # gather grad[flatten] along dp group then reshape to [H/tp, W] grad = _gather( input_=grad, dim=-1, process_group=self.data_parallel_group ) - grad_reshape = grad.view(-1, self.grad_shape[1]) + grad_reshape = grad.view(-1, grad_shape[1]) exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp] exp_avg_sq_col = state["exp_avg_sq_col"] # [W] exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) @@ -269,8 +282,8 @@ def step(self, closure=None): else: update = update_reshape else: - update_reshape = update.view(-1, self.grad_shape[1]) - grad_reshape = grad.view(-1, self.grad_shape[1]) + update_reshape = update.view(-1, grad_shape[1]) + grad_reshape = grad.view(-1, grad_shape[1]) exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp/tp] exp_avg_sq_col = state["exp_avg_sq_col"] # [W] exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) @@ -297,7 +310,7 @@ def step(self, closure=None): # (Line No.8) RMS # update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) update.mul_(lr) - if self.use_first_moment: + if use_first_moment: exp_avg = state["exp_avg"] exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) update = exp_avg diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index 0813c0594b7e..218eebd0d4da 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -38,6 +38,7 @@ run_forward_backward_with_hybrid_plugin, unwrap_model, ) +from colossalai.shardformer.layer.utils import Randomizer HEIGHT = 4 @@ -370,7 +371,7 @@ def exam_dist_adafactor_fwd_bwd(dtype: torch.dtype, tp_zero_size: tuple[int, int @parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 -@parameterize("tp_zero_size", [(1, 4), (4, 1), (2, 2)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) +@parameterize("tp_zero_size", [(2, 2), (4, 1),(1, 4)]) # (2, 2), (4, 1),(1, 4), def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): tp_size, zero_size = tp_zero_size use_zero = True if zero_size > 1 else False @@ -456,7 +457,7 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): base_optim.zero_grad() dist_optim.zero_grad() - print(f"data type {dtype},tp size {tp_size}, dp size {zero_size}\n") + # print(f"data type {dtype},tp size {tp_size}, dp size {zero_size}\n") for p, tp_p in zip(base_param_group, tp_param_group): param_is_distributed = is_distributed_tensor(tp_p) if param_is_distributed: @@ -478,10 +479,6 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): pass correctness = correctness_verify(p.data, tp_p.data, dtype) - # print(f"Curr Param correct {correctness}") - # if not correctness: - # print(f"{correctness}\n p.data {p.data}\n tp_p.data{tp_p.data}\n") - @parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 @parameterize("tp_zero_size", [(1, 4), (4, 1), (2, 2)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) @@ -643,7 +640,14 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int "num_microbatches": 4, "zero_stage": 1, "precision": "bf16", - } + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 4, + "zero_stage": 0, + "precision": "bf16", + }, ], ) def exam_bert_test(test_config): @@ -651,42 +655,56 @@ def exam_bert_test(test_config): test_config["use_lazy_init"] = False test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel test_config["initial_scale"] = 2**15 # avoid overflow + model_list = [ + "transformers_bert" + "transformers_bert_for_pretraining" + "transformers_bert_lm_head_model" + "transformers_bert_for_masked_lm" + "transformers_bert_for_sequence_classification" + # "transformers_bert_for_token_classification" + "transformers_bert_for_next_sentence" + "transformers_bert_for_mcq" + "transformers_bert_for_question_answering" + ] for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( - model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor - ) - - org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( - org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster - ) - - - stage_manager = booster.plugin.stage_manager - tp_group = booster.plugin.tp_group - - bert = unwrap_model(org_model, "BertModel", "bert") - sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") - weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"] - - org_optimizer.step() - sharded_optimizer.step() - - # check weights - if test_config["precision"] == "bf16": - atol, rtol = 5e-4, 1e-4 - else: - atol, rtol = 5e-4, 5e-4 - if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): - check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) - - # check optim states - check_optim_states(org_optimizer, sharded_optimizer.optim) - + + if name in model_list: + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + bert = unwrap_model(org_model, "BertModel", "bert") + sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") + weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"] + + org_optimizer.step() + sharded_optimizer.step() + + # check weights + if test_config["precision"] == "bf16": + atol, rtol = 5e-4, 5e-4 + else: + atol, rtol = 5e-4, 5e-4 + if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): + check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) + print(f"{name} check pass") + # check optim states + check_optim_states(org_optimizer, sharded_optimizer.optim) + clear_layout_converter() + Randomizer.reset_index() torch.cuda.empty_cache() + @@ -695,7 +713,7 @@ def run_dist(rank, world_size, port): colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # exam_dist_adafactor_base() # exam_dist_adafactor_fwd_bwd() - exam_dist_adafactor_zero() + # exam_dist_adafactor_zero() exam_bert_test() From 0fd62a022d380d55d757a5ef7cc69edfd4319224 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 11 Apr 2024 11:13:43 +0800 Subject: [PATCH 11/35] [fix] fix percision issue; --- tests/test_optimizer/test_dist_adafactor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index 218eebd0d4da..a98fca5c47d2 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -654,14 +654,14 @@ def exam_bert_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") test_config["use_lazy_init"] = False test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel - test_config["initial_scale"] = 2**15 # avoid overflow + test_config["initial_scale"] = 2**10 # avoid overflow model_list = [ "transformers_bert" "transformers_bert_for_pretraining" "transformers_bert_lm_head_model" "transformers_bert_for_masked_lm" "transformers_bert_for_sequence_classification" - # "transformers_bert_for_token_classification" + "transformers_bert_for_token_classification" "transformers_bert_for_next_sentence" "transformers_bert_for_mcq" "transformers_bert_for_question_answering" @@ -713,7 +713,7 @@ def run_dist(rank, world_size, port): colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # exam_dist_adafactor_base() # exam_dist_adafactor_fwd_bwd() - # exam_dist_adafactor_zero() + exam_dist_adafactor_zero() exam_bert_test() From 28c3a409df4cadafb347f2f50d839cbaf45fdc43 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 11 Apr 2024 15:34:30 +0800 Subject: [PATCH 12/35] [feature] add distributed rms; --- colossalai/nn/optimizer/adafactor.py | 3 +- .../nn/optimizer/distributed_adafactor.py | 20 ++++++- tests/test_optimizer/test_dist_adafactor.py | 60 ++++++------------- 3 files changed, 36 insertions(+), 47 deletions(-) diff --git a/colossalai/nn/optimizer/adafactor.py b/colossalai/nn/optimizer/adafactor.py index 0a9230d7c585..3b7d998c9f28 100644 --- a/colossalai/nn/optimizer/adafactor.py +++ b/colossalai/nn/optimizer/adafactor.py @@ -186,7 +186,7 @@ def step(self, closure=None): exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) update = exp_avg_sq.rsqrt().mul_(grad) # RMS - # update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) update.mul_(lr) if use_first_moment: @@ -198,5 +198,4 @@ def step(self, closure=None): p.add_(p, alpha=(-group["weight_decay"] * lr)) p.add_(-update) - return loss diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index ba9caad29739..d8d28b89b152 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -306,9 +306,23 @@ def step(self, closure=None): exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) update = exp_avg_sq.rsqrt().mul_(grad) - + # (Line No.8) RMS - # update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + # perform a sum on each device + update_sum = update.pow(2).sum() + num_of_element = update.numel() + # reduce sum on tp group if exist + if self.tensor_parallel_size > 1 and param_is_dtensor: + dist.all_reduce(update_sum, group=self.tensor_parallel_group) + num_of_element = num_of_element * self.tensor_parallel_size + # reduce sum on dp group if exist + if self.data_parallel_size > 1 and param_is_dtensor: + dist.all_reduce(update_sum, group=self.data_parallel_group) + num_of_element = num_of_element * self.data_parallel_size + # div num of element + rms = (update_sum / num_of_element).sqrt() + update.div_((rms / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) if use_first_moment: exp_avg = state["exp_avg"] @@ -319,6 +333,6 @@ def step(self, closure=None): p.add_(p, alpha=(-group["weight_decay"] * lr)) p.add_(-update) - + return loss diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index a98fca5c47d2..c950047e9806 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -45,7 +45,6 @@ WIDTH = 4 _TP_SPEC = DimSpec([0]) - def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torch.dtype = torch.float32): rtol = None atol = None @@ -62,7 +61,6 @@ def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torc # return torch.all(tensor1.isclose(tensor2, rtol=rtol, atol=atol)) assert_close(tensor1, tensor2, rtol=rtol, atol=atol) - # setup param groups; (For zero test optim) def setup_param_groups_zero(model: nn.Module) -> list: no_decay = ["bias", "LayerNorm.weight"] @@ -78,13 +76,11 @@ def setup_param_groups_zero(model: nn.Module) -> list: ] return optimizer_grouped_parameters - # setup param groups; (For base optim) def setup_param_groups(model: nn.Module) -> list: optimizer_grouped_parameters = [p for n, p in model.named_parameters()] return optimizer_grouped_parameters - # setup flatten param groups, sharding spec and shape; (For dist optim) def setup_flatten_param_groups_sharding_spec_shape(model: nn.Module) -> dict: flatten_optimizer_grouped_parameters = [] @@ -100,8 +96,6 @@ def setup_flatten_param_groups_sharding_spec_shape(model: nn.Module) -> dict: else: sharding_spec[id(flatten_p)] = None param_shape[id(flatten_p)] = p.shape - # print(f"sharding_spec {sharding_spec}") - # print(f"param_shape {param_shape}") return flatten_optimizer_grouped_parameters, sharding_spec, param_shape @@ -140,12 +134,10 @@ def set_dist_grad( p.grad = p.data p.data = orig_p - def set_master_param_to_shard_param(master_param_list) -> dict: master_param_to_shard_param ={id(p):p for p in master_param_list} return master_param_to_shard_param - class MlpModel(nn.Module): def __init__(self): super(MlpModel, self).__init__() @@ -157,7 +149,6 @@ def forward(self, x): x = self.linear2(x) return x - class TPModel(nn.Module): def __init__(self, linear1, linear2, tp_group=None): super().__init__() @@ -171,9 +162,6 @@ def forward(self, x): x = self.linear2(x) return x - - - @parameterize("dtype", [torch.float32]) # , torch.float16, torch.bfloat16 @parameterize("tp_zero_size", [(4, 1)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): @@ -190,7 +178,7 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): # ============================== # Base Case # ============================== - H, W = 4096, 4096 + H, W = HEIGHT, WIDTH model_col = nn.Linear(H, W).to(local_rank) # Col parallel weight weight, bias = model_col.weight, model_col.bias @@ -286,7 +274,6 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): print(f"col corrness {col_correct} row correct {row_correct}") - @parameterize("dtype", [torch.float32]) # , torch.float16, torch.bfloat16 @parameterize("tp_zero_size", [(4, 1)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) def exam_dist_adafactor_fwd_bwd(dtype: torch.dtype, tp_zero_size: tuple[int, int]): @@ -349,7 +336,6 @@ def exam_dist_adafactor_fwd_bwd(dtype: torch.dtype, tp_zero_size: tuple[int, int param_is_distributed = is_distributed_tensor(tp_p) if param_is_distributed: shard_spec = get_sharding_spec(tp_p) - # print(f"device {local_rank} shard spec{shard_spec} len {len(shard_spec.sharding_sequence)}\n") if len(shard_spec.sharding_sequence) >= 2: # Col Parallel if shard_spec.sharding_sequence[0] == "R": @@ -365,13 +351,9 @@ def exam_dist_adafactor_fwd_bwd(dtype: torch.dtype, tp_zero_size: tuple[int, int # No TP bias pass correctness = correctness_verify(p.data, tp_p.data, dtype) - # print(f"{correctness}\n p.data {p.data}\n tp_p.data{tp_p.data}\n") - # print(f"Curr Param correct {correctness}") - # print(f"device {local_rank} base_optim state dict {base_optim.optim.state_dict()['state'].items()} \n dist_optim state dict {dist_optim.optim.state_dict()['state'].items()} \n") - @parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 -@parameterize("tp_zero_size", [(2, 2), (4, 1),(1, 4)]) # (2, 2), (4, 1),(1, 4), +@parameterize("tp_zero_size", [(2, 2), (4, 1), (1, 4)]) # (2, 2), (4, 1), (1, 4) def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): tp_size, zero_size = tp_zero_size use_zero = True if zero_size > 1 else False @@ -457,12 +439,10 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): base_optim.zero_grad() dist_optim.zero_grad() - # print(f"data type {dtype},tp size {tp_size}, dp size {zero_size}\n") for p, tp_p in zip(base_param_group, tp_param_group): param_is_distributed = is_distributed_tensor(tp_p) if param_is_distributed: shard_spec = get_sharding_spec(tp_p) - # print(f"device {local_rank} shard spec{shard_spec} len {len(shard_spec.sharding_sequence)}\n") if len(shard_spec.sharding_sequence) >= 2: # Col Parallel if shard_spec.sharding_sequence[0] == "R": @@ -478,8 +458,10 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): # No TP bias pass correctness = correctness_verify(p.data, tp_p.data, dtype) + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() - @parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 @parameterize("tp_zero_size", [(1, 4), (4, 1), (2, 2)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int]): @@ -574,12 +556,10 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int base_optim.zero_grad() dist_optim.zero_grad() - print(f"data type {dtype},tp size {tp_size}, dp size {zero_size}\n") for p, tp_p in zip(base_param_group, tp_param_group): param_is_distributed = is_distributed_tensor(tp_p) if param_is_distributed: shard_spec = get_sharding_spec(tp_p) - # print(f"device {local_rank} shard spec{shard_spec} len {len(shard_spec.sharding_sequence)}\n") if len(shard_spec.sharding_sequence) >= 2: # Col Parallel if shard_spec.sharding_sequence[0] == "R": @@ -656,20 +636,21 @@ def exam_bert_test(test_config): test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel test_config["initial_scale"] = 2**10 # avoid overflow model_list = [ - "transformers_bert" - "transformers_bert_for_pretraining" - "transformers_bert_lm_head_model" - "transformers_bert_for_masked_lm" - "transformers_bert_for_sequence_classification" - "transformers_bert_for_token_classification" - "transformers_bert_for_next_sentence" - "transformers_bert_for_mcq" - "transformers_bert_for_question_answering" + "transformers_bert", + "transformers_bert_for_pretraining", + "transformers_bert_lm_head_model", + "transformers_bert_for_masked_lm", + "transformers_bert_for_sequence_classification", + "transformers_bert_for_token_classification", + "transformers_bert_for_next_sentence", + "transformers_bert_for_mcq", + "transformers_bert_for_question_answering", ] for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - + print(f"model name {name} {name in model_list}") if name in model_list: + print(f"{name} check start") org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor ) @@ -696,28 +677,23 @@ def exam_bert_test(test_config): atol, rtol = 5e-4, 5e-4 if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) - print(f"{name} check pass") # check optim states check_optim_states(org_optimizer, sharded_optimizer.optim) + print(f"{name} check pass") clear_layout_converter() Randomizer.reset_index() torch.cuda.empty_cache() - - - def run_dist(rank, world_size, port): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # exam_dist_adafactor_base() # exam_dist_adafactor_fwd_bwd() - exam_dist_adafactor_zero() + # exam_dist_adafactor_zero() exam_bert_test() - - @pytest.mark.dist @rerun_if_address_is_in_use() def test_dist_adafactor(): From a9c5bf7edb6595dc108d8d4bed9761d270b0109b Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 11 Apr 2024 15:47:05 +0800 Subject: [PATCH 13/35] [feature] remove useless comment in testcase; --- tests/test_optimizer/test_dist_adafactor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index c950047e9806..764087635e2f 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -691,7 +691,7 @@ def run_dist(rank, world_size, port): colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # exam_dist_adafactor_base() # exam_dist_adafactor_fwd_bwd() - # exam_dist_adafactor_zero() + exam_dist_adafactor_zero() exam_bert_test() @pytest.mark.dist From 150ac1975f7e472b661373abd9b13036e4612514 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 11 Apr 2024 17:13:28 +0800 Subject: [PATCH 14/35] [fix] Remove useless test; open zero test; remove fp16 test in bert exam; --- .../nn/optimizer/distributed_adafactor.py | 1 - tests/test_optimizer/test_dist_adafactor.py | 233 +----------------- 2 files changed, 11 insertions(+), 223 deletions(-) diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index d8d28b89b152..fb4f09131227 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -95,7 +95,6 @@ def setup_distributed( else: self.grad_shape_dict[id(p)] = p.shape self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] = self._get_options(group, self.grad_shape_dict[id(p)]) - # if self.factored_dict[id(p)]: if self.param_is_dtensor_dict[id(p)]: self.shard_spec_dict[id(p)] = get_sharding_spec(self.shard_to_param.get(id(p))) else: diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index 764087635e2f..17da80521cbd 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -162,8 +162,8 @@ def forward(self, x): x = self.linear2(x) return x -@parameterize("dtype", [torch.float32]) # , torch.float16, torch.bfloat16 -@parameterize("tp_zero_size", [(4, 1)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) +@parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 +@parameterize("tp_zero_size", [(4, 1)]) def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): tp_size, zero_size = tp_zero_size local_rank = dist.get_rank() @@ -202,20 +202,20 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): ) # flatten input(not dtensor) to optimizer bias_row_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True)) - base_param_group = setup_param_groups([weight, bias]) - cp_param_group = setup_param_groups([weight_col_shard_flatten, bias_col_flatten]) - rp_param_group = setup_param_groups([weight_row_shard_flatten, bias_row_flatten]) + # base_param_group = setup_param_groups([weight, bias]) + # cp_param_group = setup_param_groups([weight_col_shard_flatten, bias_col_flatten]) + # rp_param_group = setup_param_groups([weight_row_shard_flatten, bias_row_flatten]) # ============================== # Init Optimizer # ============================== # base - optimizer_base = Adafactor(base_param_group) - cp_dist_optim = DistributedAdaFactor(cp_param_group) - rp_dist_optim = DistributedAdaFactor(rp_param_group) + optimizer_base = Adafactor([weight, bias]) + cp_dist_optim = DistributedAdaFactor([weight_col_shard_flatten, bias_col_flatten]) + rp_dist_optim = DistributedAdaFactor([weight_row_shard_flatten, bias_row_flatten]) - shard_to_param_cp = set_master_param_to_shard_param(cp_dist_optim) + shard_to_param_cp = set_master_param_to_shard_param([weight_col_shard_flatten, bias_col_flatten]) cp_dist_optim.setup_distributed( tensor_parallel_group=tp_group, data_parallel_group=dp_group, @@ -223,7 +223,7 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): use_zero=use_zero, ) - shard_to_param_rp = set_master_param_to_shard_param(rp_dist_optim) + shard_to_param_rp = set_master_param_to_shard_param([weight_row_shard_flatten, bias_row_flatten]) rp_dist_optim.setup_distributed( tensor_parallel_group=tp_group, data_parallel_group=dp_group, @@ -274,84 +274,6 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): print(f"col corrness {col_correct} row correct {row_correct}") -@parameterize("dtype", [torch.float32]) # , torch.float16, torch.bfloat16 -@parameterize("tp_zero_size", [(4, 1)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) -def exam_dist_adafactor_fwd_bwd(dtype: torch.dtype, tp_zero_size: tuple[int, int]): - tp_size, zero_size = tp_zero_size - use_zero = True if zero_size > 1 else False - local_rank = dist.get_rank() - - clear_layout_converter() - proc_mesh = ProcessGroupMesh(tp_size, zero_size) - tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) - - torch.set_default_dtype(dtype) - set_seed(42) - - # ============================== - # Model Init - # ============================== - base_model = MlpModel().to(local_rank) - tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) - - base_param_group = setup_param_groups(base_model) - tp_param_group = setup_param_groups(tp_model) - - # ============================== - # Optimizer Init - # ============================== - base_optim = Adafactor(base_param_group) - dist_optim = DistributedAdaFactor(tp_param_group) - - shard_to_param = set_master_param_to_shard_param(tp_param_group) - dist_optim.setup_distributed( - tensor_parallel_group=tp_group, - data_parallel_group=dp_group, - shard_to_param=shard_to_param, - use_zero=use_zero, - ) - - # ============================== - # Correctness Verify - # ============================== - x = torch.randn(HEIGHT, WIDTH, device=local_rank) - - out = base_model(x) - out_tp = tp_model(x) - - if zero_size > 1: - dist_optim.backward(out_tp.sum()) - base_optim.backward(out.sum()) - else: - out_tp.sum().backward() - out.sum().backward() - - base_optim.step() - dist_optim.step() - - base_optim.zero_grad() - dist_optim.zero_grad() - - for p, tp_p in zip(base_param_group, tp_param_group): - param_is_distributed = is_distributed_tensor(tp_p) - if param_is_distributed: - shard_spec = get_sharding_spec(tp_p) - if len(shard_spec.sharding_sequence) >= 2: - # Col Parallel - if shard_spec.sharding_sequence[0] == "R": - tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather - # ROW Parallel - if shard_spec.sharding_sequence[-1] == "R": - tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather - else: - # TP bias - tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather - - else: - # No TP bias - pass - correctness = correctness_verify(p.data, tp_p.data, dtype) - @parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 @parameterize("tp_zero_size", [(2, 2), (4, 1), (1, 4)]) # (2, 2), (4, 1), (1, 4) def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): @@ -462,120 +384,6 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): Randomizer.reset_index() torch.cuda.empty_cache() -@parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 -@parameterize("tp_zero_size", [(1, 4), (4, 1), (2, 2)]) # (2, 2), (4, 1),(1, 4), (2, 4), (4, 2) -def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int]): - tp_size, zero_size = tp_zero_size - local_rank = dist.get_rank() - use_zero = True if zero_size > 1 else False - - clear_layout_converter() - - proc_mesh = ProcessGroupMesh(tp_size, zero_size) - tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) - - torch.set_default_dtype(dtype) - set_seed(42) - - # ============================== - # Model Init - # ============================== - base_model = MlpModel().to(local_rank) - tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) - - base_param_group = setup_param_groups(base_model) - tp_param_group = setup_param_groups(tp_model) - - # ============================== - # Optimizer Init - # ============================== - base_optim = Adafactor(base_param_group) - dist_optim = DistributedAdaFactor(tp_param_group) - - # Setup distributed optimizer - if zero_size > 1: - base_optim = LowLevelZeroOptimizer( - base_optim, - overlap_communication=True, - initial_scale=128, - partition_grad=True, - dp_process_group=dp_group, - verbose=True, - ) - - dist_optim = LowLevelZeroOptimizer( - dist_optim, - overlap_communication=True, - initial_scale=128, - partition_grad=True, - dp_process_group=dp_group, - verbose=True, - ) - shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened - dist_optim.optim.setup_distributed( - tensor_parallel_group=tp_group, - data_parallel_group=dp_group, - shard_to_param=shard_to_param, - use_zero=use_zero, - ) - else: - shard_to_param = set_master_param_to_shard_param(tp_param_group) - dist_optim.setup_distributed( - tensor_parallel_group=tp_group, - data_parallel_group=dp_group, - shard_to_param=shard_to_param, - use_zero=use_zero, - ) - - # ============================== - # Booster Init - # ============================== - plugin = TorchDDPPlugin() - booster = Booster(plugin=plugin) - criterion = lambda x: x.mean() - - tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion) - - # ============================== - # Correctness Verify - # ============================== - x = torch.randn(HEIGHT, WIDTH, device=local_rank) - - out = base_model(x) - out_tp = tp_model(x) - - if zero_size > 1: - dist_optim.backward(out_tp.sum()) - base_optim.backward(out.sum()) - else: - out_tp.sum().backward() - out.sum().backward() - - base_optim.step() - dist_optim.step() - - base_optim.zero_grad() - dist_optim.zero_grad() - for p, tp_p in zip(base_param_group, tp_param_group): - param_is_distributed = is_distributed_tensor(tp_p) - if param_is_distributed: - shard_spec = get_sharding_spec(tp_p) - if len(shard_spec.sharding_sequence) >= 2: - # Col Parallel - if shard_spec.sharding_sequence[0] == "R": - tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather - # ROW Parallel - if shard_spec.sharding_sequence[-1] == "R": - tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather - else: - # TP bias - tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather - - else: - # No TP bias - pass - correctness = correctness_verify(p.data, tp_p.data, dtype) - @parameterize( "test_config", [ @@ -597,24 +405,6 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int "zero_stage": 2, "precision": "bf16", }, - { - "tp_size": 1, - "num_microbatches": 4, - "zero_stage": 2, - "precision": "fp16", - }, - { - "tp_size": 2, - "num_microbatches": 4, - "zero_stage": 2, - "precision": "fp16", - }, - { - "tp_size": 4, - "num_microbatches": 4, - "zero_stage": 2, - "precision": "fp16", - }, { "tp_size": 2, "num_microbatches": 4, @@ -689,8 +479,7 @@ def exam_bert_test(test_config): def run_dist(rank, world_size, port): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - # exam_dist_adafactor_base() - # exam_dist_adafactor_fwd_bwd() + exam_dist_adafactor_base() exam_dist_adafactor_zero() exam_bert_test() From e783599672f9898c86de97d3e1330f4f27b43f49 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 11 Apr 2024 18:01:17 +0800 Subject: [PATCH 15/35] [feature] Extract distributed rms function; --- .../nn/optimizer/distributed_adafactor.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index fb4f09131227..b67580fd0f28 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -125,8 +125,20 @@ def _get_options(param_group, param_shape): return factored, use_first_moment @staticmethod - def _rms(tensor): - return tensor.norm(2) / (tensor.numel() ** 0.5) + def _rms(tensor, param_is_dtensor, tp_size, dp_size, tp_group, dp_group): + tensor_sum = tensor.pow(2).sum() + num_of_element = tensor.numel() + # reduce sum on tp group if exist + if tp_size > 1 and param_is_dtensor: + dist.all_reduce(tensor_sum, group=tp_group) + num_of_element = num_of_element * tp_size + # reduce sum on dp group if exist + if dp_size > 1 and param_is_dtensor: + dist.all_reduce(tensor_sum, group=dp_group) + num_of_element = num_of_element * dp_size + # div num of element + rms = (tensor_sum / num_of_element).sqrt() + return rms @staticmethod def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): @@ -217,12 +229,12 @@ def step(self, closure=None): state["RMS"] = 0 else: if use_first_moment: - state["exp_avg"] = state["exp_avg"].to(grad) + state["exp_avg"] = state["exp_avg"] if factored: - state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) - state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + state["exp_avg_sq_row"] = state["exp_avg_sq_row"] + state["exp_avg_sq_col"] = state["exp_avg_sq_col"] else: - state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + state["exp_avg_sq"] = state["exp_avg_sq"] state["step"] += 1 lr = self._get_lr(group, state) @@ -306,20 +318,8 @@ def step(self, closure=None): exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) update = exp_avg_sq.rsqrt().mul_(grad) - # (Line No.8) RMS - # perform a sum on each device - update_sum = update.pow(2).sum() - num_of_element = update.numel() - # reduce sum on tp group if exist - if self.tensor_parallel_size > 1 and param_is_dtensor: - dist.all_reduce(update_sum, group=self.tensor_parallel_group) - num_of_element = num_of_element * self.tensor_parallel_size - # reduce sum on dp group if exist - if self.data_parallel_size > 1 and param_is_dtensor: - dist.all_reduce(update_sum, group=self.data_parallel_group) - num_of_element = num_of_element * self.data_parallel_size - # div num of element - rms = (update_sum / num_of_element).sqrt() + # # (Line No.8) RMS + rms = self._rms(update, param_is_dtensor, self.tensor_parallel_size, self.data_parallel_size, self.tensor_parallel_group, self.data_parallel_group) update.div_((rms / group["clip_threshold"]).clamp_(min=1.0)) update.mul_(lr) From 9d33a34bc6745e92a07bdf3e4f436606959493a8 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 11 Apr 2024 18:57:34 +0800 Subject: [PATCH 16/35] [feature] add booster + lowlevelzeroPlugin in test; --- tests/test_optimizer/test_dist_adafactor.py | 128 +++++++++++++++++++- 1 file changed, 126 insertions(+), 2 deletions(-) diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index 17da80521cbd..a239e0e16555 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -9,7 +9,8 @@ import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import TorchDDPPlugin, HybridParallelPlugin +from colossalai.booster.plugin import TorchDDPPlugin, HybridParallelPlugin, LowLevelZeroPlugin + from colossalai.cluster import ProcessGroupMesh from colossalai.device.device_mesh import DeviceMesh from colossalai.nn.optimizer.adafactor import Adafactor @@ -272,7 +273,7 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): col_correct = correctness_verify(weight.data, weight_col_gather.data, dtype) row_correct = correctness_verify(weight.data, weight_row_gather.data, dtype) - print(f"col corrness {col_correct} row correct {row_correct}") + print(f"Base Test Pass") @parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 @parameterize("tp_zero_size", [(2, 2), (4, 1), (1, 4)]) # (2, 2), (4, 1), (1, 4) @@ -383,6 +384,127 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): clear_layout_converter() Randomizer.reset_index() torch.cuda.empty_cache() + print(f"Zero Test Pass") + +@parameterize("dtype", [torch.float16]) # torch.float32, torch.float16, torch.bfloat16 +@parameterize("tp_zero_size", [(2, 2), (1, 4)]) # (2, 2), (4, 1), (1, 4) +def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int]): + tp_size, zero_size = tp_zero_size + use_zero = True if zero_size > 1 else False + local_rank = dist.get_rank() + + clear_layout_converter() + + proc_mesh = ProcessGroupMesh(tp_size, zero_size) + tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) + + torch.set_default_dtype(dtype) + set_seed(42) + + # ============================== + # Model Init + # ============================== + base_model = MlpModel().to(local_rank) + tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) + + base_param_group = setup_param_groups(base_model) + tp_param_group = setup_param_groups(tp_model) + tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model) + + # ============================== + # Optimizer Init + # ============================== + base_optim = Adafactor(base_param_group) + dist_optim = DistributedAdaFactor(tp_param_group) + + # Setup distributed optimizer + if zero_size > 1: + base_optim = LowLevelZeroOptimizer( + base_optim, + overlap_communication=True, + initial_scale=128, + partition_grad=True, + dp_process_group=dp_group, + verbose=True, + ) + + dist_optim = LowLevelZeroOptimizer( + dist_optim, + overlap_communication=True, + initial_scale=128, + partition_grad=True, + dp_process_group=dp_group, + verbose=True, + ) + shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened + dist_optim.optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=dp_group, + shard_to_param=shard_to_param, + use_zero=use_zero, + ) + else: + shard_to_param = set_master_param_to_shard_param(tp_param_group) + dist_optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=dp_group, + shard_to_param=shard_to_param, + use_zero=use_zero, + ) + + # ============================== + # Booster Init + # ============================== + plugin = LowLevelZeroPlugin() + booster = Booster(plugin=plugin) + criterion = lambda x: x.mean() + + tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion) + + # ============================== + # Correctness Verify + # ============================== + x = torch.randn(HEIGHT, WIDTH, device=local_rank) + + out = base_model(x) + out_tp = tp_model(x) + + if zero_size > 1: + dist_optim.backward(out_tp.sum()) + base_optim.backward(out.sum()) + else: + out_tp.sum().backward() + out.sum().backward() + + base_optim.step() + dist_optim.step() + + base_optim.zero_grad() + dist_optim.zero_grad() + + for p, tp_p in zip(base_param_group, tp_param_group): + param_is_distributed = is_distributed_tensor(tp_p) + if param_is_distributed: + shard_spec = get_sharding_spec(tp_p) + if len(shard_spec.sharding_sequence) >= 2: + # Col Parallel + if shard_spec.sharding_sequence[0] == "R": + tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + # ROW Parallel + if shard_spec.sharding_sequence[-1] == "R": + tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather + else: + # TP bias + tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + + else: + # No TP bias + pass + correctness = correctness_verify(p.data, tp_p.data, dtype) + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + print(f"Booster Test Pass") @parameterize( "test_config", @@ -475,12 +597,14 @@ def exam_bert_test(test_config): clear_layout_converter() Randomizer.reset_index() torch.cuda.empty_cache() + print(f"Bert Model Zoo Test Pass") def run_dist(rank, world_size, port): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_dist_adafactor_base() exam_dist_adafactor_zero() + exam_dist_adafactor_booster() exam_bert_test() @pytest.mark.dist From 419c1c0adee3867badcd1f2e52e8b831cfbf9c11 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 12 Apr 2024 11:04:27 +0800 Subject: [PATCH 17/35] [feature] add Start_with_booster_API case in md; add Supporting Information in md; --- .../en/features/distributed_adafactor.md | 100 ++++++-------- tests/test_optimizer/test_dist_adafactor.py | 129 +----------------- 2 files changed, 45 insertions(+), 184 deletions(-) diff --git a/docs/source/en/features/distributed_adafactor.md b/docs/source/en/features/distributed_adafactor.md index 3ba609283def..725067bbcd28 100644 --- a/docs/source/en/features/distributed_adafactor.md +++ b/docs/source/en/features/distributed_adafactor.md @@ -14,7 +14,7 @@ Distributed Adafactor is an optimiser that supports hybrid optimisation, includi {{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} ## Hands-On Practice -We now demonstrate how to use Distributed Adafactor. +We now demonstrate how to start Distributed Adafactor with booster API. ### step 1. Import libraries ```python @@ -86,7 +86,15 @@ dist_optim.setup_distributed( ) ``` -### step 5. Perform a forward and backward propagation for model and step the gradient +### step 5.Init Booster + +```python +plugin = LowLevelZeroPlugin() +booster = Booster(plugin=plugin) +criterion = lambda x: x.mean() +tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion) +``` +### step 6.Perform a forward and backward propagation for model and step the gradient ```python # Random initialise dataset @@ -104,60 +112,36 @@ dist_optim.step() dist_optim.zero_grad() ``` -## Run with booster -We highly recommend users to use booster, a simple, easy to use, and efficient interface. The Code Below is the Distributed Adafactor launched with booster. - -```python -# ============================== -# Model Init -# ============================== -tp_model = TPModel() - -# ============================== -# Optimizer Init -# ============================== -dist_optim = DistributedAdaFactor([p for n, p in tp_model.named_parameters()]) - -# ============================== -# Booster Init -# ============================== -plugin = TorchDDPPlugin() -booster = Booster(plugin=plugin) -criterion = lambda x: x.mean() -tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion) -``` - -## Performance - -| Parallel strategy | iter | Float Percision | Device Nums | weight shape | Avg runtime(ms) | Avg Speed Up Rate | Best Speed Up Rate | -| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: | -| AdaFactor | 50 | float32 | 2 | [4096 , 4096] | 0.58 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.41 | 1.39 | 56.91 | -| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.61 | 0.96 | 18.69 | -| AdaFactor | 50 | float16 | 2 | [4096 , 4096] | 0.72 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.54 | 1.33 | 26.03 | -| DistAdaFactor(Row Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.67 | 1.08 | 20.55 | -| AdaFactor | 50 | bfloat16 | 2 | [4096 , 4096] | 0.72 | - | - | -| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.55 | 1.31 | 26.11 | -| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.67 | 1.07 | 21.86 | -| AdaFactor | 50 | float32 | 4 | [4096 , 4096] | 0.57 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.38 | 1.48 | 53.99 | -| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.60 | 0.95 | 16.53 | -| AdaFactor | 50 | float16 | 4 | [4096 , 4096] | 0.70 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.50 | 1.44 | 21.98 | -| DistAdaFactor(Row Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.64 | 1.12 | 15.35 | -| AdaFactor | 50 | bfloat16 | 4 | [4096 , 4096] | 0.72 | - | - | -| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.56 | 1.29 | 25.63 | -| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.71 | 1.09 | 21.52 | -| AdaFactor | 50 | float32 | 8 | [4096 , 4096] | 0.56 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.38 | 1.50 | 54.51 | -| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.91 | 0.67 | 15.68 | -| AdaFactor | 50 | float16 | 8 | [4096 , 4096] | 0.74 | - | - | -| DistAdaFactor(Col Parallel) | 50 | float16 | 8 | [4096 , 4096] | 0.84 | 0.87 | 9.21 | -| DistAdaFactor(Row Parallel) | 50 | float16 | 8 | [4096 , 4096] | 1.03 | 0.75 | 16.12 | -| AdaFactor | 50 | bfloat16 | 8 | [4096 , 4096] | 0.71 | - | - | -| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.54 | 1.31 | 27.28 | -| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.73 | 1.03 | 25.01 | - - +## Supporting Information +Model/Feature Compatibility Matrix: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Model/FeatureTransformers
Bert
Transformers Bert
For Pretraining
Transformers Bert
Lm Head Model
Transformers Bert
For Masked Lm
Transformers Bert
For Sequence Classification
Transformers Bert
For Token Classification
Transformers Bert
For Next Sentence
Transformers Bert
For Multiple-choice Question
Transformers Bert
For Question Answering
Distributedt
Adafactor
✔️✔️✔️✔️✔️✔️✔️✔️✔️
diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index a239e0e16555..b5f3260f94a5 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -385,127 +385,7 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): Randomizer.reset_index() torch.cuda.empty_cache() print(f"Zero Test Pass") - -@parameterize("dtype", [torch.float16]) # torch.float32, torch.float16, torch.bfloat16 -@parameterize("tp_zero_size", [(2, 2), (1, 4)]) # (2, 2), (4, 1), (1, 4) -def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int]): - tp_size, zero_size = tp_zero_size - use_zero = True if zero_size > 1 else False - local_rank = dist.get_rank() - - clear_layout_converter() - - proc_mesh = ProcessGroupMesh(tp_size, zero_size) - tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) - - torch.set_default_dtype(dtype) - set_seed(42) - - # ============================== - # Model Init - # ============================== - base_model = MlpModel().to(local_rank) - tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) - - base_param_group = setup_param_groups(base_model) - tp_param_group = setup_param_groups(tp_model) - tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model) - - # ============================== - # Optimizer Init - # ============================== - base_optim = Adafactor(base_param_group) - dist_optim = DistributedAdaFactor(tp_param_group) - - # Setup distributed optimizer - if zero_size > 1: - base_optim = LowLevelZeroOptimizer( - base_optim, - overlap_communication=True, - initial_scale=128, - partition_grad=True, - dp_process_group=dp_group, - verbose=True, - ) - - dist_optim = LowLevelZeroOptimizer( - dist_optim, - overlap_communication=True, - initial_scale=128, - partition_grad=True, - dp_process_group=dp_group, - verbose=True, - ) - shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened - dist_optim.optim.setup_distributed( - tensor_parallel_group=tp_group, - data_parallel_group=dp_group, - shard_to_param=shard_to_param, - use_zero=use_zero, - ) - else: - shard_to_param = set_master_param_to_shard_param(tp_param_group) - dist_optim.setup_distributed( - tensor_parallel_group=tp_group, - data_parallel_group=dp_group, - shard_to_param=shard_to_param, - use_zero=use_zero, - ) - - # ============================== - # Booster Init - # ============================== - plugin = LowLevelZeroPlugin() - booster = Booster(plugin=plugin) - criterion = lambda x: x.mean() - - tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion) - - # ============================== - # Correctness Verify - # ============================== - x = torch.randn(HEIGHT, WIDTH, device=local_rank) - - out = base_model(x) - out_tp = tp_model(x) - - if zero_size > 1: - dist_optim.backward(out_tp.sum()) - base_optim.backward(out.sum()) - else: - out_tp.sum().backward() - out.sum().backward() - - base_optim.step() - dist_optim.step() - - base_optim.zero_grad() - dist_optim.zero_grad() - - for p, tp_p in zip(base_param_group, tp_param_group): - param_is_distributed = is_distributed_tensor(tp_p) - if param_is_distributed: - shard_spec = get_sharding_spec(tp_p) - if len(shard_spec.sharding_sequence) >= 2: - # Col Parallel - if shard_spec.sharding_sequence[0] == "R": - tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather - # ROW Parallel - if shard_spec.sharding_sequence[-1] == "R": - tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather - else: - # TP bias - tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather - - else: - # No TP bias - pass - correctness = correctness_verify(p.data, tp_p.data, dtype) - clear_layout_converter() - Randomizer.reset_index() - torch.cuda.empty_cache() - print(f"Booster Test Pass") - + @parameterize( "test_config", [ @@ -560,6 +440,7 @@ def exam_bert_test(test_config): ] for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + clear_layout_converter() print(f"model name {name} {name in model_list}") if name in model_list: print(f"{name} check start") @@ -570,8 +451,7 @@ def exam_bert_test(test_config): org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster ) - - + stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group @@ -593,8 +473,6 @@ def exam_bert_test(test_config): check_optim_states(org_optimizer, sharded_optimizer.optim) print(f"{name} check pass") - - clear_layout_converter() Randomizer.reset_index() torch.cuda.empty_cache() print(f"Bert Model Zoo Test Pass") @@ -604,7 +482,6 @@ def run_dist(rank, world_size, port): colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_dist_adafactor_base() exam_dist_adafactor_zero() - exam_dist_adafactor_booster() exam_bert_test() @pytest.mark.dist From c84fb52cce6f3af28bcd1b36c17276deb6d3aaa2 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 12 Apr 2024 11:06:55 +0800 Subject: [PATCH 18/35] [fix] Also remove state movement in base adafactor; --- colossalai/nn/optimizer/adafactor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/colossalai/nn/optimizer/adafactor.py b/colossalai/nn/optimizer/adafactor.py index 3b7d998c9f28..57d677ef0059 100644 --- a/colossalai/nn/optimizer/adafactor.py +++ b/colossalai/nn/optimizer/adafactor.py @@ -159,12 +159,12 @@ def step(self, closure=None): state["RMS"] = 0 else: if use_first_moment: - state["exp_avg"] = state["exp_avg"].to(grad) + state["exp_avg"] = state["exp_avg"] if factored: - state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) - state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + state["exp_avg_sq_row"] = state["exp_avg_sq_row"] + state["exp_avg_sq_col"] = state["exp_avg_sq_col"] else: - state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + state["exp_avg_sq"] = state["exp_avg_sq"] state["step"] += 1 # state["RMS"] = self._rms(p_data_fp32) From 2eb069d3e8fc8e01f5229177937b8eab3e36bc1c Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 12 Apr 2024 14:43:35 +0800 Subject: [PATCH 19/35] [feature] extract factor function; --- .../nn/optimizer/distributed_adafactor.py | 216 ++++++++++-------- .../en/features/distributed_adafactor.md | 2 +- tests/test_optimizer/test_dist_adafactor.py | 3 - 3 files changed, 122 insertions(+), 99 deletions(-) diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index b67580fd0f28..3b71a6ba872e 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -153,7 +153,85 @@ def _approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam): r_factor = (exp_avg_sq_row / sq_row_meam).rsqrt_().unsqueeze(-1) c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() return torch.mul(r_factor, c_factor) - + + def _col_parallel_factor(self, update, grad, state, grad_shape, beta2t): + update_reshape = update.view(-1, grad_shape[1]) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + dist.all_reduce(exp_avg_sq_row, group=self.tensor_parallel_group) + exp_avg_sq_row.div_(self.tensor_parallel_size) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + if self.use_zero: + update = update_reshape.view(-1) + else: + update = update_reshape + return update + + def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t): + if grad_shape[0] % self.data_parallel_size != 0: + # gather update[flatten] along dp group then reshape to [H/tp, W] + update = _gather( + input_=update, dim=-1, process_group=self.data_parallel_group + ) + # view update to origin[tp] shape + update_reshape = update.view(-1, grad_shape[1]) + # gather grad[flatten] along dp group then reshape to [H/tp, W] + grad = _gather( + input_=grad, dim=-1, process_group=self.data_parallel_group + ) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group) + exp_avg_sq_col.div_(self.tensor_parallel_size) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + if self.use_zero: + update = _split(input_=update_reshape.view(-1), dim=-1, process_group=self.data_parallel_group) + else: + update = update_reshape + else: + update_reshape = update.view(-1, grad_shape[1]) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp/tp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group) + exp_avg_sq_col.div_(self.tensor_parallel_size) + # gather row + exp_avg_sq_row_gather = _gather( + input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group + ) + sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True) + update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam) + update_reshape.mul_(grad_reshape) + if self.use_zero: + update = update_reshape.view(-1) + else: + update = update_reshape + return update + + def _base_factor(self, update, grad, state, grad_shape, beta2t): + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + # Exponential average of row indexes + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + # Exponential average of columns indexes + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + return update + @torch.no_grad() def step(self, closure=None): """ @@ -201,29 +279,33 @@ def step(self, closure=None): if use_first_moment: # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(p) - if factored and param_is_dtensor: - if shard_spec.sharding_sequence[0] == "R": # Col Parallel - state["exp_avg_sq_row"] = torch.zeros( - grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype - ) # [H/dp] - state["exp_avg_sq_col"] = torch.zeros( - grad_shape[1], device=p.device, dtype=p.dtype - ) # [W/TP] - - if shard_spec.sharding_sequence[-1] == "R": # Row Parallel - # Row indivisible shape situation - if grad_shape[0] % self.data_parallel_size != 0: - state["exp_avg_sq_row"] = torch.zeros( - grad_shape[0], device=p.device, dtype=p.dtype - ) # [H/dp/Tp] - else: + if factored: + if param_is_dtensor: + if shard_spec.sharding_sequence[0] == "R": # Col Parallel state["exp_avg_sq_row"] = torch.zeros( - grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype - ) # [H/dp/Tp] - - state["exp_avg_sq_col"] = torch.zeros( - grad_shape[1], device=p.device, dtype=p.dtype - ) # [W] + grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype + ) # [H/dp] + state["exp_avg_sq_col"] = torch.zeros( + grad_shape[1], device=p.device, dtype=p.dtype + ) # [W/TP] + + if shard_spec.sharding_sequence[-1] == "R": # Row Parallel + # Row indivisible shape situation + if grad_shape[0] % self.data_parallel_size != 0: + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0], device=p.device, dtype=p.dtype + ) # [H/dp/Tp] + else: + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype + ) # [H/dp/Tp] + + state["exp_avg_sq_col"] = torch.zeros( + grad_shape[1], device=p.device, dtype=p.dtype + ) # [W] + else: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype) + state["exp_avg_sq_col"] = torch.zeros(grad_shape[1], device=grad.device, dtype=p.dtype) else: state["exp_avg_sq"] = torch.zeros_like(p) state["RMS"] = 0 @@ -241,78 +323,22 @@ def step(self, closure=None): beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) update = (grad**2) + group["eps"][0] - if factored and param_is_dtensor: - # ============================== - # First Dim is R, Last Dim is S{} means split dim -1 ---> - # Coloum Parallel ---> sq_row need Do (col) Reduce - # ============================== - if shard_spec.sharding_sequence[0] == "R": - update_reshape = update.view(-1, grad_shape[1]) - grad_reshape = grad.view(-1, grad_shape[1]) - exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp] - exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp] - exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) - exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) - dist.all_reduce(exp_avg_sq_row, group=self.tensor_parallel_group) - exp_avg_sq_row.div_(self.tensor_parallel_size) - update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) - update_reshape.mul_(grad_reshape) - if self.use_zero: - update = update_reshape.view(-1) - else: - update = update_reshape - # ============================== - # Last Dim is R, First Dim is S{} means split dim 0 ---> - # Row Parallel ---> sq_col need Do (row) Reduce - # ============================== - elif shard_spec.sharding_sequence[-1] == "R": - # Row Residual situation - if grad_shape[0] % self.data_parallel_size != 0: - # gather update[flatten] along dp group then reshape to [H/tp, W] - update = _gather( - input_=update, dim=-1, process_group=self.data_parallel_group - ) - # view update to origin[tp] shape - update_reshape = update.view(-1, grad_shape[1]) - # gather grad[flatten] along dp group then reshape to [H/tp, W] - grad = _gather( - input_=grad, dim=-1, process_group=self.data_parallel_group - ) - grad_reshape = grad.view(-1, grad_shape[1]) - exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp] - exp_avg_sq_col = state["exp_avg_sq_col"] # [W] - exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) - exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) - # reduce col - dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group) - exp_avg_sq_col.div_(self.tensor_parallel_size) - update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) - update_reshape.mul_(grad_reshape) - if self.use_zero: - update = _split(input_=update_reshape.view(-1), dim=-1, process_group=self.data_parallel_group) - else: - update = update_reshape - else: - update_reshape = update.view(-1, grad_shape[1]) - grad_reshape = grad.view(-1, grad_shape[1]) - exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp/tp] - exp_avg_sq_col = state["exp_avg_sq_col"] # [W] - exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) - exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) - # reduce col - dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group) - exp_avg_sq_col.div_(self.tensor_parallel_size) - # gather row - exp_avg_sq_row_gather = _gather( - input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group - ) - sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True) - update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam) - update_reshape.mul_(grad_reshape) - if self.use_zero: - update = update_reshape.view(-1) - else: - update = update_reshape + if factored: + if param_is_dtensor: + # ============================== + # First Dim is R, Last Dim is S{} means split dim -1 ---> + # Coloum Parallel ---> sq_row need Do (col) Reduce + # ============================== + if shard_spec.sharding_sequence[0] == "R": + update = self._col_parallel_factor(update, grad, state, grad_shape, beta2t) + # ============================== + # Last Dim is R, First Dim is S{} means split dim 0 ---> + # Row Parallel ---> sq_col need Do (row) Reduce + # ============================== + elif shard_spec.sharding_sequence[-1] == "R": + update = self._row_parallel_factor(update, grad, state, grad_shape, beta2t) + else: + update = self._base_factor(update, grad, state, grad_shape, beta2t) else: exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) diff --git a/docs/source/en/features/distributed_adafactor.md b/docs/source/en/features/distributed_adafactor.md index 725067bbcd28..4c25aa1fbdb9 100644 --- a/docs/source/en/features/distributed_adafactor.md +++ b/docs/source/en/features/distributed_adafactor.md @@ -128,7 +128,7 @@ Model/Feature Compatibility Matrix: Transformers Bert
For Question Answering - Distributedt
Adafactor + Distributed
Adafactor ✔️ ✔️ ✔️ diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index b5f3260f94a5..c07944d6f213 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -415,7 +415,6 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): }, { "tp_size": 4, - "pp_size": 1, "num_microbatches": 4, "zero_stage": 0, "precision": "bf16", @@ -441,9 +440,7 @@ def exam_bert_test(test_config): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): clear_layout_converter() - print(f"model name {name} {name in model_list}") if name in model_list: - print(f"{name} check start") org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor ) From 6303291c4a4e367ee86313a64c727515bc022d60 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 12 Apr 2024 15:19:57 +0800 Subject: [PATCH 20/35] [feature] add LowLevelZeroPlugin test; --- tests/test_optimizer/test_dist_adafactor.py | 129 +++++++++++++++++++- 1 file changed, 125 insertions(+), 4 deletions(-) diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index c07944d6f213..25be432705f8 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -10,7 +10,7 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import TorchDDPPlugin, HybridParallelPlugin, LowLevelZeroPlugin - +from colossalai.logging import disable_existing_loggers from colossalai.cluster import ProcessGroupMesh from colossalai.device.device_mesh import DeviceMesh from colossalai.nn.optimizer.adafactor import Adafactor @@ -386,6 +386,126 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): torch.cuda.empty_cache() print(f"Zero Test Pass") +@parameterize("dtype", [torch.float16]) +@parameterize("tp_zero_size", [(2, 2), (1, 4)]) +def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int]): + tp_size, zero_size = tp_zero_size + use_zero = True if zero_size > 1 else False + local_rank = dist.get_rank() + + clear_layout_converter() + + proc_mesh = ProcessGroupMesh(tp_size, zero_size) + tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) + + torch.set_default_dtype(dtype) + set_seed(42) + + # ============================== + # Model Init + # ============================== + base_model = MlpModel().to(local_rank) + tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) + + base_param_group = setup_param_groups(base_model) + tp_param_group = setup_param_groups(tp_model) + tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model) + + # ============================== + # Optimizer Init + # ============================== + base_optim = Adafactor(base_param_group) + dist_optim = DistributedAdaFactor(tp_param_group) + + # Setup distributed optimizer + if zero_size > 1: + base_optim = LowLevelZeroOptimizer( + base_optim, + overlap_communication=True, + initial_scale=128, + partition_grad=True, + dp_process_group=dp_group, + verbose=True, + ) + + dist_optim = LowLevelZeroOptimizer( + dist_optim, + overlap_communication=True, + initial_scale=128, + partition_grad=True, + dp_process_group=dp_group, + verbose=True, + ) + shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened + dist_optim.optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=dp_group, + shard_to_param=shard_to_param, + use_zero=use_zero, + ) + else: + shard_to_param = set_master_param_to_shard_param(tp_param_group) + dist_optim.setup_distributed( + tensor_parallel_group=tp_group, + data_parallel_group=dp_group, + shard_to_param=shard_to_param, + use_zero=use_zero, + ) + + # ============================== + # Booster Init + # ============================== + plugin = LowLevelZeroPlugin() + booster = Booster(plugin=plugin) + criterion = lambda x: x.mean() + + tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion) + + # ============================== + # Correctness Verify + # ============================== + x = torch.randn(HEIGHT, WIDTH, device=local_rank) + + out = base_model(x) + out_tp = tp_model(x) + + if zero_size > 1: + dist_optim.backward(out_tp.sum()) + base_optim.backward(out.sum()) + else: + out_tp.sum().backward() + out.sum().backward() + + base_optim.step() + dist_optim.step() + + base_optim.zero_grad() + dist_optim.zero_grad() + + for p, tp_p in zip(base_param_group, tp_param_group): + param_is_distributed = is_distributed_tensor(tp_p) + if param_is_distributed: + shard_spec = get_sharding_spec(tp_p) + if len(shard_spec.sharding_sequence) >= 2: + # Col Parallel + if shard_spec.sharding_sequence[0] == "R": + tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + # ROW Parallel + if shard_spec.sharding_sequence[-1] == "R": + tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather + else: + # TP bias + tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather + + else: + # No TP bias + pass + correctness = correctness_verify(p.data, tp_p.data, dtype) + Randomizer.reset_index() + torch.cuda.empty_cache() + print(f"Booster Test Pass") + + @parameterize( "test_config", [ @@ -425,7 +545,7 @@ def exam_bert_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") test_config["use_lazy_init"] = False test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel - test_config["initial_scale"] = 2**10 # avoid overflow + test_config["initial_scale"] = 2**16 # avoid overflow model_list = [ "transformers_bert", "transformers_bert_for_pretraining", @@ -437,9 +557,8 @@ def exam_bert_test(test_config): "transformers_bert_for_mcq", "transformers_bert_for_question_answering", ] - + clear_layout_converter() for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - clear_layout_converter() if name in model_list: org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor @@ -475,11 +594,13 @@ def exam_bert_test(test_config): print(f"Bert Model Zoo Test Pass") def run_dist(rank, world_size, port): + disable_existing_loggers() config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_dist_adafactor_base() exam_dist_adafactor_zero() exam_bert_test() + exam_dist_adafactor_booster() @pytest.mark.dist @rerun_if_address_is_in_use() From 60489ab39e55867b6d96ae8e548e8a219ebf3c70 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 12 Apr 2024 20:15:46 +0800 Subject: [PATCH 21/35] [fix] add tp=False and zero=True in logic; --- .../nn/optimizer/distributed_adafactor.py | 46 ++++++++--- .../en/features/distributed_adafactor.md | 77 ++++++------------- tests/test_optimizer/test_dist_adafactor.py | 40 +++++++--- 3 files changed, 87 insertions(+), 76 deletions(-) diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index 3b71a6ba872e..2669948c08cd 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -93,7 +93,9 @@ def setup_distributed( if self.param_is_dtensor_dict[id(p)]: self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape else: - self.grad_shape_dict[id(p)] = p.shape + # no tp; could be zero or not zero + # self.grad_shape_dict[id(p)] = p.shape + self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] = self._get_options(group, self.grad_shape_dict[id(p)]) if self.param_is_dtensor_dict[id(p)]: self.shard_spec_dict[id(p)] = get_sharding_spec(self.shard_to_param.get(id(p))) @@ -120,6 +122,7 @@ def _get_options(param_group, param_shape): param_shape : Original Shape of param """ + print(f"param_shape {param_shape}") factored = len(param_shape) >= 2 use_first_moment = param_group["beta1"] is not None return factored, use_first_moment @@ -221,17 +224,29 @@ def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t): return update def _base_factor(self, update, grad, state, grad_shape, beta2t): - exp_avg_sq_row = state["exp_avg_sq_row"] - exp_avg_sq_col = state["exp_avg_sq_col"] - # Exponential average of row indexes - exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) - # Exponential average of columns indexes - exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) - # Approximation of exponential moving average of square of gradient - update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) - update.mul_(grad) + if self.use_zero: + # view update to origin shape update.view(grad_shape[0]//self.data_parallel_size , grad_shape[1]) + + # row mean no change + + # col mean need reduce and div + + print(f"origin shape {grad_shape}, true shape {update.shape}") + pass + else: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + # Exponential average of row indexes + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + # Exponential average of columns indexes + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) return update + + @torch.no_grad() def step(self, closure=None): """ @@ -273,6 +288,9 @@ def step(self, closure=None): if param_is_dtensor: grad_shape = self.shard_to_param.get(id(p)).shape # tp shape (2 dim) factored, use_first_moment = self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] + + # print(f"factored {factored} param_is_dtensor {param_is_dtensor} shape {grad_shape}") + shard_spec = self.shard_spec_dict[id(p)] if len(state) == 0: state["step"] = 0 @@ -304,7 +322,13 @@ def step(self, closure=None): grad_shape[1], device=p.device, dtype=p.dtype ) # [W] else: - state["exp_avg_sq_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype) + if self.use_zero: + # [H // dp] + state["exp_avg_sq_row"] = torch.zeros(grad_shape[0] // self.data_parallel_size, device=grad.device, dtype=p.dtype) + else: + # [H] + state["exp_avg_sq_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype) + # Alaways [W] state["exp_avg_sq_col"] = torch.zeros(grad_shape[1], device=grad.device, dtype=p.dtype) else: state["exp_avg_sq"] = torch.zeros_like(p) diff --git a/docs/source/en/features/distributed_adafactor.md b/docs/source/en/features/distributed_adafactor.md index 4c25aa1fbdb9..5a8d8ebade85 100644 --- a/docs/source/en/features/distributed_adafactor.md +++ b/docs/source/en/features/distributed_adafactor.md @@ -21,11 +21,13 @@ We now demonstrate how to start Distributed Adafactor with booster API. import torch from torch import nn import torch.distributed as dist +from transformers import LlamaModel, LlamaConfig from colossalai.cluster import ProcessGroupMesh from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor - +from colossal_llama2.dataset.loader import load_tokenized_dataset +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin ``` ### step 2. Initialize Distributed Environment and Parallism Group @@ -44,72 +46,37 @@ proc_mesh = ProcessGroupMesh(tp_size, zero_size) tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) ``` -### step 3. Initialize Module +### step 3. Initialize Module and Optimizer Build our model. We created an MLP using two Linear Layer. ```python -# Init a Tensor Paralleled Module -class TPModel(nn.Module): - def __init__(self, linear1, linear2, tp_group=None): - super().__init__() - self.linear1 = Linear1D_Col.from_native_module( - linear1, process_group=tp_group, gather_output=False, overlap=True - ) - self.linear2 = Linear1D_Row.from_native_module(linear2, process_group=tp_group, parallel_input=True) - - def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - return x -HEIGHT = 4096 -WIDTH = 4096 -tp_model = TPModel(copy.deepcopy(nn.Linear(HEIGHT, WIDTH)), copy.deepcopy(nn.Linear(HEIGHT, WIDTH)), tp_group).to(local_rank) - -# Get Module parameter -tp_param_group = [p for n, p in tp_model.named_parameters()] -``` - -### step 4. Initialize Optimizer -Then, We initialise the optimiser using the model parameter. Then, we set the distributed information for optimiser. +# Init Llama from huggingface +configuration = LlamaConfig() +model = LlamaModel(configuration) +dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") +dataloader = plugin.prepare_dataloader(dataset, batch_size=8) +criterion = lambda x: x.mean() +dist_optim = DistributedAdaFactor(model.parameters()) -```python -# Init a Optimizer -dist_optim = DistributedAdaFactor(tp_param_group) -shard_to_param = {id(p):p for p in tp_param_group} - -# Setup distributed information for Optimizer -dist_optim.setup_distributed( - tensor_parallel_group=tp_group, - data_parallel_group=dp_group, - shard_to_param=shard_to_param, - use_zero=use_zero, -) ``` -### step 5.Init Booster +### step 4.Init Booster ```python plugin = LowLevelZeroPlugin() booster = Booster(plugin=plugin) -criterion = lambda x: x.mean() -tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion) +model, dist_optim, criterion, dataloader, _ = booster.boost(model, dist_optim, criterion, dataloader) ``` -### step 6.Perform a forward and backward propagation for model and step the gradient - +### step 5.Train Your Model ```python -# Random initialise dataset -x = torch.randn(HEIGHT, WIDTH, device=local_rank) - -# Fwd and Bwd -out_tp = tp_model(x) -if zero_size > 1: - dist_optim.backward(out_tp.sum()) -else: - out_tp.sum().backward() - -# perform step for param and grad -dist_optim.step() -dist_optim.zero_grad() +for epoch in range(max_epochs): + for input_ids, attention_mask in dataloader: + outputs = model(input_ids.cuda(), attention_mask.cuda()) + loss = criterion(outputs.logits, input_ids) + booster.backward(loss, optimizer) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() ``` ## Supporting Information diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index 25be432705f8..9f3a46ab8add 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -275,8 +275,8 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): print(f"Base Test Pass") -@parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 -@parameterize("tp_zero_size", [(2, 2), (4, 1), (1, 4)]) # (2, 2), (4, 1), (1, 4) +@parameterize("dtype", [torch.float16]) # torch.float32, torch.float16, torch.bfloat16 +@parameterize("tp_zero_size", [(1, 4)]) # (2, 2), (4, 1), (1, 4) def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): tp_size, zero_size = tp_zero_size use_zero = True if zero_size > 1 else False @@ -387,7 +387,7 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): print(f"Zero Test Pass") @parameterize("dtype", [torch.float16]) -@parameterize("tp_zero_size", [(2, 2), (1, 4)]) +@parameterize("tp_zero_size", [(1, 4)]) def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int]): tp_size, zero_size = tp_zero_size use_zero = True if zero_size > 1 else False @@ -405,7 +405,8 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int # Model Init # ============================== base_model = MlpModel().to(local_rank) - tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) + # tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank) + tp_model = copy.deepcopy(base_model).to(local_rank) base_param_group = setup_param_groups(base_model) tp_param_group = setup_param_groups(tp_model) @@ -437,6 +438,7 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int verbose=True, ) shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened + # print(f"shard_to_param {shard_to_param}") dist_optim.optim.setup_distributed( tensor_parallel_group=tp_group, data_parallel_group=dp_group, @@ -496,7 +498,6 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int else: # TP bias tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather - else: # No TP bias pass @@ -560,6 +561,7 @@ def exam_bert_test(test_config): clear_layout_converter() for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): if name in model_list: + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor ) @@ -586,8 +588,26 @@ def exam_bert_test(test_config): if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) # check optim states - check_optim_states(org_optimizer, sharded_optimizer.optim) - print(f"{name} check pass") + # print(f"{org_optimizer.param_groups} {sharded_optimizer.optim.param_groups}") + + # for group, tp_group in zip(org_optimizer.param_groups, sharded_optimizer.optim.param_groups): + # for i in range(len(group["params"])): + # p, tp = group["params"], tp_group["params"] + + # sharded_state = sharded_optimizer.optim.state[tp] + # state = org_optimizer.state[p] + + # print(f"sharded_state {sharded_state}\n state {state}") + + # for p in group["params"]: + # sharded_state = sharded_optimizer.optim.state[p] + # state = org_optimizer.state[p] + # print(sharded_state) + # for key in sharded_state: + # print(state[key], sharded_state[key]) + + # check_optim_states(org_optimizer, sharded_optimizer.optim) + # print(f"{name} check pass") Randomizer.reset_index() torch.cuda.empty_cache() @@ -597,9 +617,9 @@ def run_dist(rank, world_size, port): disable_existing_loggers() config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - exam_dist_adafactor_base() - exam_dist_adafactor_zero() - exam_bert_test() + # exam_dist_adafactor_base() + # exam_dist_adafactor_zero() + # exam_bert_test() exam_dist_adafactor_booster() @pytest.mark.dist From 02ea83e6c3ae5dffb8d3bd2f052a439074be9f3b Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Sun, 14 Apr 2024 00:36:36 +0800 Subject: [PATCH 22/35] [fix] fix use zero logic; --- .../nn/optimizer/distributed_adafactor.py | 64 +++++++++++++++---- tests/test_optimizer/test_dist_adafactor.py | 6 +- 2 files changed, 54 insertions(+), 16 deletions(-) diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index 2669948c08cd..956bf07604e3 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -122,7 +122,6 @@ def _get_options(param_group, param_shape): param_shape : Original Shape of param """ - print(f"param_shape {param_shape}") factored = len(param_shape) >= 2 use_first_moment = param_group["beta1"] is not None return factored, use_first_moment @@ -225,15 +224,49 @@ def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t): def _base_factor(self, update, grad, state, grad_shape, beta2t): if self.use_zero: - # view update to origin shape update.view(grad_shape[0]//self.data_parallel_size , grad_shape[1]) - - # row mean no change - - # col mean need reduce and div - - print(f"origin shape {grad_shape}, true shape {update.shape}") - pass + # only zero + if grad_shape[0] % self.data_parallel_size != 0: + # view update to origin shape update.view(grad_shape[0]//self.data_parallel_size , grad_shape[1]) + # row mean no change + # col mean need reduce and div + # gather update[flatten] along dp group then reshape to [H, W] + update = _gather( + input_=update, dim=-1, process_group=self.data_parallel_group + ) + # view update to origin[tp] shape + update_reshape = update.view(-1, grad_shape[1]) + # gather grad[flatten] along dp group then reshape to [H, W] + grad = _gather( + input_=grad, dim=-1, process_group=self.data_parallel_group + ) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group) + exp_avg_sq_col.div_(self.tensor_parallel_size) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + update = _split(input_=update_reshape.view(-1), dim=-1, process_group=self.data_parallel_group) + else: + # no residual row + # view update to origin[tp] shape + update_reshape = update.view(-1, grad_shape[1]) # [H/dp, W] + grad_reshape = grad.view(-1, grad_shape[1]) # [H/dp, W] + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group) + exp_avg_sq_col.div_(self.tensor_parallel_size) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + update = update_reshape.view(-1) else: + # base factor; no tp, no dp exp_avg_sq_row = state["exp_avg_sq_row"] exp_avg_sq_col = state["exp_avg_sq_col"] # Exponential average of row indexes @@ -323,12 +356,17 @@ def step(self, closure=None): ) # [W] else: if self.use_zero: - # [H // dp] - state["exp_avg_sq_row"] = torch.zeros(grad_shape[0] // self.data_parallel_size, device=grad.device, dtype=p.dtype) + # param grad [H // dp] + if grad_shape[0] % self.data_parallel_size != 0: + # save all exp_avg_sq_row [H] + state["exp_avg_sq_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype) + else: + # exp_avg_sq_row [H // dp] + state["exp_avg_sq_row"] = torch.zeros(grad_shape[0] // self.data_parallel_size, device=grad.device, dtype=p.dtype) else: - # [H] + # exp_avg_sq_row [H] state["exp_avg_sq_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype) - # Alaways [W] + # exp_avg_sq_col alaways [W] state["exp_avg_sq_col"] = torch.zeros(grad_shape[1], device=grad.device, dtype=p.dtype) else: state["exp_avg_sq"] = torch.zeros_like(p) diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index 9f3a46ab8add..f93508d7723d 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -617,9 +617,9 @@ def run_dist(rank, world_size, port): disable_existing_loggers() config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - # exam_dist_adafactor_base() - # exam_dist_adafactor_zero() - # exam_bert_test() + exam_bert_test() + exam_dist_adafactor_base() + exam_dist_adafactor_zero() exam_dist_adafactor_booster() @pytest.mark.dist From fb141253b0f06fca8373d68514742447fcb932d1 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Sun, 14 Apr 2024 21:06:09 +0800 Subject: [PATCH 23/35] [feature] add row residue logic in column parallel factor; --- .../nn/optimizer/distributed_adafactor.py | 72 +++++++++++++------ 1 file changed, 50 insertions(+), 22 deletions(-) diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index 956bf07604e3..9331da396a65 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -157,20 +157,43 @@ def _approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam): return torch.mul(r_factor, c_factor) def _col_parallel_factor(self, update, grad, state, grad_shape, beta2t): - update_reshape = update.view(-1, grad_shape[1]) - grad_reshape = grad.view(-1, grad_shape[1]) - exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp] - exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp] - exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) - exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) - dist.all_reduce(exp_avg_sq_row, group=self.tensor_parallel_group) - exp_avg_sq_row.div_(self.tensor_parallel_size) - update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) - update_reshape.mul_(grad_reshape) - if self.use_zero: - update = update_reshape.view(-1) + if grad_shape[0] % self.data_parallel_size != 0: + # gather update[flatten] along dp group then reshape to [H, W/tp] + update = _gather( + input_=update, dim=-1, process_group=self.data_parallel_group + ) + update_reshape = update.view(-1, grad_shape[1]) + # gather grad[flatten] along dp group then reshape to [H, W/tp] + grad = _gather( + input_=grad, dim=-1, process_group=self.data_parallel_group + ) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + if self.use_zero: + update = update_reshape.view(-1) + else: + update = update_reshape + else: - update = update_reshape + update_reshape = update.view(-1, grad_shape[1]) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + dist.all_reduce(exp_avg_sq_row, group=self.tensor_parallel_group) + exp_avg_sq_row.div_(self.tensor_parallel_size) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + if self.use_zero: + update = update_reshape.view(-1) + else: + update = update_reshape return update def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t): @@ -333,9 +356,14 @@ def step(self, closure=None): if factored: if param_is_dtensor: if shard_spec.sharding_sequence[0] == "R": # Col Parallel - state["exp_avg_sq_row"] = torch.zeros( - grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype - ) # [H/dp] + if grad_shape[0] % self.data_parallel_size != 0: + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0], device=p.device, dtype=p.dtype + ) # [H] + else: + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype + ) # [H/dp] state["exp_avg_sq_col"] = torch.zeros( grad_shape[1], device=p.device, dtype=p.dtype ) # [W/TP] @@ -344,16 +372,16 @@ def step(self, closure=None): # Row indivisible shape situation if grad_shape[0] % self.data_parallel_size != 0: state["exp_avg_sq_row"] = torch.zeros( - grad_shape[0], device=p.device, dtype=p.dtype - ) # [H/dp/Tp] + grad_shape[0], device=p.device, dtype=p.dtype + ) # [H/tp] else: state["exp_avg_sq_row"] = torch.zeros( - grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype - ) # [H/dp/Tp] + grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype + ) # [H/dp/tp] state["exp_avg_sq_col"] = torch.zeros( - grad_shape[1], device=p.device, dtype=p.dtype - ) # [W] + grad_shape[1], device=p.device, dtype=p.dtype + ) # [W] else: if self.use_zero: # param grad [H // dp] From 2dc0341af57b01fb40222c1d0377e92bf762c99c Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 15 Apr 2024 15:59:47 +0800 Subject: [PATCH 24/35] [feature] add check optim state func; --- .../nn/optimizer/distributed_adafactor.py | 26 +++++---- tests/test_optimizer/_utils.py | 56 ++++++++++++++++++- tests/test_optimizer/test_dist_adafactor.py | 26 ++------- 3 files changed, 74 insertions(+), 34 deletions(-) diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index 9331da396a65..ccda1f3e65b1 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -130,15 +130,22 @@ def _get_options(param_group, param_shape): def _rms(tensor, param_is_dtensor, tp_size, dp_size, tp_group, dp_group): tensor_sum = tensor.pow(2).sum() num_of_element = tensor.numel() - # reduce sum on tp group if exist - if tp_size > 1 and param_is_dtensor: + + if param_is_dtensor: + # reduce tensor_sum from tp_group dist.all_reduce(tensor_sum, group=tp_group) num_of_element = num_of_element * tp_size - # reduce sum on dp group if exist - if dp_size > 1 and param_is_dtensor: - dist.all_reduce(tensor_sum, group=dp_group) - num_of_element = num_of_element * dp_size - # div num of element + if dp_size > 1: + dist.all_reduce(tensor_sum, group=dp_group) + num_of_element = num_of_element * dp_size + else: + pass + else: + if dp_size > 1: + dist.all_reduce(tensor_sum, group=dp_group) + num_of_element = num_of_element * dp_size + else: + pass rms = (tensor_sum / num_of_element).sqrt() return rms @@ -344,9 +351,7 @@ def step(self, closure=None): if param_is_dtensor: grad_shape = self.shard_to_param.get(id(p)).shape # tp shape (2 dim) factored, use_first_moment = self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] - - # print(f"factored {factored} param_is_dtensor {param_is_dtensor} shape {grad_shape}") - + shard_spec = self.shard_spec_dict[id(p)] if len(state) == 0: state["step"] = 0 @@ -384,7 +389,6 @@ def step(self, closure=None): ) # [W] else: if self.use_zero: - # param grad [H // dp] if grad_shape[0] % self.data_parallel_size != 0: # save all exp_avg_sq_row [H] state["exp_avg_sq_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype) diff --git a/tests/test_optimizer/_utils.py b/tests/test_optimizer/_utils.py index 7865366e6e0b..49ab3ee4562c 100644 --- a/tests/test_optimizer/_utils.py +++ b/tests/test_optimizer/_utils.py @@ -12,7 +12,7 @@ run_forward_backward_with_hybrid_plugin, unwrap_model, ) - +from colossalai.shardformer.layer._operation import _gather def check_optim_states(org_optim, sharded_optim): for group in org_optim.param_groups: @@ -135,3 +135,57 @@ def _run_bert_test(rank, world_size, port, optim_class, sharded_optim_class): def check_optim_on_bert(optim_class, sharded_optim_class): spawn(_run_bert_test, 4, optim_class, sharded_optim_class) + + +def check_dist_optim_state(org_optimizer, sharded_optimizer): + for group, tp_group in zip(org_optimizer.param_groups, sharded_optimizer.param_groups): + for p, tp in zip(group["params"], tp_group["params"]): + p_state = org_optimizer.state[p] + tp_state = sharded_optimizer.state[tp] + for key in ["exp_avg_sq_col", "exp_avg_sq_row"]: + if key in tp_state.keys() and type(tp_state[key]) is torch.Tensor: + tp_is_dtensor = sharded_optimizer.param_is_dtensor_dict[id(tp)] + use_zero = sharded_optimizer.use_zero + tp_optim_state = tp_state[key] + p_state_shape, tp_state_shape = p_state[key].shape, tp_state[key].shape + # we start init model as first tensor parallel then zero; + # we gather model as first zero then tensor parallel + if use_zero: + # gather from dp group + if p_state_shape != tp_state_shape: + tp_optim_state = _gather( + input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.data_parallel_group + ) + tp_state_shape = tp_optim_state.shape + else: + pass + + # check tp + if tp_is_dtensor: + # gather from tp group + if p_state_shape != tp_state_shape: + tp_optim_state = _gather( + input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tensor_parallel_group + ) + tp_state_shape = tp_optim_state.shape + else: + pass + else: + pass + + else: + # check tp + if tp_is_dtensor: + # gather from tp group + if p_state_shape != tp_state_shape: + tp_optim_state = _gather( + input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tensor_parallel_group + ) + tp_state_shape = tp_optim_state.shape + else: + pass + else: + pass + print(f"{key} \np_state {p_state[key].shape} \ntp_optim_state {tp_state[key].shape} {tp_state_shape}\n") + + assert_close(p_state[key], tp_optim_state, atol=5e-3, rtol=1.6e-2) diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index f93508d7723d..ab6cde370bc3 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -32,7 +32,8 @@ from colossalai.utils import set_seed from colossalai.zero import LowLevelZeroOptimizer from tests.kit.model_zoo import model_zoo -from tests.test_optimizer._utils import run_bert_test, check_optim_states +from tests.test_optimizer._utils import run_bert_test, check_dist_optim_state +from colossalai.shardformer.layer._operation import _gather from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, check_weight, @@ -587,27 +588,8 @@ def exam_bert_test(test_config): atol, rtol = 5e-4, 5e-4 if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) - # check optim states - # print(f"{org_optimizer.param_groups} {sharded_optimizer.optim.param_groups}") - - # for group, tp_group in zip(org_optimizer.param_groups, sharded_optimizer.optim.param_groups): - # for i in range(len(group["params"])): - # p, tp = group["params"], tp_group["params"] - - # sharded_state = sharded_optimizer.optim.state[tp] - # state = org_optimizer.state[p] - - # print(f"sharded_state {sharded_state}\n state {state}") - - # for p in group["params"]: - # sharded_state = sharded_optimizer.optim.state[p] - # state = org_optimizer.state[p] - # print(sharded_state) - # for key in sharded_state: - # print(state[key], sharded_state[key]) - - # check_optim_states(org_optimizer, sharded_optimizer.optim) - # print(f"{name} check pass") + # check optim states + # check_dist_optim_state(org_optimizer, sharded_optimizer.optim) Randomizer.reset_index() torch.cuda.empty_cache() From 3168a595aad05900ed3794e5fae08a849e733a57 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 15 Apr 2024 16:44:55 +0800 Subject: [PATCH 25/35] [feature] Remove duplicate logic; --- .../nn/optimizer/distributed_adafactor.py | 21 ++++++------------- tests/test_optimizer/test_dist_adafactor.py | 4 ++-- 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index ccda1f3e65b1..7bd7c8cafd48 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -90,12 +90,7 @@ def setup_distributed( for group in self.param_groups: for p in group["params"]: self.param_is_dtensor_dict[id(p)] = is_distributed_tensor(self.shard_to_param.get(id(p))) - if self.param_is_dtensor_dict[id(p)]: - self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape - else: - # no tp; could be zero or not zero - # self.grad_shape_dict[id(p)] = p.shape - self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape + self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] = self._get_options(group, self.grad_shape_dict[id(p)]) if self.param_is_dtensor_dict[id(p)]: self.shard_spec_dict[id(p)] = get_sharding_spec(self.shard_to_param.get(id(p))) @@ -181,11 +176,6 @@ def _col_parallel_factor(self, update, grad, state, grad_shape, beta2t): exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) update_reshape.mul_(grad_reshape) - if self.use_zero: - update = update_reshape.view(-1) - else: - update = update_reshape - else: update_reshape = update.view(-1, grad_shape[1]) grad_reshape = grad.view(-1, grad_shape[1]) @@ -197,10 +187,11 @@ def _col_parallel_factor(self, update, grad, state, grad_shape, beta2t): exp_avg_sq_row.div_(self.tensor_parallel_size) update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) update_reshape.mul_(grad_reshape) - if self.use_zero: - update = update_reshape.view(-1) - else: - update = update_reshape + + if self.use_zero: + update = update_reshape.view(-1) + else: + update = update_reshape return update def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t): diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index ab6cde370bc3..221a236ddc10 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -550,8 +550,8 @@ def exam_bert_test(test_config): test_config["initial_scale"] = 2**16 # avoid overflow model_list = [ "transformers_bert", - "transformers_bert_for_pretraining", - "transformers_bert_lm_head_model", + "transformers_bert_for_pretraining", + "transformers_bert_lm_head_model", "transformers_bert_for_masked_lm", "transformers_bert_for_sequence_classification", "transformers_bert_for_token_classification", From 3bca49168ed03599dfad8c1a9eedc583950ff8db Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 15 Apr 2024 17:36:24 +0800 Subject: [PATCH 26/35] [feature] update optim state check func and percision test bug; --- tests/test_optimizer/_utils.py | 97 +++++++++++++++------ tests/test_optimizer/test_dist_adafactor.py | 2 +- 2 files changed, 69 insertions(+), 30 deletions(-) diff --git a/tests/test_optimizer/_utils.py b/tests/test_optimizer/_utils.py index 49ab3ee4562c..6f0bf456559e 100644 --- a/tests/test_optimizer/_utils.py +++ b/tests/test_optimizer/_utils.py @@ -145,47 +145,86 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer): for key in ["exp_avg_sq_col", "exp_avg_sq_row"]: if key in tp_state.keys() and type(tp_state[key]) is torch.Tensor: tp_is_dtensor = sharded_optimizer.param_is_dtensor_dict[id(tp)] + shard_spec = sharded_optimizer.shard_spec_dict[id(tp)] use_zero = sharded_optimizer.use_zero tp_optim_state = tp_state[key] p_state_shape, tp_state_shape = p_state[key].shape, tp_state[key].shape - # we start init model as first tensor parallel then zero; - # we gather model as first zero then tensor parallel - if use_zero: - # gather from dp group - if p_state_shape != tp_state_shape: - tp_optim_state = _gather( - input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.data_parallel_group - ) - tp_state_shape = tp_optim_state.shape - else: - pass - - # check tp - if tp_is_dtensor: + dp_size, tp_size = sharded_optimizer.data_parallel_size, sharded_optimizer.tensor_parallel_size, + # we start init model with first tensor parallel then zero; + # So, we gather model with first zero then tensor parallel + + if tp_is_dtensor: + # col parallel + if shard_spec.sharding_sequence[0] == "R": + if use_zero: + # sq_row need gather alone dp group + if key == "exp_avg_sq_row": + tp_optim_state = _gather( + input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.data_parallel_group + ) + tp_state_shape = tp_optim_state.shape + # sq_col don't need gather alone dp group + if key == "exp_avg_sq_col": + pass + else: + pass + # gather from tp group - if p_state_shape != tp_state_shape: + # sq_row don need gather alone tp group + if key == "exp_avg_sq_row": + pass + # sq_col need gather alone dp group + if key == "exp_avg_sq_col": tp_optim_state = _gather( - input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tensor_parallel_group - ) + input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tensor_parallel_group + ) tp_state_shape = tp_optim_state.shape + + # row parallel + if shard_spec.sharding_sequence[-1] == "R": + if use_zero: + # sq_row need gather alone dp group + if key == "exp_avg_sq_row": + if p_state[key].shape[0] // tp_size % dp_size != 0: + pass + else: + tp_optim_state = _gather( + input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.data_parallel_group + ) + tp_state_shape = tp_optim_state.shape + # sq_col don't need gather alone dp group + if key == "exp_avg_sq_col": + pass else: pass - else: - pass - - else: - # check tp - if tp_is_dtensor: # gather from tp group - if p_state_shape != tp_state_shape: + # sq_row need gather alone tp group + if key == "exp_avg_sq_row": tp_optim_state = _gather( - input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tensor_parallel_group - ) + input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tensor_parallel_group + ) tp_state_shape = tp_optim_state.shape - else: + # sq_col don't need gather alone dp group + if key == "exp_avg_sq_col": + pass + else: + if use_zero: + # sq_row need gather alone dp group + if key == "exp_avg_sq_row": + # row residule; no gather + + if p_state[key].shape[0] % dp_size != 0: + pass + else: + tp_optim_state = _gather( + input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.data_parallel_group + ) + tp_state_shape = tp_optim_state.shape + # sq_col don't need gather alone dp group + if key == "exp_avg_sq_col": pass else: pass - print(f"{key} \np_state {p_state[key].shape} \ntp_optim_state {tp_state[key].shape} {tp_state_shape}\n") + print(f"{key} is_dtensor {tp_is_dtensor} shard_spec {shard_spec} dp_size {dp_size} tp_size {tp_size}\np_state {p_state[key].shape} \ntp_optim_state {tp_state[key].shape} {tp_state_shape}\ndp res {p_state[key].shape[0] // tp_size % dp_size}\n") - assert_close(p_state[key], tp_optim_state, atol=5e-3, rtol=1.6e-2) + # assert_close(p_state[key], tp_optim_state, atol=5e-3, rtol=1.6e-2) diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index 221a236ddc10..17b00f5e68e1 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -589,7 +589,7 @@ def exam_bert_test(test_config): if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) # check optim states - # check_dist_optim_state(org_optimizer, sharded_optimizer.optim) + check_dist_optim_state(org_optimizer, sharded_optimizer.optim) Randomizer.reset_index() torch.cuda.empty_cache() From 1357dd12798f83a45ea854739c7236835bbc38fc Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 16 Apr 2024 15:56:25 +0800 Subject: [PATCH 27/35] [fix] update/fix optim state; Still exist percision issue; --- .../nn/optimizer/distributed_adafactor.py | 14 ++++++-------- tests/test_optimizer/_utils.py | 18 +++++++++++------- tests/test_optimizer/test_dist_adafactor.py | 10 +++++++++- 3 files changed, 26 insertions(+), 16 deletions(-) diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index 7bd7c8cafd48..e1fbfb296001 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -3,7 +3,6 @@ import torch import torch.distributed as dist -# from torch.optim import Optimizer from colossalai.interface.optimizer import DistributedOptim from colossalai.shardformer.layer._operation import _gather, _split @@ -58,7 +57,6 @@ def __init__( self.shard_spec_dict = {} # {id(p): ShardSpec} super().__init__(params, defaults) - def setup_distributed( self, tensor_parallel_group: dist.ProcessGroup = None, @@ -276,6 +274,7 @@ def _base_factor(self, update, grad, state, grad_shape, beta2t): # view update to origin[tp] shape update_reshape = update.view(-1, grad_shape[1]) # [H/dp, W] grad_reshape = grad.view(-1, grad_shape[1]) # [H/dp, W] + # print(f"grad_shape {grad_shape} update shape {update.shape} grad shape {grad.shape}\n update {update}\n") exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp] exp_avg_sq_col = state["exp_avg_sq_col"] # [W] exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) @@ -286,6 +285,7 @@ def _base_factor(self, update, grad, state, grad_shape, beta2t): update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) update_reshape.mul_(grad_reshape) update = update_reshape.view(-1) + # print(f"No res factor exp_avg_sq_col is_dtensor {False} shard_spec {None} use_zero {self.use_zero} dp_size {self.data_parallel_size} tp_size {self.tensor_parallel_size}\n {state['exp_avg_sq_col']}\n") else: # base factor; no tp, no dp exp_avg_sq_row = state["exp_avg_sq_row"] @@ -298,9 +298,7 @@ def _base_factor(self, update, grad, state, grad_shape, beta2t): update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) update.mul_(grad) return update - - - + @torch.no_grad() def step(self, closure=None): """ @@ -400,6 +398,7 @@ def step(self, closure=None): if factored: state["exp_avg_sq_row"] = state["exp_avg_sq_row"] state["exp_avg_sq_col"] = state["exp_avg_sq_col"] + else: state["exp_avg_sq"] = state["exp_avg_sq"] @@ -407,7 +406,7 @@ def step(self, closure=None): lr = self._get_lr(group, state) beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) update = (grad**2) + group["eps"][0] - + if factored: if param_is_dtensor: # ============================== @@ -428,7 +427,7 @@ def step(self, closure=None): exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) update = exp_avg_sq.rsqrt().mul_(grad) - + # # (Line No.8) RMS rms = self._rms(update, param_is_dtensor, self.tensor_parallel_size, self.data_parallel_size, self.tensor_parallel_group, self.data_parallel_group) update.div_((rms / group["clip_threshold"]).clamp_(min=1.0)) @@ -444,5 +443,4 @@ def step(self, closure=None): p.add_(-update) - return loss diff --git a/tests/test_optimizer/_utils.py b/tests/test_optimizer/_utils.py index 6f0bf456559e..8670aa531406 100644 --- a/tests/test_optimizer/_utils.py +++ b/tests/test_optimizer/_utils.py @@ -1,4 +1,5 @@ import torch +import torch.distributed from torch.testing import assert_close import colossalai @@ -142,7 +143,8 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer): for p, tp in zip(group["params"], tp_group["params"]): p_state = org_optimizer.state[p] tp_state = sharded_optimizer.state[tp] - for key in ["exp_avg_sq_col", "exp_avg_sq_row"]: + # TODO "exp_avg_sq_col", "exp_avg_sq_row", "exp_avg_sq" + for key in ["exp_avg_sq_row"]: if key in tp_state.keys() and type(tp_state[key]) is torch.Tensor: tp_is_dtensor = sharded_optimizer.param_is_dtensor_dict[id(tp)] shard_spec = sharded_optimizer.shard_spec_dict[id(tp)] @@ -168,7 +170,6 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer): pass else: pass - # gather from tp group # sq_row don need gather alone tp group if key == "exp_avg_sq_row": @@ -212,7 +213,6 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer): # sq_row need gather alone dp group if key == "exp_avg_sq_row": # row residule; no gather - if p_state[key].shape[0] % dp_size != 0: pass else: @@ -222,9 +222,13 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer): tp_state_shape = tp_optim_state.shape # sq_col don't need gather alone dp group if key == "exp_avg_sq_col": - pass + tp_optim_state = tp_optim_state.div_(dp_size) + # need a div; + # if dp group is [] else: pass - print(f"{key} is_dtensor {tp_is_dtensor} shard_spec {shard_spec} dp_size {dp_size} tp_size {tp_size}\np_state {p_state[key].shape} \ntp_optim_state {tp_state[key].shape} {tp_state_shape}\ndp res {p_state[key].shape[0] // tp_size % dp_size}\n") - - # assert_close(p_state[key], tp_optim_state, atol=5e-3, rtol=1.6e-2) + res = torch.allclose(p_state[key], tp_optim_state, atol=5e-4, rtol=1.6e-2) + # print(f"device {torch.distributed.get_rank(sharded_optimizer.data_parallel_group)} {key} is_dtensor {tp_is_dtensor} shard_spec {shard_spec} use_zero {use_zero} dp_size {dp_size} tp_size {tp_size}\np_state {p_state[key].shape} \ntp_optim_state {tp_state[key].shape} {tp_state_shape}\ndp res {p_state[key].shape[0] // tp_size % dp_size} Close {res}\n") + # if not res: + # print(f"p_state {p_state[key]}\ntp_optim_state {tp_optim_state}\n") + assert_close(p_state[key], tp_optim_state, atol=5e-4, rtol=1.6e-2) diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index 17b00f5e68e1..883d808711af 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -47,6 +47,7 @@ WIDTH = 4 _TP_SPEC = DimSpec([0]) + def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torch.dtype = torch.float32): rtol = None atol = None @@ -63,6 +64,7 @@ def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torc # return torch.all(tensor1.isclose(tensor2, rtol=rtol, atol=atol)) assert_close(tensor1, tensor2, rtol=rtol, atol=atol) + # setup param groups; (For zero test optim) def setup_param_groups_zero(model: nn.Module) -> list: no_decay = ["bias", "LayerNorm.weight"] @@ -78,11 +80,13 @@ def setup_param_groups_zero(model: nn.Module) -> list: ] return optimizer_grouped_parameters + # setup param groups; (For base optim) def setup_param_groups(model: nn.Module) -> list: optimizer_grouped_parameters = [p for n, p in model.named_parameters()] return optimizer_grouped_parameters + # setup flatten param groups, sharding spec and shape; (For dist optim) def setup_flatten_param_groups_sharding_spec_shape(model: nn.Module) -> dict: flatten_optimizer_grouped_parameters = [] @@ -136,10 +140,12 @@ def set_dist_grad( p.grad = p.data p.data = orig_p + def set_master_param_to_shard_param(master_param_list) -> dict: master_param_to_shard_param ={id(p):p for p in master_param_list} return master_param_to_shard_param + class MlpModel(nn.Module): def __init__(self): super(MlpModel, self).__init__() @@ -164,6 +170,7 @@ def forward(self, x): x = self.linear2(x) return x + @parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 @parameterize("tp_zero_size", [(4, 1)]) def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): @@ -276,6 +283,7 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): print(f"Base Test Pass") + @parameterize("dtype", [torch.float16]) # torch.float32, torch.float16, torch.bfloat16 @parameterize("tp_zero_size", [(1, 4)]) # (2, 2), (4, 1), (1, 4) def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): @@ -377,7 +385,6 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): else: # TP bias tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather - else: # No TP bias pass @@ -386,6 +393,7 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): Randomizer.reset_index() torch.cuda.empty_cache() print(f"Zero Test Pass") + @parameterize("dtype", [torch.float16]) @parameterize("tp_zero_size", [(1, 4)]) From 1038b233632c5ad04c40262a7027ba7205a14e65 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 17 Apr 2024 11:10:08 +0800 Subject: [PATCH 28/35] [fix] Add use_zero check in _rms; Add plugin support info in Readme; Add Dist Adafactor init Info; --- colossalai/nn/optimizer/__init__.py | 4 +- .../nn/optimizer/distributed_adafactor.py | 15 ++--- .../en/features/distributed_adafactor.md | 57 ++++++++++++++++--- 3 files changed, 56 insertions(+), 20 deletions(-) diff --git a/colossalai/nn/optimizer/__init__.py b/colossalai/nn/optimizer/__init__.py index 250f9a3b8545..18375947db9d 100644 --- a/colossalai/nn/optimizer/__init__.py +++ b/colossalai/nn/optimizer/__init__.py @@ -6,5 +6,7 @@ from .hybrid_adam import HybridAdam from .lamb import Lamb from .lars import Lars +from .adafactor import Adafactor +from .distributed_adafactor import DistributedAdaFactor -__all__ = ["FusedLAMB", "FusedAdam", "FusedSGD", "Lamb", "Lars", "CPUAdam", "HybridAdam", "DistributedLamb"] +__all__ = ["FusedLAMB", "FusedAdam", "FusedSGD", "Lamb", "Lars", "CPUAdam", "HybridAdam", "DistributedLamb", "Adafactor", "DistributedAdaFactor"] diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index e1fbfb296001..b4b7b86b98d0 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -120,7 +120,7 @@ def _get_options(param_group, param_shape): return factored, use_first_moment @staticmethod - def _rms(tensor, param_is_dtensor, tp_size, dp_size, tp_group, dp_group): + def _rms(tensor, param_is_dtensor, use_zero, tp_size, dp_size, tp_group, dp_group): tensor_sum = tensor.pow(2).sum() num_of_element = tensor.numel() @@ -128,17 +128,13 @@ def _rms(tensor, param_is_dtensor, tp_size, dp_size, tp_group, dp_group): # reduce tensor_sum from tp_group dist.all_reduce(tensor_sum, group=tp_group) num_of_element = num_of_element * tp_size - if dp_size > 1: + if use_zero: dist.all_reduce(tensor_sum, group=dp_group) num_of_element = num_of_element * dp_size - else: - pass else: - if dp_size > 1: + if use_zero: dist.all_reduce(tensor_sum, group=dp_group) num_of_element = num_of_element * dp_size - else: - pass rms = (tensor_sum / num_of_element).sqrt() return rms @@ -274,7 +270,6 @@ def _base_factor(self, update, grad, state, grad_shape, beta2t): # view update to origin[tp] shape update_reshape = update.view(-1, grad_shape[1]) # [H/dp, W] grad_reshape = grad.view(-1, grad_shape[1]) # [H/dp, W] - # print(f"grad_shape {grad_shape} update shape {update.shape} grad shape {grad.shape}\n update {update}\n") exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp] exp_avg_sq_col = state["exp_avg_sq_col"] # [W] exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) @@ -285,7 +280,6 @@ def _base_factor(self, update, grad, state, grad_shape, beta2t): update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) update_reshape.mul_(grad_reshape) update = update_reshape.view(-1) - # print(f"No res factor exp_avg_sq_col is_dtensor {False} shard_spec {None} use_zero {self.use_zero} dp_size {self.data_parallel_size} tp_size {self.tensor_parallel_size}\n {state['exp_avg_sq_col']}\n") else: # base factor; no tp, no dp exp_avg_sq_row = state["exp_avg_sq_row"] @@ -398,7 +392,6 @@ def step(self, closure=None): if factored: state["exp_avg_sq_row"] = state["exp_avg_sq_row"] state["exp_avg_sq_col"] = state["exp_avg_sq_col"] - else: state["exp_avg_sq"] = state["exp_avg_sq"] @@ -429,7 +422,7 @@ def step(self, closure=None): update = exp_avg_sq.rsqrt().mul_(grad) # # (Line No.8) RMS - rms = self._rms(update, param_is_dtensor, self.tensor_parallel_size, self.data_parallel_size, self.tensor_parallel_group, self.data_parallel_group) + rms = self._rms(update, param_is_dtensor, self.use_zero,self.tensor_parallel_size, self.data_parallel_size, self.tensor_parallel_group, self.data_parallel_group) update.div_((rms / group["clip_threshold"]).clamp_(min=1.0)) update.mul_(lr) diff --git a/docs/source/en/features/distributed_adafactor.md b/docs/source/en/features/distributed_adafactor.md index 5a8d8ebade85..fa848851f812 100644 --- a/docs/source/en/features/distributed_adafactor.md +++ b/docs/source/en/features/distributed_adafactor.md @@ -38,12 +38,6 @@ for other initialization methods. We use `ProcessGroupMesh` to create tensor par # Distributed Enviroment config = {} colossalai.launch(config=config, rank=rank, world_size=world_size,host="localhost", port=port, backend="nccl") - -# Parallism Group -tp_size, zero_size = 2, 2 -use_zero = True if zero_size > 1 else False -proc_mesh = ProcessGroupMesh(tp_size, zero_size) -tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1) ``` ### step 3. Initialize Module and Optimizer @@ -95,7 +89,7 @@ Model/Feature Compatibility Matrix: Transformers Bert
For Question Answering - Distributed
Adafactor + Hybrid Parallel
Plugin ✔️ ✔️ ✔️ @@ -106,7 +100,54 @@ Model/Feature Compatibility Matrix: ✔️ ✔️ - + + Low Level Zero
Plugin + ✔️ + ✔️ + ✔️ + ✔️ + ✔️ + ✔️ + ✔️ + ✔️ + ✔️ + + + Torch DDP
Plugin + ✔️ + ✔️ + ✔️ + ✔️ + ✔️ + ✔️ + ✔️ + ✔️ + ✔️ + + + Gemini
Plugin + ❌ + ❌ + ❌ + ❌ + ❌ + ❌ + ❌ + ❌ + ❌ + + + Moe Hybrid
Plugin + ❌ + ❌ + ❌ + ❌ + ❌ + ❌ + ❌ + ❌ + ❌ + From 87746ecf5f2cb781f9dcefda4206084a695e9642 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 17 Apr 2024 11:33:21 +0800 Subject: [PATCH 29/35] [feature] removed print & comments in utils; --- tests/test_optimizer/_utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_optimizer/_utils.py b/tests/test_optimizer/_utils.py index 922900011a45..5475f9b93cce 100644 --- a/tests/test_optimizer/_utils.py +++ b/tests/test_optimizer/_utils.py @@ -224,11 +224,7 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer): if key == "exp_avg_sq_col": tp_optim_state = tp_optim_state.div_(dp_size) # need a div; - # if dp group is [] else: pass res = torch.allclose(p_state[key], tp_optim_state, atol=5e-4, rtol=1.6e-2) - # print(f"device {torch.distributed.get_rank(sharded_optimizer.data_parallel_group)} {key} is_dtensor {tp_is_dtensor} shard_spec {shard_spec} use_zero {use_zero} dp_size {dp_size} tp_size {tp_size}\np_state {p_state[key].shape} \ntp_optim_state {tp_state[key].shape} {tp_state_shape}\ndp res {p_state[key].shape[0] // tp_size % dp_size} Close {res}\n") - # if not res: - # print(f"p_state {p_state[key]}\ntp_optim_state {tp_optim_state}\n") assert_close(p_state[key], tp_optim_state, atol=5e-4, rtol=1.6e-2) From 42b8cf5832126c76087206a61a9c2d87b1addf70 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 17 Apr 2024 14:14:27 +0800 Subject: [PATCH 30/35] [feature] uodate Readme; --- docs/source/en/features/distributed_adafactor.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/features/distributed_adafactor.md b/docs/source/en/features/distributed_adafactor.md index fa848851f812..1abf7df027e7 100644 --- a/docs/source/en/features/distributed_adafactor.md +++ b/docs/source/en/features/distributed_adafactor.md @@ -9,7 +9,7 @@ Author: Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details. -### API Reference +## API Reference {{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} From 510d4c0a94b50680b3420ef2ddfc21af1aad65c4 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 18 Apr 2024 06:40:58 +0000 Subject: [PATCH 31/35] [feature] add LowLevelZeroPlugin test with Bert model zoo; --- .../booster/plugin/low_level_zero_plugin.py | 5 +- .../en/features/distributed_adafactor.md | 3 +- tests/test_optimizer/_utils.py | 15 +++- tests/test_optimizer/test_dist_adafactor.py | 81 ++++++++++++++++--- tests/test_shardformer/test_model/_utils.py | 67 ++++++++++++++- 5 files changed, 156 insertions(+), 15 deletions(-) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 8fc390414484..19faf80b0e81 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -6,6 +6,7 @@ from typing import Callable, Iterator, List, Optional, Tuple import torch +import torch.distributed import torch.nn as nn from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler @@ -328,8 +329,8 @@ def configure( model.update_master_params = MethodType(optimizer.update_master_params, model) # Setup optimizers that require global states if isinstance(optimizer.optim, DistributedOptim): - tp_group = self.tp_group - dp_group = self.dp_group + tp_group = None + dp_group = torch.distributed.distributed_c10d._get_default_group() shard_to_param = optimizer.get_master_to_working_map() is_zero = True optimizer.optim.setup_distributed(tp_group, dp_group, shard_to_param, is_zero) diff --git a/docs/source/en/features/distributed_adafactor.md b/docs/source/en/features/distributed_adafactor.md index 1abf7df027e7..858a91567afd 100644 --- a/docs/source/en/features/distributed_adafactor.md +++ b/docs/source/en/features/distributed_adafactor.md @@ -152,4 +152,5 @@ Model/Feature Compatibility Matrix: - + + diff --git a/tests/test_optimizer/_utils.py b/tests/test_optimizer/_utils.py index 5475f9b93cce..ad6749dd8f0e 100644 --- a/tests/test_optimizer/_utils.py +++ b/tests/test_optimizer/_utils.py @@ -139,6 +139,7 @@ def check_optim_on_bert(optim_class, sharded_optim_class): def check_dist_optim_state(org_optimizer, sharded_optimizer): + torch.set_default_dtype(torch.bfloat16) for group, tp_group in zip(org_optimizer.param_groups, sharded_optimizer.param_groups): for p, tp in zip(group["params"], tp_group["params"]): p_state = org_optimizer.state[p] @@ -226,5 +227,17 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer): # need a div; else: pass - res = torch.allclose(p_state[key], tp_optim_state, atol=5e-4, rtol=1.6e-2) + # Sovled a New issus: different dtype; + # So far, only happen in H100 env; + # Seem torch.set_default_dtype(torch.bfloat16) not act on booster.percision; + # Or assert_close just update to check dtype; + if p_state[key].dtype != tp_optim_state.dtype: + tp_optim_state = tp_optim_state.type(p_state[key].dtype) assert_close(p_state[key], tp_optim_state, atol=5e-4, rtol=1.6e-2) + + +def check_dist_param(org_model, sharded_model, weight_layer_for_check, atol, rtol): + for (org_name, org_param), (sharded_name, sharded_param) in zip(org_model.named_parameters(), sharded_model.named_parameters()): + if org_name in weight_layer_for_check: + # print(f"org_name {org_name} shape {org_param.shape} {org_param}\n sharded_name {sharded_name} shape {sharded_param.shape} {sharded_param}\n") + assert_close(org_param, sharded_param, atol=atol, rtol=rtol) \ No newline at end of file diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index 883d808711af..9c406852da26 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -32,16 +32,18 @@ from colossalai.utils import set_seed from colossalai.zero import LowLevelZeroOptimizer from tests.kit.model_zoo import model_zoo -from tests.test_optimizer._utils import run_bert_test, check_dist_optim_state +from tests.test_optimizer._utils import run_bert_test, check_dist_optim_state, check_dist_param, check_optim_states from colossalai.shardformer.layer._operation import _gather from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, + build_model_from_low_level_zero_plugin, check_weight, run_forward_backward_with_hybrid_plugin, + run_forward_backward_with_low_level_zero_plugin, unwrap_model, ) from colossalai.shardformer.layer.utils import Randomizer - +from colossalai.accelerator import get_accelerator HEIGHT = 4 WIDTH = 4 @@ -514,8 +516,66 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int Randomizer.reset_index() torch.cuda.empty_cache() print(f"Booster Test Pass") - - + +@parameterize( + "test_config", + [ + { + "stage": 1, + "precision": "bf16", + }, + { + "stage": 2, + "precision": "bf16", + }, + ], +) +def exam_bert_test_on_lowlevelzero_plugin(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + model_list = [ + "transformers_bert", + "transformers_bert_for_pretraining", + "transformers_bert_lm_head_model", + "transformers_bert_for_masked_lm", + "transformers_bert_for_sequence_classification", + "transformers_bert_for_token_classification", + "transformers_bert_for_next_sentence", + "transformers_bert_for_mcq", + "transformers_bert_for_question_answering", + ] + clear_layout_converter() + torch.set_default_dtype(torch.bfloat16) + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name in model_list: + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_low_level_zero_plugin( + model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_low_level_zero_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + # LowLevelZero not need warp + # bert = unwrap_model(org_model, "BertModel", "bert") + # sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") + weight_layer_for_check = ["bert.encoder.layer.0.output.dense.weight", "bert.encoder.layer.0.output.dense.weight"] + + org_optimizer.step() + sharded_optimizer.step() + + # check weights + if test_config["precision"] == "bf16": + atol, rtol = 5e-4, 5e-4 + else: + atol, rtol = 5e-4, 5e-4 + + check_dist_param(org_model, sharded_model, weight_layer_for_check, atol, rtol) + check_optim_states(org_optimizer, sharded_optimizer.optim) + + Randomizer.reset_index() + torch.cuda.empty_cache() + print(f"Bert Model Zoo Test Pass") + @parameterize( "test_config", [ @@ -551,7 +611,7 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int }, ], ) -def exam_bert_test(test_config): +def exam_bert_test_on_hybrid_plugin(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") test_config["use_lazy_init"] = False test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel @@ -568,17 +628,17 @@ def exam_bert_test(test_config): "transformers_bert_for_question_answering", ] clear_layout_converter() + torch.set_default_dtype(torch.bfloat16) for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): if name in model_list: - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor - ) - + ) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster ) - + stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group @@ -607,7 +667,8 @@ def run_dist(rank, world_size, port): disable_existing_loggers() config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - exam_bert_test() + exam_bert_test_on_lowlevelzero_plugin() + exam_bert_test_on_hybrid_plugin() exam_dist_adafactor_base() exam_dist_adafactor_zero() exam_dist_adafactor_booster() diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 4719fa0b0546..2eb365c1a007 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -13,7 +13,7 @@ from torch.testing import assert_close from colossalai.booster import Booster -from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule from colossalai.lazy import LazyInitContext from colossalai.pipeline.stage_manager import PipelineStageManager @@ -21,6 +21,7 @@ from colossalai.shardformer._utils import getattr_ from colossalai.shardformer.policies.auto_policy import Policy from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.accelerator import get_accelerator def build_model( @@ -137,6 +138,32 @@ def build_model_from_hybrid_plugin( return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster +def build_model_from_low_level_zero_plugin( + model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any], optim_class=Adam, sharded_optim_class=Adam +): + use_lazy_init = False + if "use_lazy_init" in test_config: + use_lazy_init = test_config.pop("use_lazy_init") + + ctx = LazyInitContext() if use_lazy_init else nullcontext() + with ctx: + org_model = model_fn() + sharded_model = copy.deepcopy(org_model) + if use_lazy_init: + ctx.materialize(org_model) + + org_model = org_model.cuda() + org_optimizer = optim_class(org_model.parameters(), lr=1e-3) + sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3) + criterion = loss_fn + + plugin = LowLevelZeroPlugin(**test_config, max_norm=1.0, initial_scale=2**5) + booster = Booster(plugin=plugin) + + sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion) + return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster + + def run_forward_backward_with_hybrid_plugin( org_model: Module, sharded_model: Module, @@ -195,6 +222,44 @@ def _criterion(outputs, inputs): return org_loss, org_output, sharded_loss, sharded_output +def run_forward_backward_with_low_level_zero_plugin( + org_model: Module, + sharded_model: Module, + sharded_optimizer: Optimizer, + data_gen_fn: Callable, + output_transform_fn: Callable, + criterion: Callable, + booster: Booster, +): + device = get_accelerator().get_current_device() + org_model.cuda() + sharded_model.cuda() + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + + data = data_gen_fn() + + # data = { + # k: v.to(device) if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() + # } + data = {k: v.cuda() for k, v in data.items()} + + sharded_model.train() + sharded_output = sharded_model(**data) + sharded_loss = criterion(sharded_output) + sharded_optimizer.backward(sharded_loss) + + org_model.train() + org_output = org_model(**data) + org_loss = criterion(org_output) + org_loss.backward() + + return org_loss, org_output, sharded_loss, sharded_output + + def check_output_hidden_state( org_output: Tensor, sharded_output: Tensor, From 0a7f68269077347655d5c2bc284320bac29d1116 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 18 Apr 2024 07:54:23 +0000 Subject: [PATCH 32/35] [fix] fix logic in _rms; --- colossalai/nn/optimizer/distributed_adafactor.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index b4b7b86b98d0..cafdcca316e5 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -128,13 +128,9 @@ def _rms(tensor, param_is_dtensor, use_zero, tp_size, dp_size, tp_group, dp_grou # reduce tensor_sum from tp_group dist.all_reduce(tensor_sum, group=tp_group) num_of_element = num_of_element * tp_size - if use_zero: - dist.all_reduce(tensor_sum, group=dp_group) - num_of_element = num_of_element * dp_size - else: - if use_zero: - dist.all_reduce(tensor_sum, group=dp_group) - num_of_element = num_of_element * dp_size + if use_zero: + dist.all_reduce(tensor_sum, group=dp_group) + num_of_element = num_of_element * dp_size rms = (tensor_sum / num_of_element).sqrt() return rms From 40c5f51edbf6c9580c8d7017bcb58bc495120c94 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Apr 2024 07:56:24 +0000 Subject: [PATCH 33/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/nn/optimizer/__init__.py | 8 +- colossalai/nn/optimizer/adafactor.py | 2 +- .../nn/optimizer/distributed_adafactor.py | 120 +++++++-------- .../en/features/distributed_adafactor.md | 18 +-- tests/test_optimizer/_utils.py | 50 ++++--- tests/test_optimizer/test_dist_adafactor.py | 141 ++++++++++-------- tests/test_shardformer/test_model/_utils.py | 8 +- 7 files changed, 187 insertions(+), 160 deletions(-) diff --git a/colossalai/nn/optimizer/__init__.py b/colossalai/nn/optimizer/__init__.py index 87bd08de17d0..155051f04516 100644 --- a/colossalai/nn/optimizer/__init__.py +++ b/colossalai/nn/optimizer/__init__.py @@ -1,5 +1,7 @@ +from .adafactor import Adafactor from .came import CAME from .cpu_adam import CPUAdam +from .distributed_adafactor import DistributedAdaFactor from .distributed_came import DistributedCAME from .distributed_lamb import DistributedLamb from .fused_adam import FusedAdam @@ -8,8 +10,6 @@ from .hybrid_adam import HybridAdam from .lamb import Lamb from .lars import Lars -from .adafactor import Adafactor -from .distributed_adafactor import DistributedAdaFactor __all__ = [ "FusedLAMB", @@ -22,6 +22,6 @@ "DistributedLamb", "CAME", "DistributedCAME", - "Adafactor", - "DistributedAdaFactor" + "Adafactor", + "DistributedAdaFactor", ] diff --git a/colossalai/nn/optimizer/adafactor.py b/colossalai/nn/optimizer/adafactor.py index 57d677ef0059..22a6c8f4d3ce 100644 --- a/colossalai/nn/optimizer/adafactor.py +++ b/colossalai/nn/optimizer/adafactor.py @@ -36,7 +36,7 @@ def __init__( relative_step=True, warmup_init=False, ): - lr=None + lr = None if lr is not None and relative_step: raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") if warmup_init and not relative_step: diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py index cafdcca316e5..d0794f450d8a 100644 --- a/colossalai/nn/optimizer/distributed_adafactor.py +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -3,8 +3,8 @@ import torch import torch.distributed as dist -from colossalai.interface.optimizer import DistributedOptim +from colossalai.interface.optimizer import DistributedOptim from colossalai.shardformer.layer._operation import _gather, _split from colossalai.tensor.d_tensor import get_sharding_spec, is_distributed_tensor @@ -49,14 +49,14 @@ def __init__( self.data_parallel_group = None self.shard_to_param = None # Dict{id:shape}, sample {id(param): torch.tensor} self.use_zero = True - - self.param_is_dtensor_dict = {} # {id(p): True/False} - self.grad_shape_dict = {} # {id(p): master param shape} - self.factored_dict = {} # {id(p): True/False} + + self.param_is_dtensor_dict = {} # {id(p): True/False} + self.grad_shape_dict = {} # {id(p): master param shape} + self.factored_dict = {} # {id(p): True/False} self.use_first_moment_dict = {} # {id(p): True/False} self.shard_spec_dict = {} # {id(p): ShardSpec} super().__init__(params, defaults) - + def setup_distributed( self, tensor_parallel_group: dist.ProcessGroup = None, @@ -82,19 +82,21 @@ def setup_distributed( if self.data_parallel_group is not None: self.data_parallel_size = dist.get_world_size(self.data_parallel_group) self.use_zero = use_zero - + self.shard_to_param = shard_to_param if shard_to_param is not None else {} # grad is None, cause we dont setup now for group in self.param_groups: for p in group["params"]: self.param_is_dtensor_dict[id(p)] = is_distributed_tensor(self.shard_to_param.get(id(p))) - self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape - self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] = self._get_options(group, self.grad_shape_dict[id(p)]) + self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape + self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] = self._get_options( + group, self.grad_shape_dict[id(p)] + ) if self.param_is_dtensor_dict[id(p)]: self.shard_spec_dict[id(p)] = get_sharding_spec(self.shard_to_param.get(id(p))) else: self.shard_spec_dict[id(p)] = None - + @staticmethod def _get_lr(param_group, param_state): rel_step_sz = param_group["lr"] @@ -123,7 +125,7 @@ def _get_options(param_group, param_shape): def _rms(tensor, param_is_dtensor, use_zero, tp_size, dp_size, tp_group, dp_group): tensor_sum = tensor.pow(2).sum() num_of_element = tensor.numel() - + if param_is_dtensor: # reduce tensor_sum from tp_group dist.all_reduce(tensor_sum, group=tp_group) @@ -147,25 +149,21 @@ def _approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam): r_factor = (exp_avg_sq_row / sq_row_meam).rsqrt_().unsqueeze(-1) c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() return torch.mul(r_factor, c_factor) - + def _col_parallel_factor(self, update, grad, state, grad_shape, beta2t): if grad_shape[0] % self.data_parallel_size != 0: # gather update[flatten] along dp group then reshape to [H, W/tp] - update = _gather( - input_=update, dim=-1, process_group=self.data_parallel_group - ) + update = _gather(input_=update, dim=-1, process_group=self.data_parallel_group) update_reshape = update.view(-1, grad_shape[1]) # gather grad[flatten] along dp group then reshape to [H, W/tp] - grad = _gather( - input_=grad, dim=-1, process_group=self.data_parallel_group - ) + grad = _gather(input_=grad, dim=-1, process_group=self.data_parallel_group) grad_reshape = grad.view(-1, grad_shape[1]) exp_avg_sq_row = state["exp_avg_sq_row"] # [H] exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp] exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) - update_reshape.mul_(grad_reshape) + update_reshape.mul_(grad_reshape) else: update_reshape = update.view(-1, grad_shape[1]) grad_reshape = grad.view(-1, grad_shape[1]) @@ -177,25 +175,21 @@ def _col_parallel_factor(self, update, grad, state, grad_shape, beta2t): exp_avg_sq_row.div_(self.tensor_parallel_size) update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) update_reshape.mul_(grad_reshape) - + if self.use_zero: update = update_reshape.view(-1) else: update = update_reshape return update - + def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t): if grad_shape[0] % self.data_parallel_size != 0: # gather update[flatten] along dp group then reshape to [H/tp, W] - update = _gather( - input_=update, dim=-1, process_group=self.data_parallel_group - ) + update = _gather(input_=update, dim=-1, process_group=self.data_parallel_group) # view update to origin[tp] shape update_reshape = update.view(-1, grad_shape[1]) # gather grad[flatten] along dp group then reshape to [H/tp, W] - grad = _gather( - input_=grad, dim=-1, process_group=self.data_parallel_group - ) + grad = _gather(input_=grad, dim=-1, process_group=self.data_parallel_group) grad_reshape = grad.view(-1, grad_shape[1]) exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp] exp_avg_sq_col = state["exp_avg_sq_col"] # [W] @@ -221,9 +215,7 @@ def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t): dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group) exp_avg_sq_col.div_(self.tensor_parallel_size) # gather row - exp_avg_sq_row_gather = _gather( - input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group - ) + exp_avg_sq_row_gather = _gather(input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group) sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True) update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam) update_reshape.mul_(grad_reshape) @@ -232,24 +224,20 @@ def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t): else: update = update_reshape return update - + def _base_factor(self, update, grad, state, grad_shape, beta2t): if self.use_zero: # only zero if grad_shape[0] % self.data_parallel_size != 0: - # view update to origin shape update.view(grad_shape[0]//self.data_parallel_size , grad_shape[1]) - # row mean no change + # view update to origin shape update.view(grad_shape[0]//self.data_parallel_size , grad_shape[1]) + # row mean no change # col mean need reduce and div # gather update[flatten] along dp group then reshape to [H, W] - update = _gather( - input_=update, dim=-1, process_group=self.data_parallel_group - ) + update = _gather(input_=update, dim=-1, process_group=self.data_parallel_group) # view update to origin[tp] shape update_reshape = update.view(-1, grad_shape[1]) # gather grad[flatten] along dp group then reshape to [H, W] - grad = _gather( - input_=grad, dim=-1, process_group=self.data_parallel_group - ) + grad = _gather(input_=grad, dim=-1, process_group=self.data_parallel_group) grad_reshape = grad.view(-1, grad_shape[1]) exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp] exp_avg_sq_col = state["exp_avg_sq_col"] # [W] @@ -264,8 +252,8 @@ def _base_factor(self, update, grad, state, grad_shape, beta2t): else: # no residual row # view update to origin[tp] shape - update_reshape = update.view(-1, grad_shape[1]) # [H/dp, W] - grad_reshape = grad.view(-1, grad_shape[1]) # [H/dp, W] + update_reshape = update.view(-1, grad_shape[1]) # [H/dp, W] + grad_reshape = grad.view(-1, grad_shape[1]) # [H/dp, W] exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp] exp_avg_sq_col = state["exp_avg_sq_col"] # [W] exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) @@ -275,7 +263,7 @@ def _base_factor(self, update, grad, state, grad_shape, beta2t): exp_avg_sq_col.div_(self.tensor_parallel_size) update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) update_reshape.mul_(grad_reshape) - update = update_reshape.view(-1) + update = update_reshape.view(-1) else: # base factor; no tp, no dp exp_avg_sq_row = state["exp_avg_sq_row"] @@ -288,7 +276,7 @@ def _base_factor(self, update, grad, state, grad_shape, beta2t): update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) update.mul_(grad) return update - + @torch.no_grad() def step(self, closure=None): """ @@ -323,7 +311,7 @@ def step(self, closure=None): grad = p.grad if grad.is_sparse: raise RuntimeError("Adafactor does not support sparse gradients.") - + state = self.state[p] grad_shape = self.grad_shape_dict[id(p)] param_is_dtensor = self.param_is_dtensor_dict[id(p)] @@ -343,11 +331,11 @@ def step(self, closure=None): if grad_shape[0] % self.data_parallel_size != 0: state["exp_avg_sq_row"] = torch.zeros( grad_shape[0], device=p.device, dtype=p.dtype - ) # [H] + ) # [H] else: state["exp_avg_sq_row"] = torch.zeros( - grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype - ) # [H/dp] + grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype + ) # [H/dp] state["exp_avg_sq_col"] = torch.zeros( grad_shape[1], device=p.device, dtype=p.dtype ) # [W/TP] @@ -357,23 +345,27 @@ def step(self, closure=None): if grad_shape[0] % self.data_parallel_size != 0: state["exp_avg_sq_row"] = torch.zeros( grad_shape[0], device=p.device, dtype=p.dtype - ) # [H/tp] + ) # [H/tp] else: state["exp_avg_sq_row"] = torch.zeros( grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype - ) # [H/dp/tp] - + ) # [H/dp/tp] + state["exp_avg_sq_col"] = torch.zeros( - grad_shape[1], device=p.device, dtype=p.dtype - ) # [W] + grad_shape[1], device=p.device, dtype=p.dtype + ) # [W] else: if self.use_zero: if grad_shape[0] % self.data_parallel_size != 0: # save all exp_avg_sq_row [H] - state["exp_avg_sq_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype) + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0], device=grad.device, dtype=p.dtype + ) else: # exp_avg_sq_row [H // dp] - state["exp_avg_sq_row"] = torch.zeros(grad_shape[0] // self.data_parallel_size, device=grad.device, dtype=p.dtype) + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0] // self.data_parallel_size, device=grad.device, dtype=p.dtype + ) else: # exp_avg_sq_row [H] state["exp_avg_sq_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype) @@ -395,7 +387,7 @@ def step(self, closure=None): lr = self._get_lr(group, state) beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) update = (grad**2) + group["eps"][0] - + if factored: if param_is_dtensor: # ============================== @@ -411,16 +403,24 @@ def step(self, closure=None): elif shard_spec.sharding_sequence[-1] == "R": update = self._row_parallel_factor(update, grad, state, grad_shape, beta2t) else: - update = self._base_factor(update, grad, state, grad_shape, beta2t) + update = self._base_factor(update, grad, state, grad_shape, beta2t) else: exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) update = exp_avg_sq.rsqrt().mul_(grad) # # (Line No.8) RMS - rms = self._rms(update, param_is_dtensor, self.use_zero,self.tensor_parallel_size, self.data_parallel_size, self.tensor_parallel_group, self.data_parallel_group) + rms = self._rms( + update, + param_is_dtensor, + self.use_zero, + self.tensor_parallel_size, + self.data_parallel_size, + self.tensor_parallel_group, + self.data_parallel_group, + ) update.div_((rms / group["clip_threshold"]).clamp_(min=1.0)) - + update.mul_(lr) if use_first_moment: exp_avg = state["exp_avg"] @@ -429,7 +429,7 @@ def step(self, closure=None): if group["weight_decay"] != 0: p.add_(p, alpha=(-group["weight_decay"] * lr)) - + p.add_(-update) - + return loss diff --git a/docs/source/en/features/distributed_adafactor.md b/docs/source/en/features/distributed_adafactor.md index 858a91567afd..8d3691177ad6 100644 --- a/docs/source/en/features/distributed_adafactor.md +++ b/docs/source/en/features/distributed_adafactor.md @@ -1,20 +1,20 @@ # Distributed Adafactor -Author: +Author: **Related Paper** - [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235) ## Introduction -Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details. +Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details. ## API Reference {{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} ## Hands-On Practice -We now demonstrate how to start Distributed Adafactor with booster API. +We now demonstrate how to start Distributed Adafactor with booster API. ### step 1. Import libraries ```python @@ -59,9 +59,9 @@ dist_optim = DistributedAdaFactor(model.parameters()) ```python plugin = LowLevelZeroPlugin() booster = Booster(plugin=plugin) -model, dist_optim, criterion, dataloader, _ = booster.boost(model, dist_optim, criterion, dataloader) +model, dist_optim, criterion, dataloader, _ = booster.boost(model, dist_optim, criterion, dataloader) ``` -### step 5.Train Your Model +### step 5.Train Your Model ```python for epoch in range(max_epochs): for input_ids, attention_mask in dataloader: @@ -111,7 +111,7 @@ Model/Feature Compatibility Matrix: ✔️ ✔️ ✔️ - + Torch DDP
Plugin ✔️ @@ -123,7 +123,7 @@ Model/Feature Compatibility Matrix: ✔️ ✔️ ✔️ - + Gemini
Plugin ❌ @@ -135,7 +135,7 @@ Model/Feature Compatibility Matrix: ❌ ❌ ❌ - + Moe Hybrid
Plugin ❌ @@ -147,7 +147,7 @@ Model/Feature Compatibility Matrix: ❌ ❌ ❌ - + diff --git a/tests/test_optimizer/_utils.py b/tests/test_optimizer/_utils.py index ad6749dd8f0e..75b57db134ec 100644 --- a/tests/test_optimizer/_utils.py +++ b/tests/test_optimizer/_utils.py @@ -3,6 +3,7 @@ from torch.testing import assert_close import colossalai +from colossalai.shardformer.layer._operation import _gather from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import parameterize, spawn @@ -13,7 +14,7 @@ run_forward_backward_with_hybrid_plugin, unwrap_model, ) -from colossalai.shardformer.layer._operation import _gather + def check_optim_states(org_optim, sharded_optim): for group in org_optim.param_groups: @@ -152,10 +153,13 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer): use_zero = sharded_optimizer.use_zero tp_optim_state = tp_state[key] p_state_shape, tp_state_shape = p_state[key].shape, tp_state[key].shape - dp_size, tp_size = sharded_optimizer.data_parallel_size, sharded_optimizer.tensor_parallel_size, - # we start init model with first tensor parallel then zero; + dp_size, tp_size = ( + sharded_optimizer.data_parallel_size, + sharded_optimizer.tensor_parallel_size, + ) + # we start init model with first tensor parallel then zero; # So, we gather model with first zero then tensor parallel - + if tp_is_dtensor: # col parallel if shard_spec.sharding_sequence[0] == "R": @@ -163,9 +167,11 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer): # sq_row need gather alone dp group if key == "exp_avg_sq_row": tp_optim_state = _gather( - input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.data_parallel_group + input_=tp_optim_state, + dim=-1, + process_group=sharded_optimizer.data_parallel_group, ) - tp_state_shape = tp_optim_state.shape + tp_optim_state.shape # sq_col don't need gather alone dp group if key == "exp_avg_sq_col": pass @@ -180,8 +186,8 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer): tp_optim_state = _gather( input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tensor_parallel_group ) - tp_state_shape = tp_optim_state.shape - + tp_optim_state.shape + # row parallel if shard_spec.sharding_sequence[-1] == "R": if use_zero: @@ -191,9 +197,11 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer): pass else: tp_optim_state = _gather( - input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.data_parallel_group + input_=tp_optim_state, + dim=-1, + process_group=sharded_optimizer.data_parallel_group, ) - tp_state_shape = tp_optim_state.shape + tp_optim_state.shape # sq_col don't need gather alone dp group if key == "exp_avg_sq_col": pass @@ -205,7 +213,7 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer): tp_optim_state = _gather( input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tensor_parallel_group ) - tp_state_shape = tp_optim_state.shape + tp_optim_state.shape # sq_col don't need gather alone dp group if key == "exp_avg_sq_col": pass @@ -218,18 +226,20 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer): pass else: tp_optim_state = _gather( - input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.data_parallel_group + input_=tp_optim_state, + dim=-1, + process_group=sharded_optimizer.data_parallel_group, ) - tp_state_shape = tp_optim_state.shape + tp_optim_state.shape # sq_col don't need gather alone dp group if key == "exp_avg_sq_col": tp_optim_state = tp_optim_state.div_(dp_size) - # need a div; + # need a div; else: pass - # Sovled a New issus: different dtype; - # So far, only happen in H100 env; - # Seem torch.set_default_dtype(torch.bfloat16) not act on booster.percision; + # Sovled a New issus: different dtype; + # So far, only happen in H100 env; + # Seem torch.set_default_dtype(torch.bfloat16) not act on booster.percision; # Or assert_close just update to check dtype; if p_state[key].dtype != tp_optim_state.dtype: tp_optim_state = tp_optim_state.type(p_state[key].dtype) @@ -237,7 +247,9 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer): def check_dist_param(org_model, sharded_model, weight_layer_for_check, atol, rtol): - for (org_name, org_param), (sharded_name, sharded_param) in zip(org_model.named_parameters(), sharded_model.named_parameters()): + for (org_name, org_param), (sharded_name, sharded_param) in zip( + org_model.named_parameters(), sharded_model.named_parameters() + ): if org_name in weight_layer_for_check: # print(f"org_name {org_name} shape {org_param.shape} {org_param}\n sharded_name {sharded_name} shape {sharded_param.shape} {sharded_param}\n") - assert_close(org_param, sharded_param, atol=atol, rtol=rtol) \ No newline at end of file + assert_close(org_param, sharded_param, atol=atol, rtol=rtol) diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index 9c406852da26..4f3359e9a971 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -1,5 +1,4 @@ import copy -import os import pytest import torch @@ -9,19 +8,19 @@ import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import TorchDDPPlugin, HybridParallelPlugin, LowLevelZeroPlugin -from colossalai.logging import disable_existing_loggers +from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.cluster import ProcessGroupMesh -from colossalai.device.device_mesh import DeviceMesh +from colossalai.logging import disable_existing_loggers from colossalai.nn.optimizer.adafactor import Adafactor from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row from colossalai.shardformer.layer._operation import _gather +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor import ( distribute_tensor, + get_device_mesh, get_layout, get_sharding_spec, - get_device_mesh, is_distributed_tensor, shard_colwise, shard_rowwise, @@ -32,8 +31,7 @@ from colossalai.utils import set_seed from colossalai.zero import LowLevelZeroOptimizer from tests.kit.model_zoo import model_zoo -from tests.test_optimizer._utils import run_bert_test, check_dist_optim_state, check_dist_param, check_optim_states -from colossalai.shardformer.layer._operation import _gather +from tests.test_optimizer._utils import check_dist_optim_state, check_dist_param, check_optim_states from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, build_model_from_low_level_zero_plugin, @@ -42,8 +40,6 @@ run_forward_backward_with_low_level_zero_plugin, unwrap_model, ) -from colossalai.shardformer.layer.utils import Randomizer -from colossalai.accelerator import get_accelerator HEIGHT = 4 WIDTH = 4 @@ -144,10 +140,10 @@ def set_dist_grad( def set_master_param_to_shard_param(master_param_list) -> dict: - master_param_to_shard_param ={id(p):p for p in master_param_list} + master_param_to_shard_param = {id(p): p for p in master_param_list} return master_param_to_shard_param - - + + class MlpModel(nn.Module): def __init__(self): super(MlpModel, self).__init__() @@ -159,6 +155,7 @@ def forward(self, x): x = self.linear2(x) return x + class TPModel(nn.Module): def __init__(self, linear1, linear2, tp_group=None): super().__init__() @@ -174,7 +171,7 @@ def forward(self, x): @parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16 -@parameterize("tp_zero_size", [(4, 1)]) +@parameterize("tp_zero_size", [(4, 1)]) def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): tp_size, zero_size = tp_zero_size local_rank = dist.get_rank() @@ -192,7 +189,7 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): H, W = HEIGHT, WIDTH model_col = nn.Linear(H, W).to(local_rank) # Col parallel weight weight, bias = model_col.weight, model_col.bias - + # ============================== # Col Parallel # ============================== @@ -212,7 +209,7 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): weight_row_shard.clone().flatten().requires_grad_(True) ) # flatten input(not dtensor) to optimizer bias_row_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True)) - + # base_param_group = setup_param_groups([weight, bias]) # cp_param_group = setup_param_groups([weight_col_shard_flatten, bias_col_flatten]) # rp_param_group = setup_param_groups([weight_row_shard_flatten, bias_row_flatten]) @@ -225,7 +222,7 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): optimizer_base = Adafactor([weight, bias]) cp_dist_optim = DistributedAdaFactor([weight_col_shard_flatten, bias_col_flatten]) rp_dist_optim = DistributedAdaFactor([weight_row_shard_flatten, bias_row_flatten]) - + shard_to_param_cp = set_master_param_to_shard_param([weight_col_shard_flatten, bias_col_flatten]) cp_dist_optim.setup_distributed( tensor_parallel_group=tp_group, @@ -233,7 +230,7 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): shard_to_param=shard_to_param_cp, use_zero=use_zero, ) - + shard_to_param_rp = set_master_param_to_shard_param([weight_row_shard_flatten, bias_row_flatten]) rp_dist_optim.setup_distributed( tensor_parallel_group=tp_group, @@ -242,7 +239,6 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): use_zero=use_zero, ) - N_STEPS = 1 for _ in range(N_STEPS): # base step @@ -254,7 +250,9 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): # col parallel step cp_dist_optim.zero_grad() weight_col_shard_flatten.grad = ( - distribute_tensor(weight.grad, get_device_mesh(weight_col_shard), weight_col_shard_shard_spec).clone().flatten() + distribute_tensor(weight.grad, get_device_mesh(weight_col_shard), weight_col_shard_shard_spec) + .clone() + .flatten() ) bias_col_flatten.grad = bias.grad.clone().flatten() cp_dist_optim.step() @@ -262,7 +260,9 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): # row parallel step rp_dist_optim.zero_grad() weight_row_shard_flatten.grad = ( - distribute_tensor(weight.grad, get_device_mesh(weight_row_shard), weight_row_shard_shard_spec).clone().flatten() + distribute_tensor(weight.grad, get_device_mesh(weight_row_shard), weight_row_shard_shard_spec) + .clone() + .flatten() ) bias_row_flatten.grad = bias.grad.clone().flatten() rp_dist_optim.step() @@ -273,15 +273,13 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): dim=-1, process_group=tp_group, ) # gather - weight_row_gather = _gather( - input_=weight_row_shard_flatten.data, dim=-1, process_group=tp_group - ).view( + weight_row_gather = _gather(input_=weight_row_shard_flatten.data, dim=-1, process_group=tp_group).view( -1, W ) # gather # verify - col_correct = correctness_verify(weight.data, weight_col_gather.data, dtype) - row_correct = correctness_verify(weight.data, weight_row_gather.data, dtype) + correctness_verify(weight.data, weight_col_gather.data, dtype) + correctness_verify(weight.data, weight_row_gather.data, dtype) print(f"Base Test Pass") @@ -292,7 +290,7 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): tp_size, zero_size = tp_zero_size use_zero = True if zero_size > 1 else False local_rank = dist.get_rank() - + clear_layout_converter() proc_mesh = ProcessGroupMesh(tp_size, zero_size) @@ -390,20 +388,20 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): else: # No TP bias pass - correctness = correctness_verify(p.data, tp_p.data, dtype) + correctness_verify(p.data, tp_p.data, dtype) clear_layout_converter() Randomizer.reset_index() torch.cuda.empty_cache() print(f"Zero Test Pass") - - + + @parameterize("dtype", [torch.float16]) @parameterize("tp_zero_size", [(1, 4)]) def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int]): tp_size, zero_size = tp_zero_size use_zero = True if zero_size > 1 else False local_rank = dist.get_rank() - + clear_layout_converter() proc_mesh = ProcessGroupMesh(tp_size, zero_size) @@ -464,14 +462,14 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int shard_to_param=shard_to_param, use_zero=use_zero, ) - + # ============================== # Booster Init # ============================== plugin = LowLevelZeroPlugin() booster = Booster(plugin=plugin) criterion = lambda x: x.mean() - + tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion) # ============================== @@ -512,11 +510,12 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int else: # No TP bias pass - correctness = correctness_verify(p.data, tp_p.data, dtype) + correctness_verify(p.data, tp_p.data, dtype) Randomizer.reset_index() - torch.cuda.empty_cache() - print(f"Booster Test Pass") - + torch.cuda.empty_cache() + print(f"Booster Test Pass") + + @parameterize( "test_config", [ @@ -529,13 +528,13 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int "precision": "bf16", }, ], -) +) def exam_bert_test_on_lowlevelzero_plugin(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") model_list = [ "transformers_bert", - "transformers_bert_for_pretraining", - "transformers_bert_lm_head_model", + "transformers_bert_for_pretraining", + "transformers_bert_lm_head_model", "transformers_bert_for_masked_lm", "transformers_bert_for_sequence_classification", "transformers_bert_for_token_classification", @@ -547,35 +546,44 @@ def exam_bert_test_on_lowlevelzero_plugin(test_config): torch.set_default_dtype(torch.bfloat16) for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): if name in model_list: - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_low_level_zero_plugin( - model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor - ) - + ( + org_model, + org_optimizer, + sharded_model, + sharded_optimizer, + criterion, + booster, + ) = build_model_from_low_level_zero_plugin(model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_low_level_zero_plugin( org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster ) - + # LowLevelZero not need warp # bert = unwrap_model(org_model, "BertModel", "bert") # sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") - weight_layer_for_check = ["bert.encoder.layer.0.output.dense.weight", "bert.encoder.layer.0.output.dense.weight"] - + weight_layer_for_check = [ + "bert.encoder.layer.0.output.dense.weight", + "bert.encoder.layer.0.output.dense.weight", + ] + org_optimizer.step() sharded_optimizer.step() - + # check weights if test_config["precision"] == "bf16": atol, rtol = 5e-4, 5e-4 else: atol, rtol = 5e-4, 5e-4 - + check_dist_param(org_model, sharded_model, weight_layer_for_check, atol, rtol) - check_optim_states(org_optimizer, sharded_optimizer.optim) + check_optim_states(org_optimizer, sharded_optimizer.optim) Randomizer.reset_index() torch.cuda.empty_cache() - print(f"Bert Model Zoo Test Pass") - + print(f"Bert Model Zoo Test Pass") + + @parameterize( "test_config", [ @@ -618,8 +626,8 @@ def exam_bert_test_on_hybrid_plugin(test_config): test_config["initial_scale"] = 2**16 # avoid overflow model_list = [ "transformers_bert", - "transformers_bert_for_pretraining", - "transformers_bert_lm_head_model", + "transformers_bert_for_pretraining", + "transformers_bert_lm_head_model", "transformers_bert_for_masked_lm", "transformers_bert_for_sequence_classification", "transformers_bert_for_token_classification", @@ -631,24 +639,29 @@ def exam_bert_test_on_hybrid_plugin(test_config): torch.set_default_dtype(torch.bfloat16) for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): if name in model_list: - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( - model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor - ) - + ( + org_model, + org_optimizer, + sharded_model, + sharded_optimizer, + criterion, + booster, + ) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster ) - + stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group bert = unwrap_model(org_model, "BertModel", "bert") sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"] - + org_optimizer.step() sharded_optimizer.step() - + # check weights if test_config["precision"] == "bf16": atol, rtol = 5e-4, 5e-4 @@ -657,12 +670,13 @@ def exam_bert_test_on_hybrid_plugin(test_config): if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) # check optim states - check_dist_optim_state(org_optimizer, sharded_optimizer.optim) + check_dist_optim_state(org_optimizer, sharded_optimizer.optim) Randomizer.reset_index() torch.cuda.empty_cache() - print(f"Bert Model Zoo Test Pass") - + print(f"Bert Model Zoo Test Pass") + + def run_dist(rank, world_size, port): disable_existing_loggers() config = {} @@ -673,6 +687,7 @@ def run_dist(rank, world_size, port): exam_dist_adafactor_zero() exam_dist_adafactor_booster() + @pytest.mark.dist @rerun_if_address_is_in_use() def test_dist_adafactor(): diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 2eb365c1a007..4c46e98f174e 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -12,6 +12,7 @@ from torch.optim import Adam, Optimizer from torch.testing import assert_close +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule @@ -21,7 +22,6 @@ from colossalai.shardformer._utils import getattr_ from colossalai.shardformer.policies.auto_policy import Policy from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor -from colossalai.accelerator import get_accelerator def build_model( @@ -231,7 +231,7 @@ def run_forward_backward_with_low_level_zero_plugin( criterion: Callable, booster: Booster, ): - device = get_accelerator().get_current_device() + get_accelerator().get_current_device() org_model.cuda() sharded_model.cuda() @@ -246,7 +246,7 @@ def _criterion(outputs, inputs): # k: v.to(device) if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() # } data = {k: v.cuda() for k, v in data.items()} - + sharded_model.train() sharded_output = sharded_model(**data) sharded_loss = criterion(sharded_output) @@ -256,7 +256,7 @@ def _criterion(outputs, inputs): org_output = org_model(**data) org_loss = criterion(org_output) org_loss.backward() - + return org_loss, org_output, sharded_loss, sharded_output From 91310e9548d9ade918f357a4978768eea7d02045 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 18 Apr 2024 09:33:46 +0000 Subject: [PATCH 34/35] [fix] remove comments in testcase; --- tests/test_optimizer/test_dist_adafactor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index 4f3359e9a971..237851a90f6c 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -447,7 +447,6 @@ def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int verbose=True, ) shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened - # print(f"shard_to_param {shard_to_param}") dist_optim.optim.setup_distributed( tensor_parallel_group=tp_group, data_parallel_group=dp_group, From 3046daf4f68e2cb5182a9d7985cdb4375288d8c3 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 29 Apr 2024 10:58:45 +0000 Subject: [PATCH 35/35] [feature] add zh-Han Readme; --- .../zh-Hans/features/distributed_adafactor.md | 155 ++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 docs/source/zh-Hans/features/distributed_adafactor.md diff --git a/docs/source/zh-Hans/features/distributed_adafactor.md b/docs/source/zh-Hans/features/distributed_adafactor.md new file mode 100644 index 000000000000..19610a85c8c1 --- /dev/null +++ b/docs/source/zh-Hans/features/distributed_adafactor.md @@ -0,0 +1,155 @@ +# 分布式 Adafactor + +作者: + +**相关论文** +- [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235) + +## 简介 + +分布式 Adafactor 是一种支持混合优化的优化器,包括 1D 张量并行和 ZerO。它通过合理的任务并行化充分利用了计算资源,提高了训练效率和速度,并减少了存储压力。它应用广泛,目前支持一系列基于 Transformer 的模型,详见 [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo). + +## API接口 + +{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} + +## 实例演示 +现在我们演示如何使用 Booster API 启动分布式 Adafactor。 +### 步骤 1. 导入相关库 + +```python +import torch +from torch import nn +import torch.distributed as dist +from transformers import LlamaModel, LlamaConfig + +from colossalai.cluster import ProcessGroupMesh +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row +from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor +from colossal_llama2.dataset.loader import load_tokenized_dataset +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin +``` + +### 步骤 2. 初始化分布式环境和参数 +然后,我们需要初始化分布式环境。为了演示的目的,我们使用了 `colossalai.launch`。您可以参考 [Launch Colossal-AI](../basics/launch_colossalai.md) 获得其他的初始化方法。这里, 我们使用 "ProcessGroupMesh"来创建张量并行组和数据并行组。 + +```python +# Distributed Enviroment +config = {} +colossalai.launch(config=config, rank=rank, world_size=world_size,host="localhost", port=port, backend="nccl") +``` + +### 步骤 3.初始化模块和优化器 +Build our model. We created an MLP using two Linear Layer. + +```python +# Init Llama from huggingface +configuration = LlamaConfig() +model = LlamaModel(configuration) +dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") +dataloader = plugin.prepare_dataloader(dataset, batch_size=8) +criterion = lambda x: x.mean() +dist_optim = DistributedAdaFactor(model.parameters()) + +``` + +### 步骤 4.初始化Booster + +```python +plugin = LowLevelZeroPlugin() +booster = Booster(plugin=plugin) +model, dist_optim, criterion, dataloader, _ = booster.boost(model, dist_optim, criterion, dataloader) +``` +### 步骤 5.训练模型 +```python +for epoch in range(max_epochs): + for input_ids, attention_mask in dataloader: + outputs = model(input_ids.cuda(), attention_mask.cuda()) + loss = criterion(outputs.logits, input_ids) + booster.backward(loss, optimizer) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() +``` + +## 支持信息 +模型/功能兼容性矩阵: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Model/FeatureTransformers
Bert
Transformers Bert
For Pretraining
Transformers Bert
Lm Head Model
Transformers Bert
For Masked Lm
Transformers Bert
For Sequence Classification
Transformers Bert
For Token Classification
Transformers Bert
For Next Sentence
Transformers Bert
For Multiple-choice Question
Transformers Bert
For Question Answering
Hybrid Parallel
Plugin
✔️✔️✔️✔️✔️✔️✔️✔️✔️
Low Level Zero
Plugin
✔️✔️✔️✔️✔️✔️✔️✔️✔️
Torch DDP
Plugin
✔️✔️✔️✔️✔️✔️✔️✔️✔️
Gemini
Plugin
Moe Hybrid
Plugin
+ +