Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[optim] Distributed Adafactor #5484

Merged
merged 36 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
f77ffc2
[feature] solve conflict; update optimizer readme;
duanjunwen Apr 8, 2024
b75ac58
[feature] update optimize readme;
duanjunwen Apr 8, 2024
d5f72fe
[fix] fix testcase;
duanjunwen Apr 8, 2024
020ed54
[feature] Add transformer-bert to testcase;solve a bug related to ind…
duanjunwen Apr 9, 2024
ce58adf
[feature] Add transformers_bert model zoo in testcase;
duanjunwen Apr 9, 2024
efac2a1
[feature] add user documentation to docs/source/feature.
duanjunwen Apr 10, 2024
40a5528
[feature] add API Reference & Sample to optimizer Readme; add state c…
duanjunwen Apr 10, 2024
1c9bb93
[feature] modify user documentation;
duanjunwen Apr 10, 2024
1039f34
[fix] fix readme format issue;
duanjunwen Apr 10, 2024
2ffca49
[fix] add zero=0 in testcase; cached augment in dict;
duanjunwen Apr 10, 2024
0fd62a0
[fix] fix percision issue;
duanjunwen Apr 11, 2024
28c3a40
[feature] add distributed rms;
duanjunwen Apr 11, 2024
a9c5bf7
[feature] remove useless comment in testcase;
duanjunwen Apr 11, 2024
150ac19
[fix] Remove useless test; open zero test; remove fp16 test in bert e…
duanjunwen Apr 11, 2024
e783599
[feature] Extract distributed rms function;
duanjunwen Apr 11, 2024
9d33a34
[feature] add booster + lowlevelzeroPlugin in test;
duanjunwen Apr 11, 2024
419c1c0
[feature] add Start_with_booster_API case in md; add Supporting Infor…
duanjunwen Apr 12, 2024
c84fb52
[fix] Also remove state movement in base adafactor;
duanjunwen Apr 12, 2024
2eb069d
[feature] extract factor function;
duanjunwen Apr 12, 2024
6303291
[feature] add LowLevelZeroPlugin test;
duanjunwen Apr 12, 2024
60489ab
[fix] add tp=False and zero=True in logic;
duanjunwen Apr 12, 2024
02ea83e
[fix] fix use zero logic;
duanjunwen Apr 13, 2024
fb14125
[feature] add row residue logic in column parallel factor;
duanjunwen Apr 14, 2024
2dc0341
[feature] add check optim state func;
duanjunwen Apr 15, 2024
3168a59
[feature] Remove duplicate logic;
duanjunwen Apr 15, 2024
3bca491
[feature] update optim state check func and percision test bug;
duanjunwen Apr 15, 2024
1357dd1
[fix] update/fix optim state; Still exist percision issue;
duanjunwen Apr 16, 2024
1038b23
[fix] Add use_zero check in _rms; Add plugin support info in Readme; …
duanjunwen Apr 17, 2024
2c92350
Merge branch 'feature/dist-optim-upstream' into dist_adafactor
duanjunwen Apr 17, 2024
87746ec
[feature] removed print & comments in utils;
duanjunwen Apr 17, 2024
42b8cf5
[feature] uodate Readme;
duanjunwen Apr 17, 2024
510d4c0
[feature] add LowLevelZeroPlugin test with Bert model zoo;
duanjunwen Apr 18, 2024
0a7f682
[fix] fix logic in _rms;
duanjunwen Apr 18, 2024
40c5f51
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 18, 2024
91310e9
[fix] remove comments in testcase;
duanjunwen Apr 18, 2024
3046daf
[feature] add zh-Han Readme;
duanjunwen Apr 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Callable, Iterator, List, Optional, Tuple

import torch
import torch.distributed
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
Expand Down Expand Up @@ -328,8 +329,8 @@ def configure(
model.update_master_params = MethodType(optimizer.update_master_params, model)
# Setup optimizers that require global states
if isinstance(optimizer.optim, DistributedOptim):
tp_group = self.tp_group
dp_group = self.dp_group
tp_group = None
dp_group = torch.distributed.distributed_c10d._get_default_group()
shard_to_param = optimizer.get_master_to_working_map()
is_zero = True
optimizer.optim.setup_distributed(tp_group, dp_group, shard_to_param, is_zero)
Expand Down
4 changes: 4 additions & 0 deletions colossalai/nn/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .adafactor import Adafactor
from .came import CAME
from .cpu_adam import CPUAdam
from .distributed_adafactor import DistributedAdaFactor
from .distributed_came import DistributedCAME
from .distributed_lamb import DistributedLamb
from .fused_adam import FusedAdam
Expand All @@ -20,4 +22,6 @@
"DistributedLamb",
"CAME",
"DistributedCAME",
"Adafactor",
"DistributedAdaFactor",
]
201 changes: 201 additions & 0 deletions colossalai/nn/optimizer/adafactor.py
ver217 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import torch
from torch.optim import Optimizer

__all__ = ["Adafactor"]


# Adafactor
class Adafactor(Optimizer):
def __init__(
self,
params,
lr=None,
eps=(1e-30, 1e-3),
clip_threshold=1.0,
decay_rate=-0.8,
beta1=None,
weight_decay=0.0,
scale_parameter=True,
relative_step=True,
warmup_init=False,
):
lr = None
if lr is not None and relative_step:
raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
if warmup_init and not relative_step:
raise ValueError("`warmup_init=True` requires `relative_step=True`")

defaults = {
"lr": lr,
"eps": eps,
"clip_threshold": clip_threshold,
"decay_rate": decay_rate,
"beta1": beta1,
"weight_decay": weight_decay,
"scale_parameter": scale_parameter,
"relative_step": relative_step,
"warmup_init": warmup_init,
}
super().__init__(params, defaults)

@staticmethod
def _get_lr(param_group, param_state):
rel_step_sz = param_group["lr"]
if param_group["relative_step"]:
min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
param_scale = 1.0
if param_group["scale_parameter"]:
param_scale = max(param_group["eps"][1], param_state["RMS"])
return param_scale * rel_step_sz
ver217 marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def _get_options(param_group, param_shape):
factored = len(param_shape) >= 2
use_first_moment = param_group["beta1"] is not None
return factored, use_first_moment

@staticmethod
def _rms(tensor):
return tensor.norm(2) / (tensor.numel() ** 0.5)

@staticmethod
def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor)

@torch.no_grad()
def step(self, closure=None):
"""
Performs a single optimization step

Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

"""
param_groups: Dict
{
"params":[weight, bias]
"lr"
"eps"
"clip_threshold"
"decay_rate"
"beta1"
"weight_decay"
"scale_parameter"
"relative_step"
"warmup_init"
}
"""

for group in self.param_groups:
# update weight & bias
for p in group["params"]:
if p.grad is None:
continue
"""
# grad shape is same as weigh / bias
"""
grad = p.grad
if grad.is_sparse:
raise RuntimeError("Adafactor does not support sparse gradients.")

"""
p is weight
state
{'step',
'exp_avg_sq_row',
'exp_avg_sq_col',
'RMS'
}

p is bias
state
{'step',
'exp_avg_sq',
'RMS'
}
"""

state = self.state[p]
grad_shape = grad.shape

factored, use_first_moment = self._get_options(group, grad_shape)
# State Initialization
if len(state) == 0:
state["step"] = 0
if use_first_moment:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad)
if factored:
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1], device=grad.device)
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:], device=grad.device)
else:
state["exp_avg_sq"] = torch.zeros_like(grad)

state["RMS"] = 0
else:
if use_first_moment:
state["exp_avg"] = state["exp_avg"]
if factored:
state["exp_avg_sq_row"] = state["exp_avg_sq_row"]
state["exp_avg_sq_col"] = state["exp_avg_sq_col"]
else:
state["exp_avg_sq"] = state["exp_avg_sq"]

state["step"] += 1
# state["RMS"] = self._rms(p_data_fp32)
lr = self._get_lr(group, state)
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
update = (grad**2) + group["eps"][0]
if factored:
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]
# Exponential average of row indexes
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
# Exponential average of columns indexes
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
# Approximation of exponential moving average of square of gradient
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
else:
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
update = exp_avg_sq.rsqrt().mul_(grad)
# RMS
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
update.mul_(lr)

if use_first_moment:
exp_avg = state["exp_avg"]
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
update = exp_avg

if group["weight_decay"] != 0:
p.add_(p, alpha=(-group["weight_decay"] * lr))
p.add_(-update)

return loss
Loading
Loading