From d63479553caff2e69441733c840064e3df378e05 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Sun, 29 Sep 2024 08:33:55 +0000 Subject: [PATCH 1/5] [feat] zerobubble support moehybridplugin; --- .../naive_amp/mixed_precision_mixin/base.py | 2 +- .../naive_amp/mixed_precision_optimizer.py | 13 +- .../booster/mixed_precision/fp16_torch.py | 4 +- .../booster/plugin/hybrid_parallel_plugin.py | 63 ++-- .../plugin/moe_hybrid_parallel_plugin.py | 17 +- colossalai/pipeline/stage_manager.py | 6 +- colossalai/shardformer/policies/mixtral.py | 28 +- .../test_schedule/test_zerobubble_pp.py | 269 +++++++++++------- 8 files changed, 250 insertions(+), 152 deletions(-) diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py index fc7e0b74179a..b2ba47f6762d 100644 --- a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py @@ -43,7 +43,7 @@ def zero_grad(self): dtype: torch.dtype @abstractmethod - def pre_backward(self, loss: Tensor) -> Tensor: + def pre_backward(self, loss: Tensor, *args, **kwargs) -> Tensor: """Called before backward. Args: diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py index 9e07bdebf8fa..8fb56aee4fce 100644 --- a/colossalai/amp/naive_amp/mixed_precision_optimizer.py +++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py @@ -85,13 +85,18 @@ def __init__( master_params.append(master_p) group["params"] = master_params - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): loss = self.mixed_precision.pre_backward(loss) - loss.backward(*args, **kwargs) + loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs) - def backward_by_grad(self, tensor: Tensor, grad: Tensor): + def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): grad = self.mixed_precision.pre_backward_by_grad(tensor, grad) - tensor.backward(grad) + torch.autograd.backward( + tensors=tensor, + grad_tensors=grad, + inputs=inputs, + retain_graph=retain_graph, + ) def zero_grad(self, *args, **kwargs): for p in self.working_to_master_map.keys(): diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py index c757a878d97a..a85d9f808546 100644 --- a/colossalai/booster/mixed_precision/fp16_torch.py +++ b/colossalai/booster/mixed_precision/fp16_torch.py @@ -46,9 +46,9 @@ def __init__( growth_interval=growth_interval, ) - def backward(self, loss: Tensor, *args, **kwargs) -> None: + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs) -> None: scaled_loss = self.scale_loss(loss) - scaled_loss.backward(*args, **kwargs) + scaled_loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs) def step(self, *args, **kwargs) -> Optional[float]: out = self.scaler.step(self.optim, *args, **kwargs) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 1b3b765c2ff0..5d114ab9c315 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -28,7 +28,7 @@ from colossalai.interface.optimizer import DistributedOptim from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed -from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule +from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer @@ -288,7 +288,7 @@ def __init__( self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 super().__init__(optim) - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): r""" Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. @@ -306,7 +306,7 @@ def backward(self, loss: Tensor, *args, **kwargs): """ # Call the superclass backward method to compute gradients. - super().backward(loss, *args, **kwargs) + super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -315,7 +315,7 @@ def backward(self, loss: Tensor, *args, **kwargs): # If gradient synchronization is is not required, return. return - def backward_by_grad(self, tensor: Tensor, grad: Tensor): + def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): """ Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. @@ -332,7 +332,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor): """ # Call the superclass backward method to compute gradients. - super().backward_by_grad(tensor, grad) + super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -512,7 +512,7 @@ def __init__( max_norm=max_norm, ) - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): r""" Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. @@ -529,7 +529,7 @@ def backward(self, loss: Tensor, *args, **kwargs): None """ # Call the superclass backward method to compute gradients. - super().backward(loss, *args, **kwargs) + super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -538,7 +538,7 @@ def backward(self, loss: Tensor, *args, **kwargs): # If gradient synchronization is is not required, return. return - def backward_by_grad(self, tensor: Tensor, grad: Tensor): + def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): """ Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. @@ -554,7 +554,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor): None """ # Call the superclass backward method to compute gradients. - super().backward_by_grad(tensor, grad) + super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -768,7 +768,7 @@ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]: else: return - def backward(self, loss, retain_graph=False): + def backward(self, loss, inputs=None, retain_graph=False): """ Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. @@ -784,7 +784,7 @@ def backward(self, loss, retain_graph=False): None """ # Call the superclass backward method to compute gradients. - super().backward(loss, retain_graph) + super().backward(loss, inputs=inputs, retain_graph=retain_graph) if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -793,7 +793,7 @@ def backward(self, loss, retain_graph=False): # If gradient synchronization is is not required, return. return - def backward_by_grad(self, tensor, grad): + def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False): """ Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. @@ -809,7 +809,7 @@ def backward_by_grad(self, tensor, grad): None """ # Call the superclass backward_by_grad method to compute gradients. - super().backward_by_grad(tensor, grad) + super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph) if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -1013,6 +1013,7 @@ def __init__( custom_policy: Policy = None, pp_style: str = "1f1b", num_model_chunks: int = 1, + scheduler_nodes: List = None, num_layers_per_stage: Optional[List[int]] = None, gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, enable_metadata_cache: bool = True, @@ -1029,6 +1030,9 @@ def __init__( dist.get_world_size() % (tp_size * pp_size) == 0 ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" + assert ( + not pp_style == "zbv" or scheduler_nodes is not None + ), f"scheduler_nodes must not be None when using zero bubble pipeline." if enable_sequence_parallelism: self.sequence_parallelism_mode = ( sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" @@ -1088,29 +1092,39 @@ def __init__( self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) self.stage_manager = None - self.schedule = None + self.scheduler = None self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: - assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" - assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" + assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style" + assert ( + pp_style in ["interleaved", "zbv"] or num_model_chunks == 1 + ), "num_model_chunks must be 1 when using 1f1b" + assert ( + pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2 + ), "num_model_chunks must be 2 when using zero bubble pipeline" assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" assert ( self.zero_stage <= 1 ), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism" + if pp_style == "zbv": + self.logger.warning( + """the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})""" + ) self.stage_manager = PipelineStageManager( self.pg_mesh, pipeline_axis=self.pp_axis, - enable_interleave=(pp_style == "interleaved"), + enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"), + use_zbv=(pp_style == "zbv"), num_model_chunks=num_model_chunks, num_layers_per_stage=num_layers_per_stage, ) if pp_style == "interleaved": assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" - self.schedule = InterleavedSchedule( + self.scheduler = InterleavedSchedule( stage_manager=self.stage_manager, num_model_chunks=num_model_chunks, num_microbatch=num_microbatches, @@ -1119,12 +1133,20 @@ def __init__( overlap_p2p=overlap_p2p, ) elif pp_style == "1f1b": - self.schedule = OneForwardOneBackwardSchedule( + self.scheduler = OneForwardOneBackwardSchedule( stage_manager=self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size, enable_metadata_cache=enable_metadata_cache, ) + elif pp_style == "zbv": + self.scheduler = ZeroBubbleVPipeScheduler( + stage_manager=self.stage_manager, + schedule=scheduler_nodes, + num_model_chunks=num_model_chunks, + num_microbatch=num_microbatches, + microbatch_size=microbatch_size, + ) else: raise NotImplementedError() if sequence_parallelism_mode == "ring_attn": @@ -1236,7 +1258,6 @@ def configure( # Replace with distributed implementation if exists optimizer = cast_to_distributed(optimizer) - if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: self.logger.warning( "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.", @@ -1352,7 +1373,7 @@ def execute_pipeline( ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() with ctx, model._wait_all_gather(): - outputs = self.schedule.forward_backward_step( + outputs = self.scheduler.forward_backward_step( model, data_iter, criterion, optimizer, return_loss, return_outputs ) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 56405ed47e00..23331c2819b6 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -280,14 +280,17 @@ def __init__( self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size) self.stage_manager = None - self.schedule = None + self.scheduler = None self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style" assert ( - pp_style == "interleaved" or pp_style == "zbv" - ) or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" + pp_style in ["interleaved", "zbv"] or num_model_chunks == 1 + ), "num_model_chunks must be 1 when using 1f1b" + assert ( + pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2 + ), "num_model_chunks must be 2 when using zero bubble pipeline" assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" @@ -300,11 +303,12 @@ def __init__( enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"), num_model_chunks=num_model_chunks, num_layers_per_stage=num_layers_per_stage, + use_zbv=(pp_style == "zbv"), ) if pp_style == "interleaved": assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" - self.schedule = InterleavedSchedule( + self.scheduler = InterleavedSchedule( stage_manager=self.stage_manager, num_model_chunks=num_model_chunks, num_microbatch=num_microbatches, @@ -313,14 +317,15 @@ def __init__( overlap_p2p=overlap_p2p, ) elif pp_style == "1f1b": - self.schedule = OneForwardOneBackwardSchedule( + self.scheduler = OneForwardOneBackwardSchedule( stage_manager=self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size, enable_metadata_cache=enable_metadata_cache, ) elif pp_style == "zbv": - self.schedule = ZeroBubbleVPipeScheduler( + assert num_model_chunks > 1, "number of model chunks must be > 1 when using ZerbubbleV" + self.scheduler = ZeroBubbleVPipeScheduler( schedule=scheduler_nodes, stage_manager=self.stage_manager, num_model_chunks=num_model_chunks, diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 50cc965bb9c3..5cc32114daff 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -136,7 +136,11 @@ def is_last_stage(self, ignore_chunk: bool = False) -> bool: if not self.is_interleave or ignore_chunk: return self.stage == self.num_stages - 1 else: - return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1 + # use zero bubble pipeline + if self.use_zbv: + return self.stage == 0 and self.model_chunk_id == self.num_model_chunks - 1 + else: + return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1 @property def num_stages(self) -> int: diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index e11edae9f5e3..053e751906e2 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -234,14 +234,28 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = stage_manager.distribute_layers(len(module.layers)) - if stage_manager.is_first_stage(): - held_layers.append(module.embed_tokens) - start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) - held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.norm) + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + stage_manager.stage_indices = stage_indices + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.embed_tokens) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.norm) + elif stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(module.norm) + else: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) return held_layers diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 0f2d6c49c749..ba6cafe6bbd4 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -7,17 +7,28 @@ import torch.distributed as dist import torch.nn as nn from torch.testing import assert_close +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralModel import colossalai +from colossalai.booster.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import ProcessGroupMesh from colossalai.interface import OptimizerWrapper from colossalai.logging import disable_existing_loggers from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo +from colossalai.testing.random import seed_all +from tests.test_moe.moe_utils import assert_loose_close + +NUM_BATCH = 8 +NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4 +NUM_LAYERS = 8 +HIDDEN_SIZE_PER_HEAD = 4 +NUM_HEADS = 4 +TOP_K = 1 class MlpModel(nn.Module): @@ -730,127 +741,165 @@ def criterion_base(x, *args, **kwargs): assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups) -# TODO:4) support Hybrid base 3) +# TODO:3) support booster & Hybrid base 2) def run_with_hybridplugin(test_config): pass -# TODO:5) support MoEHybrid base 3) -@parameterize( - "test_config", - [ - { - "pp_style": "zbv", - "tp_size": 1, - "ep_size": 1, - "pp_size": 4, - "num_microbatches": 4, - "zero_stage": 1, - "precision": "bf16", - "num_model_chunks": 2, - }, - ], -) -def run_with_moehybridplugin(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") - # test_config["use_lazy_init"] = False - test_config["initial_scale"] = 2**16 - model_list = [ - "transformers_bert", - ] - clear_layout_converter() - - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name in model_list: - # base param - model = model_fn() - data = data_gen_fn() - print(f"data {data}") - criterion = loss_fn - optimizer = torch.optim.SGD(model.parameters(), momentum=0.1, lr=1e-5) - - output = model(**data) - loss = criterion(output) - loss.backward() - optimizer.step() - print(f"output {output}") - - # # pp param - # model_pp = deepcopy(model) - # data_pp = deepcopy(data) - # optimizer_pp = OptimizerWrapper(torch.optim.SGD(model_pp.parameters(), momentum=0.1, lr=1e-5)) - - # # init pipeline graph - # h, a, s = model.config.hidden_size, model.config.num_attention_heads, 1024 - # mem_f = 34 * h + 5 * a * s - # mem_w = -32 * h - # mem_b = -mem_w - mem_f - # graph = PipelineGraph( - # n_stage=test_config["pp_size"], - # n_micro=test_config["num_microbatches"], - # f_cost=1, - # b_cost=1, - # w_cost=1, - # c_cost=1, - # f_mem=mem_f, - # b_mem=mem_b, - # w_mem=mem_w, - # # max_mem=mem_f * (p * 2 + m_offset), - # ) - - # zbv_schedule = graph.get_v_schedule() - - # test_config["scheduler_nodes"] = zbv_schedule - # plugin = MoeHybridParallelPlugin( - # **test_config - # ) - # model_pp, optimizer_pp, criterion, data_pp, _ = plugin.configure( - # model = model_pp, - # optimizer = optimizer_pp, - # criterion = criterion, - # dataloader = data_pp, - # ) - - # output_pp = plugin.execute_pipeline( - # data_iter=iter(data), - # model=model, - # criterion=criterion, - # optimizer=optimizer, - # return_loss = True, - # return_outputs = True, - # ) - - -# TODO:6) support booster & Hybrid base 4) - - -# TODO:7) support booster & MoEHybrid base 4) +# TODO:4) support booster & MoEHybrid base 2) @parameterize( - "test_config", + "config", [ - { - "pp_style": "zbv", - "tp_size": 1, - "ep_size": 1, - "pp_size": 4, - "num_microbatches": 4, - "zero_stage": 1, - "precision": "bf16", - "num_model_chunks": 2, - }, + (0, 1, 4, 1, 1), + # (0, 2, 2, 1, 1), + # (0, 2, 1, 2, 1), + # (0, 2, 1, 1, 2), ], ) -def run_with_booster_moehybridplugin(test_config): - pass +def run_with_booster_moehybridplugin(config: Tuple[int, ...]): + stage, ep_size, pp_size, tp_size, sp_size = config + num_microbatches = pp_size + dist.get_world_size() + rank = dist.get_rank() + dtype, precision = torch.float16, "fp16" + torch.cuda.set_device(dist.get_rank()) + + ######## + # init base model + ######## + assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS" + config = MixtralConfig( + hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, + intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + num_hidden_layers=NUM_LAYERS, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_HEADS, + num_local_experts=NUM_EXPERTS, + num_experts_per_tok=TOP_K, + attn_implementation="flash_attention_2", + ) + + # init model with the same seed + seed_all(10086) + + torch_model = MixtralModel(config).to(dtype).cuda() + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) + # init schedule + h, a, s = config.hidden_size, config.num_attention_heads, 1024 + mem_f = 34 * h + 5 * a * s + mem_w = -32 * h + mem_b = -mem_w - mem_f + graph = PipelineGraph( + n_stage=pp_size, + n_micro=num_microbatches, + f_cost=1, + b_cost=1, + w_cost=1, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, + # max_mem=mem_f * (p * 2 + m_offset), + ) + + zbv_schedule = graph.get_v_schedule() + + # init MoeHybridPlugin + plugin = MoeHybridParallelPlugin( + pp_size=pp_size, + num_microbatches=pp_size, + tp_size=tp_size, + sp_size=sp_size, + ep_size=ep_size, + zero_stage=stage, + enable_sequence_parallelism=sp_size > 1, + sequence_parallelism_mode="all_to_all" if sp_size > 1 else None, + overlap_communication=False, + initial_scale=1, + precision=precision, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, + ) + + dp_size = plugin.dp_size + + booster = Booster(plugin=plugin) + + ######## + # init pp model + ######## + + parallel_model = deepcopy(torch_model) + parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1) + parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer) + # create different input along dp axis + seed_all(1453 + rank) + + torch_model.train() + parallel_model.train() + for _ in range(2): + # gen random input + input_embeddings = torch.rand( + NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True + ).cuda() + dist.all_reduce( + input_embeddings, group=plugin.pp_group + ) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check + + dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input + dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input + + # run the model with hybrid parallel + if booster.plugin.stage_manager is not None: + # for test with pp + data_iter = iter([{"inputs_embeds": input_embeddings}]) + sharded_output = booster.execute_pipeline( + data_iter, + parallel_model, + lambda x, y: x.last_hidden_state.mean(), + parallel_optimizer, + return_loss=True, + return_outputs=True, + ) + # stage 0 chunk 0 + parallel_output = None + if rank == dist.get_process_group_ranks(plugin.pp_group)[0]: + parallel_output = sharded_output["loss"] + + else: + # for test without pp + parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean() + parallel_optimizer.backward(parallel_output) + parallel_optimizer.step() + parallel_optimizer.zero_grad() + # dist.all_reduce(parallel_output, group=plugin.dp_group) + + # =================================================================================== + # run normal model with all dp(different) inputs + all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)] + dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) + torch_output_sum = 0 + for input_data_ in all_inputs: + torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() + torch_output.backward() + torch_output_sum += torch_output.detach() + # avg dp grads follows zero optimizer + for p in torch_model.parameters(): + if p.grad is not None: + p.grad /= dp_size + torch_optimizer.step() + torch_optimizer.zero_grad() + if rank == dist.get_process_group_ranks(plugin.pp_group)[0]: + assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) def run_dist(rank, world_size, port): disable_existing_loggers() colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - # run_fwd_bwd_iter_input() - run_fwd_bwd_vschedule_with_optim() - # run_with_moehybridplugin() - # run_with_booster_moehybridplugin() + # run_fwd_bwd_vschedule_with_optim() + run_with_booster_moehybridplugin() @pytest.mark.dist From 5c8bbf63a8ac03e15b658dc9dbf69b1cdec31c33 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Sun, 29 Sep 2024 09:59:41 +0000 Subject: [PATCH 2/5] =?UTF-8?q?[feat]=20update=20optimizer=20bwd;=20=C3=A4?= =?UTF-8?q?=C2=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- colossalai/interface/optimizer.py | 4 +-- colossalai/zero/gemini/gemini_ddp.py | 2 +- colossalai/zero/gemini/gemini_optimizer.py | 6 +++-- colossalai/zero/low_level/low_level_optim.py | 13 ++++++--- .../test_schedule/test_zerobubble_pp.py | 27 ++++++++++++++----- 5 files changed, 36 insertions(+), 16 deletions(-) diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py index a236434a55d6..c8cf3ec21360 100644 --- a/colossalai/interface/optimizer.py +++ b/colossalai/interface/optimizer.py @@ -49,11 +49,11 @@ def zero_grad(self, *args, **kwargs): """ self.optim.zero_grad(*args, **kwargs) - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): """ Performs a backward pass on the loss. """ - loss.backward(*args, **kwargs) + loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs) def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): """ diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 80b2c7961e29..d2754cbd965b 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -373,7 +373,7 @@ def backward(self, loss: torch.Tensor): loss.backward() self._post_backward() - def backward_by_grad(self, tensor, grad): + def backward_by_grad(self, tensor, grad, inputs: torch.Tensor = None, retain_graph: bool = False): raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.") @staticmethod diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index fdf2a497626f..ccd4634b5fe2 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -298,12 +298,14 @@ def backward(self, loss: torch.Tensor): loss = self.mix_precision_mixin.pre_backward(loss) self.module.backward(loss) - def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor): + def backward_by_grad( + self, tensor: torch.Tensor, grad: torch.Tensor, inputs: torch.Tensor = None, retain_graph: bool = False + ): # This function is called except the last stage of pipeline parallel # It receives the scaled grad from the previous rank # No need to scale the grad again # Need to unscale when optimizing - grad = self.mix_precision_mixin.pre_backward_by_grad(grad) + grad = self.mix_precision_mixin.pre_backward_by_grad(grad, inputs=inputs, retain_graph=retain_graph) self.module.backward_by_grad(tensor, grad) def _maybe_move_fp32_params(self): diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 51d7d1eaaa33..9cc44c7538dd 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -408,7 +408,7 @@ def _add_to_bucket(self, param, group_id): # torch.optim.Optimizer methods ################################ - def backward(self, loss, retain_graph=False): + def backward(self, loss, inputs=None, retain_graph=False): assert not ( self._partition_grads and not self.require_grad_sync ), "ZeRO2(partition_grads) and no_sync are not compatible" @@ -416,7 +416,7 @@ def backward(self, loss, retain_graph=False): if self.mixed_precision_mixin is not None: loss = self.mixed_precision_mixin.pre_backward(loss) - loss.backward(retain_graph=retain_graph) + loss.backward(inputs=inputs, retain_graph=retain_graph) if not self.require_grad_sync: return @@ -427,14 +427,19 @@ def backward(self, loss, retain_graph=False): if self._overlap_communication: get_accelerator().synchronize() - def backward_by_grad(self, tensor, grad): + def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False): assert not ( self._partition_grads and not self.require_grad_sync ), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" if self.mixed_precision_mixin is not None: grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) - torch.autograd.backward(tensor, grad) + torch.autograd.backward( + tensor, + grad, + inputs=inputs, + retain_graph=retain_graph, + ) if not self.require_grad_sync: return diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index ba6cafe6bbd4..384ed649055c 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -19,6 +19,8 @@ from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from tests.test_moe.moe_utils import assert_loose_close @@ -751,12 +753,13 @@ def run_with_hybridplugin(test_config): "config", [ (0, 1, 4, 1, 1), - # (0, 2, 2, 1, 1), - # (0, 2, 1, 2, 1), - # (0, 2, 1, 1, 2), + (1, 2, 2, 1, 1), + (1, 2, 1, 2, 1), + (1, 2, 1, 1, 2), ], ) def run_with_booster_moehybridplugin(config: Tuple[int, ...]): + test_config = config stage, ep_size, pp_size, tp_size, sp_size = config num_microbatches = pp_size dist.get_world_size() @@ -865,8 +868,15 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): ) # stage 0 chunk 0 parallel_output = None - if rank == dist.get_process_group_ranks(plugin.pp_group)[0]: + if ( + booster.plugin.stage_manager.is_first_stage(ignore_chunk=True) + and rank == dist.get_process_group_ranks(plugin.pp_group)[0] + ): parallel_output = sharded_output["loss"] + else: + parallel_output = torch.tensor(12345.0, device="cuda") + # broadcast along pp axis + dist.broadcast(parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[0], group=plugin.pp_group) else: # for test without pp @@ -874,7 +884,7 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): parallel_optimizer.backward(parallel_output) parallel_optimizer.step() parallel_optimizer.zero_grad() - # dist.all_reduce(parallel_output, group=plugin.dp_group) + dist.all_reduce(parallel_output, group=plugin.dp_group) # =================================================================================== # run normal model with all dp(different) inputs @@ -891,8 +901,11 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): p.grad /= dp_size torch_optimizer.step() torch_optimizer.zero_grad() - if rank == dist.get_process_group_ranks(plugin.pp_group)[0]: - assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) + assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) + print(f"rank {dist.get_rank()} config {test_config} test passed") + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() def run_dist(rank, world_size, port): From 6975c50f781516ffa350ca1f53b020b9a9b25045 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 30 Sep 2024 02:34:54 +0000 Subject: [PATCH 3/5] [fix] fix build ci; --- .github/workflows/build_on_pr.yml | 2 +- .github/workflows/build_on_schedule.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 58cd8826809a..79d758c87976 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -140,7 +140,7 @@ jobs: - name: Install Colossal-AI run: | - BUILD_EXT=1 pip install -v -e . + BUILD_EXT=1 pip install -v . pip install --no-cache-dir -r requirements/requirements-test.txt - name: Store Colossal-AI Cache diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml index fc688a71bd92..e7b5063279eb 100644 --- a/.github/workflows/build_on_schedule.yml +++ b/.github/workflows/build_on_schedule.yml @@ -55,7 +55,7 @@ jobs: if: steps.check-avai.outputs.avai == 'true' run: | [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ - BUILD_EXT=1 pip install -v -e . + BUILD_EXT=1 pip install -v . cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ pip install --no-cache-dir -r requirements/requirements-test.txt From 292a504bea0ca7af22d2f21c3826ca0a4ea7b4ab Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 8 Oct 2024 09:25:11 +0000 Subject: [PATCH 4/5] [fix] fix mixtral policy; --- colossalai/shardformer/policies/mixtral.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 3a41b27995fa..c570badd6dab 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -268,9 +268,11 @@ def get_held_layers(self) -> List[Module]: held_layers.append(module.embed_tokens) for start_idx, end_idx in stage_indices: held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): - held_layers.append(module.norm) - elif stage_manager.is_last_stage(ignore_chunk=True): + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + stage_manager.is_last_stage(ignore_chunk=True) + ): + # for zbv, when is_first_stage (last fwd), we append norm + # for interleaved, when is_last_stage (last fwd), we also append norm held_layers.append(module.norm) else: layers_per_stage = stage_manager.distribute_layers(len(module.layers)) From cc500b3e25dc8d626829e0098a1cc54d6438f93b Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 8 Oct 2024 09:34:09 +0000 Subject: [PATCH 5/5] [fix] fix mixtral policy; --- colossalai/shardformer/policies/mixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index c570badd6dab..af5b15ed5d20 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -269,7 +269,7 @@ def get_held_layers(self) -> List[Module]: for start_idx, end_idx in stage_indices: held_layers.extend(module.layers[start_idx:end_idx]) if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( - stage_manager.is_last_stage(ignore_chunk=True) + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) ): # for zbv, when is_first_stage (last fwd), we append norm # for interleaved, when is_last_stage (last fwd), we also append norm