From 0a51319113d57f69555454facfe4e01b09915648 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 16 Aug 2024 10:13:07 +0800 Subject: [PATCH] [fp8] zero support fp8 linear. (#6006) * fix * fix * fix * zero fp8 * zero fp8 * Update requirements.txt --- .../booster/plugin/low_level_zero_plugin.py | 19 ++++++++++++++++--- examples/language/llama/benchmark.py | 1 - 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 64f264f7eba1..63d46f6f8a2e 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -35,6 +35,7 @@ from colossalai.interface.optimizer import DistributedOptim from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.quantization import BnbQuantizationConfig, quantize_model +from colossalai.quantization.fp8_hook import FP8Hook from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.zero import LowLevelZeroOptimizer @@ -62,7 +63,9 @@ class OptimizerParamCheckState(enum.Enum): class LowLevelZeroModel(ModelWrapper, AMPModelMixin): - def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None: + def __init__( + self, module: nn.Module, precision: str, overlap_allgather: bool = False, use_fp8: bool = False + ) -> None: super().__init__(module) self.dtype = None if precision == "fp16": @@ -74,11 +77,16 @@ def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = module = module.to(get_accelerator().get_current_device()) self.module = module self.convert_fn = None + self.use_fp8 = use_fp8 if self.dtype is not None: self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) self.overlap_allgather = overlap_allgather + self.op_hooks = [] if overlap_allgather: - self.op_hook = ZeroOpHook() + self.op_hooks.append(ZeroOpHook()) + if use_fp8: + self.op_hooks.append(FP8Hook()) + if overlap_allgather or use_fp8: for p in module.parameters(): if p.requires_grad and type(p) is not ColoParameter: p.__class__ = ColoParameter @@ -335,6 +343,7 @@ def __init__( master_weights: bool = True, verbose: bool = False, fp8_communication: bool = False, + use_fp8: bool = False, ) -> None: super().__init__() assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" @@ -362,6 +371,7 @@ def __init__( ) self.lora_enabled = False self.verbose = verbose + self.use_fp8 = use_fp8 # set class name with stage, for better error message setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") @@ -476,7 +486,10 @@ def configure( if not isinstance(model, ModelWrapper): model = LowLevelZeroModel( - model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"] + model, + self.precision, + overlap_allgather=self.zero_optim_kwargs["overlap_allgather"], + use_fp8=self.use_fp8, ) # TODO: Support Galore + ZeRO diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 07583161b6fb..21d081145cd9 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -259,7 +259,6 @@ def empty_init(): if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) else nullcontext() ) - init_kwargs = {} if config.model_type == "chatglm": init_kwargs["empty_init"] = False