diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index f58a21df3614..f970d8ccc85d 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -1097,13 +1097,19 @@ def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8 return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication) -def gather_sp_output(hidden_states, sp_group, sp_mode, fp8_communication=False): +def gather_sp_output(hidden_states, shard_config, sp_dim=1): """ Gather the output of the last layer for cross entropy computation """ + sp_group = shard_config.sequence_parallel_process_group + sp_mode = shard_config.sequence_parallelism_mode + fp8_comm = shard_config.fp8_communication + if dist.get_world_size(sp_group) == 1: + return hidden_states + # Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group) scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group) hidden_states = gather_forward_split_backward( - hidden_states, 1, sp_group, grad_scale=scale, fp8_communication=fp8_communication + hidden_states, sp_dim, sp_group, grad_scale=scale, fp8_communication=fp8_comm ) return hidden_states diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 5d1a30d8a4b6..bf4fa77c6c23 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -433,7 +433,6 @@ def get_double_ring_groups(sp_group, inner_ring_size=None): assert ( sp_size % inner_ring_size == 0 ), f"sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}" - logger = get_dist_logger() logger.info( f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!", @@ -898,6 +897,7 @@ def backward(ctx, dout, _): local_sp_rank = dist.get_rank(sp_group) sp_size = dist.get_world_size(sp_group) + # Using separate streams (pg) for concurrent kv and dkv comm may # cause NCCL "software caused connection abort" here... local_kv_comm = RingComm(local_kv_group) @@ -1119,9 +1119,14 @@ def prepare_varlen_batch( the batch dim to a packed 1d sequence. Contingent on model forward shape definitions. Returns: - inputs_embeds: Packed input embeddings of shape [B, Sq // sp_size, ...]. - mask_info: A dictionary of mask info. - position_ids: Packed position ids of shape [..., Sq // sp_size]. + torch.Tensor: + Packed input embeddings of shape [B, Sq // sp_size, ...]. + + Dict[str, Any]: + A dictionary containing mask info. + + torch.Tensor: + Packed position ids of shape [..., Sq // sp_size]. """ _load_varlen_helpers() diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 12df824d1c0c..0e2241af9fc9 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -153,7 +153,6 @@ def dist_cross_entropy( labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] logits: torch.Tensor, # [B, S, Vocab_size] shard_config: ShardConfig, - out_features: int, vocab_size: int, dtype: torch.dtype, seq_dim: int = 1, @@ -226,13 +225,13 @@ def dist_cross_entropy( logits, labels, process_group=shard_config.tensor_parallel_process_group, - vocab_size=out_features, + vocab_size=vocab_size, dtype=dtype, mode="sum", ) else: # NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D - logits = logits.view(-1, vocab_size) + logits = logits.view(-1, logits.size(-1)) loss = loss_fct(logits, labels) # Reduce loss instead of gathering logits over seq dim for savings diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index f9a41a467300..6fd689908af0 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -313,19 +313,19 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - - if self.seq_parallel_mode is None: - # Set up backprop all-reduce. - input_parallel = reduce_backward(input_, self.process_group, fp8_communication=self.fp8_communication) - output_parallel = matmul_with_async_comm( + if self.seq_parallel_mode == "split_gather": + input_parallel = input_ + output_parallel = matmul_gather_forward_reducescatter_backward( input_parallel, self.weight, bias, self.process_group, - self.async_communication, + True, + 1, + self.overlap, fp8_communication=self.fp8_communication, ) - elif self.seq_parallel_mode == "split_gather": + elif self.seq_parallel_mode == "ring": input_parallel = input_ output_parallel = matmul_gather_forward_reducescatter_backward( input_parallel, @@ -335,13 +335,22 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: True, 1, self.overlap, + True, fp8_communication=self.fp8_communication, ) - elif self.seq_parallel_mode == "ring": - input_parallel = input_ - output_parallel = matmul_gather_forward_reducescatter_backward( - input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap, True + elif self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn": + # Set up backprop all-reduce. + input_parallel = reduce_backward(input_, self.process_group) + output_parallel = matmul_with_async_comm( + input_parallel, + self.weight, + bias, + self.process_group, + self.async_communication, + fp8_communication=self.fp8_communication, ) + else: + raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!") if self.gather_output: # All-gather across the partitions. @@ -553,7 +562,7 @@ def forward(self, input_: Tensor) -> Tensor: handle.wait() output = torch.cat(output_parallel_list, dim=-1) else: - if self.seq_parallel_mode is None: + if self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn": output_parallel = torch.matmul(input_, self.weight) output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication) elif self.seq_parallel_mode == "split_gather": @@ -567,8 +576,12 @@ def forward(self, input_: Tensor) -> Tensor: elif self.seq_parallel_mode == "ring": output_parallel = torch.matmul(input_, self.weight) output = reducescatter_forward_gather_backward( - output_parallel, self.process_group, 1, self.fp8_communication + output_parallel, + self.process_group, + 1, ) + else: + raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!") if not self.skip_bias_add: if self.bias is not None: diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index c1a73ce05c97..4512e0c680f3 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -309,6 +309,9 @@ def split_batch_zigzag( """ sp_size = dist.get_world_size(sp_group) sp_rank = dist.get_rank(sp_group) + if sp_size == 1: + return batch + if isinstance(batch, torch.Tensor): batch = [batch] seq_dim = seq_dim if seq_dim != -1 else batch[0].dim() - 1 @@ -364,6 +367,9 @@ def split_varlen_zigzag( """ sp_size = dist.get_world_size(sp_group) sp_rank = dist.get_rank(sp_group) + if sp_size == 1: + return batch + if is_2d: assert max_seqlen > 0, "max_seqlen must be provided for 2D input" diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index f8fd4665f1bb..7e8e50d9bbd0 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -365,14 +365,15 @@ def bloom_for_causal_lm_forward( hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states).contiguous() - loss = dist_cross_entropy( - labels, - lm_logits, - shard_config, - self.lm_head.out_features, - self.config.vocab_size, - self.transformer.dtype, - ) + loss = None + if labels is not None: + loss = dist_cross_entropy( + labels, + lm_logits, + shard_config, + self.lm_head.out_features, + self.transformer.dtype, + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] @@ -1036,9 +1037,11 @@ def forward( hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) - loss = dist_cross_entropy( - labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy( + labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index a761968af009..a9be5c74dba8 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -4,7 +4,6 @@ import torch import torch.utils.checkpoint -from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.utils import logging @@ -13,10 +12,13 @@ from colossalai.shardformer.layer import AttnMaskType, ColoAttention from colossalai.shardformer.layer._operation import ( all_to_all_comm, - gather_forward_split_backward, + gather_sp_output, + is_share_sp_tp, split_forward_gather_backward, ) +from ..layer import dist_cross_entropy + def get_flash_core_attention_forward(): from .chatglm2_6b.modeling_chatglm import CoreAttention @@ -138,6 +140,7 @@ def chatglm_model_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, + force_sp_output_gather: Optional[bool] = True, ): logger = logging.get_logger(__name__) output_hidden_states = ( @@ -180,6 +183,15 @@ def chatglm_model_forward( if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + # Support SP + PP + sp_size = shard_config.sequence_parallel_size + sp_mode = shard_config.sequence_parallelism_mode + sp_group = shard_config.sequence_parallel_process_group + # For generating full positions ids (the states will be gathered along the seq dim before attention fwd). + if sp_mode != "ring_attn" and not stage_manager.is_first_stage(): + seq_length *= sp_size + # Rotary positional embeddings rotary_pos_emb = self.rotary_pos_emb(self.seq_length) if position_ids is not None: @@ -200,29 +212,23 @@ def chatglm_model_forward( all_hidden_states = () if output_hidden_states else None start_idx, end_idx = stage_index[0], stage_index[1] - if shard_config and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode == "split_gather": - hidden_states = split_forward_gather_backward( - hidden_states, - dim=0, - process_group=shard_config.tensor_parallel_process_group, - fp8_communication=shard_config.fp8_communication, - ) - elif shard_config.sequence_parallelism_mode == "all_to_all": - hidden_states = split_forward_gather_backward( - hidden_states, - dim=0, - process_group=shard_config.sequence_parallel_process_group, - grad_scale=1 / shard_config.sequence_parallel_size, - fp8_communication=shard_config.fp8_communication, - ) - elif shard_config.sequence_parallelism_mode == "all_to_all": - hidden_states = split_forward_gather_backward( - hidden_states, - dim=0, - process_group=shard_config.sequence_parallel_process_group, - grad_scale=1 / shard_config.sequence_parallel_size, - ) + # Keep the input split across all PP stages + if stage_manager.is_first_stage(): + if shard_config.enable_sequence_parallelism: + if sp_mode == "split_gather": + hidden_states = split_forward_gather_backward( + hidden_states, + dim=0, + process_group=sp_group, + ) + elif shard_config.sequence_parallelism_mode == "all_to_all": + hidden_states = split_forward_gather_backward( + hidden_states, + dim=0, + process_group=shard_config.sequence_parallel_process_group, + grad_scale=1 / shard_config.sequence_parallel_size, + ) + for idx in range(start_idx, end_idx): layer = self.encoder._get_layer(idx) if output_hidden_states: @@ -248,35 +254,19 @@ def chatglm_model_forward( if use_cache: presents = presents + (kv_cache,) - if shard_config and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode == "split_gather": - hidden_states = gather_forward_split_backward( - hidden_states, - dim=0, - process_group=shard_config.tensor_parallel_process_group, - fp8_communication=shard_config.fp8_communication, - ) - elif shard_config.sequence_parallelism_mode == "all_to_all": - hidden_states = gather_forward_split_backward( - hidden_states, - dim=0, - process_group=shard_config.sequence_parallel_process_group, - grad_scale=shard_config.sequence_parallel_size, - fp8_communication=shard_config.fp8_communication, - ) - elif shard_config.sequence_parallelism_mode == "all_to_all": - hidden_states = gather_forward_split_backward( - hidden_states, - dim=0, - process_group=shard_config.sequence_parallel_process_group, - grad_scale=shard_config.sequence_parallel_size, - ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if stage_manager.is_last_stage(): # final layer_norm if self.encoder.post_layer_norm: hidden_states = self.encoder.final_layernorm(hidden_states) + + # Gather seq-wise in the final output stage + if shard_config.enable_sequence_parallelism: + sp_mode = shard_config.sequence_parallelism_mode + if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + hidden_states = gather_sp_output(hidden_states, shard_config, sp_dim=0) + if not return_dict: return tuple( v @@ -333,6 +323,7 @@ def chatglm_for_conditional_generation_forward( hidden_states=hidden_states, stage_index=stage_index, shard_config=shard_config, + force_sp_output_gather=False, ) if stage_manager.is_last_stage(): hidden_states = transformer_outputs[0] @@ -340,17 +331,21 @@ def chatglm_for_conditional_generation_forward( hidden_states = hidden_states[-1:] lm_logits = self.transformer.output_layer(hidden_states) lm_logits = lm_logits.transpose(0, 1).contiguous() + loss = None if labels is not None: - lm_logits = lm_logits.to(torch.float32) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss(ignore_index=-100) - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - lm_logits = lm_logits.to(hidden_states.dtype) - loss = loss.to(hidden_states.dtype) + # ChatGLM doesn't have lm_head split + enable_tp = shard_config.enable_tensor_parallelism + shard_config.enable_tensor_parallelism = False + loss = dist_cross_entropy( + labels, + lm_logits, + shard_config, + self.transformer.output_layer.out_features, + lm_logits.dtype, + ) + shard_config.enable_tensor_parallelism = enable_tp + if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output @@ -379,6 +374,7 @@ def forward( use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + force_sp_output_gather: Optional[bool] = True, ): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -456,22 +452,9 @@ def forward( use_cache=use_cache, output_hidden_states=output_hidden_states, ) - - if sp_mode in ["split_gather"]: - hidden_states = gather_forward_split_backward( - hidden_states, - dim=0, - process_group=shard_config.tensor_parallel_process_group, - fp8_communication=shard_config.fp8_communication, - ) - elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward( - hidden_states, - dim=0, - process_group=sp_group, - grad_scale=sp_size, - fp8_communication=shard_config.fp8_communication, - ) + if shard_config.enable_sequence_parallelism: + if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + hidden_states = gather_sp_output(hidden_states, shard_config, sp_dim=0) if not return_dict: return tuple( diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 5303383945bc..ea811acdf21a 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -17,14 +17,13 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.layer._operation import ( - all_to_all_comm, - gather_forward_split_backward, - split_forward_gather_backward, -) +from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig from ..layer import ColoAttention, dist_cross_entropy +from ..layer._operation import gather_sp_output, is_share_sp_tp + +_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring"] _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] @@ -52,6 +51,7 @@ def command_model_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, + force_sp_output_gather: bool = True, ): logger = logging.get_logger(__name__) @@ -93,10 +93,16 @@ def command_model_forward( if not isinstance(past_key_values, StaticCache): past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_seen_tokens = past_key_values.get_seq_length() + + # NOTE: For generating full positions ids + # (the states will be gathered along the seq dim before attention fwd). + if shard_config.sequence_parallelism_mode != "ring_attn" and not stage_manager.is_first_stage(): + seq_length *= shard_config.sequence_parallel_size + if cache_position is None: if isinstance(past_key_values, StaticCache): raise ValueError("cache_position is a required argument when using StaticCache.") - cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device) + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device) seq_length_with_past = seq_length + past_seen_tokens @@ -136,7 +142,7 @@ def command_model_forward( ) use_cache = False - if shard_config and shard_config.enable_sequence_parallelism: + if stage_manager.is_first_stage() and shard_config.enable_sequence_parallelism: if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: hidden_states = split_forward_gather_backward( hidden_states, @@ -208,23 +214,10 @@ def command_model_forward( if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) - - if shard_config and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: - hidden_states = gather_forward_split_backward( - hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group, - fp8_communication=shard_config.fp8_communication, - ) - elif shard_config.sequence_parallelism_mode == "all_to_all": - hidden_states = gather_forward_split_backward( - hidden_states, - dim=1, - process_group=shard_config.sequence_parallel_process_group, - grad_scale=shard_config.sequence_parallel_size, - fp8_communication=shard_config.fp8_communication, - ) + sp_mode = shard_config.sequence_parallelism_mode + if shard_config.enable_sequence_parallelism: + if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + hidden_states = gather_sp_output(hidden_states, shard_config) # add hidden states from the last decoder layer if output_hidden_states: @@ -327,6 +320,7 @@ def command_for_causal_lm_forward( hidden_states=hidden_states, stage_index=stage_index, shard_config=shard_config, + force_sp_output_gather=False, ) past_key_values = None @@ -335,9 +329,10 @@ def command_for_causal_lm_forward( logits = self.lm_head(hidden_states) logits = logits * self.logit_scale logits = logits.float() - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype - ) + + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype) if not return_dict: output = (logits,) + outputs[1:] @@ -482,6 +477,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + force_sp_output_gather: bool = True, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -584,14 +580,10 @@ def forward( hidden_states = self.norm(hidden_states) - if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward( - hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication - ) - elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward( - hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication - ) + # Cases that don't support parallelizing cross entropy computation along sequence + if shard_config.enable_sequence_parallelism: + if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather: + hidden_states = gather_sp_output(hidden_states, shard_config) # add hidden states from the last decoder layer if output_hidden_states: @@ -676,6 +668,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + force_sp_output_gather=False, ) hidden_states = outputs[0] @@ -683,14 +676,16 @@ def forward( logits = self.lm_head(hidden_states) logits = logits * self.logit_scale logits = logits.float() - loss = dist_cross_entropy( - labels, - logits, - shard_config, - self.lm_head.out_features, - self.config.vocab_size, - self.model.dtype, - ) + + loss = None + if labels is not None: + loss = dist_cross_entropy( + labels, + logits, + shard_config, + self.lm_head.out_features, + self.model.dtype, + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 97544e1105d6..798fca88fb4f 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -21,8 +21,9 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.layer import ColoAttention -from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward +from colossalai.shardformer.layer import ColoAttention, RingAttention +from colossalai.shardformer.layer._operation import gather_sp_output, split_forward_gather_backward +from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag from colossalai.shardformer.shard import ShardConfig from ..layer import dist_cross_entropy @@ -39,10 +40,16 @@ def _get_attention_mask( encoder_hidden_states: Optional[torch.Tensor], encoder_attention_mask: Optional[torch.FloatTensor], ) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]: - batch_size, seq_len = hidden_states.shape[:2] + # Received input is already split for non-first pipeline stages, + # but attn mask isn't + batch_size = hidden_states.size(0) + seq_len = attention_mask.size(-1) + + sp_mode = shard_config.sequence_parallelism_mode # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.add_cross_attention and encoder_hidden_states is not None: + assert not sp_mode == "ring_attn", "Ring Attention only supports decoder-only." encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() if shard_config.enable_flash_attention: encoder_attention_mask = ColoAttention.prepare_attn_kwargs( @@ -62,6 +69,7 @@ def _get_attention_mask( encoder_attention_mask = {"attention_mask": None} else: encoder_attention_mask = None + # GPT2Attention mask. past_key_values_length = 0 if past_key_values is not None and past_key_values[0] is not None: @@ -69,6 +77,7 @@ def _get_attention_mask( if shard_config.enable_flash_attention: if attention_mask is not None: attention_mask = attention_mask.view(batch_size, -1) + attention_mask = ColoAttention.prepare_attn_kwargs( (batch_size, 1, seq_len, seq_len + past_key_values_length), hidden_states.dtype, @@ -123,6 +132,7 @@ def gpt2_model_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, + force_sp_gather: Optional[bool] = True, ) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. # Please refer to original code of transformers for more details. @@ -146,16 +156,15 @@ def gpt2_model_forward( logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False - if stage_manager.is_first_stage(): + disable_pp = stage_manager is None + if disable_pp or stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) - input_ids.shape[0] elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] - inputs_embeds.shape[0] else: raise ValueError("You have to specify either input_ids or inputs_embeds") @@ -176,7 +185,7 @@ def gpt2_model_forward( # head_mask has shape n_layer x batch x n_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) - if stage_manager.is_first_stage(): + if disable_pp or stage_manager.is_first_stage(): if position_ids is None: position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0) @@ -190,9 +199,7 @@ def gpt2_model_forward( hidden_states = hidden_states + token_type_embeds hidden_states = self.drop(hidden_states) - output_shape = input_shape + (hidden_states.size(-1),) - - attention_mask, encoder_attention_mask = _get_attention_mask( + attn_kwargs, encoder_attention_mask = _get_attention_mask( self, shard_config, hidden_states, @@ -215,23 +222,43 @@ def gpt2_model_forward( # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - if shard_config and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode == "split_gather": - hidden_states = split_forward_gather_backward( - hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group, - fp8_communication=shard_config.fp8_communication, - ) + sp_mode = shard_config.sequence_parallelism_mode + sp_group = shard_config.sequence_parallel_process_group + if disable_pp or stage_manager.is_first_stage(): + # Ring Attention's special zigzag batch processing + if sp_mode == "ring_attn": + assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." + if not attention_mask.bool().all(): + hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch( + attention_mask, sp_group, hidden_states, position_ids + ) + else: + hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group) + # Other sp modes + else: + if sp_mode == "split_gather": + hidden_states = split_forward_gather_backward( + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + ) + elif sp_mode == "ring_attn": + # Later stages already received split hidden states + _, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group) + del attention_mask # Going through held blocks. - start_idx, end_idx = stage_index[0], stage_index[1] + if disable_pp: + start_idx, end_idx = 0, len(self.h) + else: + start_idx, end_idx = stage_index[0], stage_index[1] + for i in range(start_idx, end_idx): block = self.h[i] torch.cuda.set_device(hidden_states.device) # Ensure that attention_mask is always on the same device as hidden_states - if torch.is_tensor(attention_mask): - attention_mask = attention_mask.to(hidden_states.device) + if torch.is_tensor(attn_kwargs): + attn_kwargs = attn_kwargs.to(hidden_states.device) if isinstance(head_mask, torch.Tensor): head_mask = head_mask.to(hidden_states.device) if output_hidden_states: @@ -242,7 +269,7 @@ def gpt2_model_forward( block.__call__, hidden_states, None, - attention_mask, + attn_kwargs, head_mask[i], encoder_hidden_states, encoder_attention_mask, @@ -253,7 +280,7 @@ def gpt2_model_forward( outputs = block( hidden_states, layer_past=None, - attention_mask=attention_mask, + attention_mask=attn_kwargs, head_mask=head_mask[i], encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, @@ -270,26 +297,25 @@ def gpt2_model_forward( if self.config.add_cross_attention: all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) - # When sequence parallelism done, gather the output tensor in forward and split it in backward - if shard_config and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode == "split_gather": - hidden_states = gather_forward_split_backward( - hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group, - fp8_communication=shard_config.fp8_communication, - ) + # When sequence parallelism is done, gather the output tensor in forward and split it in backward + gather_output = (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode) + if disable_pp or stage_manager.is_last_stage(): + if gather_output: + hidden_states = gather_sp_output(hidden_states, shard_config) - if stage_manager.is_last_stage(): - hidden_states = self.ln_f(hidden_states) + # gather_sp_output could've changed seq length. + input_shape = (*input_shape[:-1], hidden_states.size(-2)) + output_shape = input_shape + (hidden_states.size(-1),) + if disable_pp or stage_manager.is_last_stage(): + hidden_states = self.ln_f(hidden_states) hidden_states = hidden_states.view(output_shape) # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if stage_manager.is_last_stage(): + if disable_pp or stage_manager.is_last_stage(): if not return_dict: return tuple( v @@ -366,17 +392,29 @@ def gpt2_lmhead_model_forward( hidden_states=hidden_states, stage_index=stage_index, shard_config=shard_config, + force_sp_gather=False, ) # If not at the last stage, return hidden_states as in GPT2Model - if not stage_manager.is_last_stage(): + disable_pp = stage_manager is None + if (not disable_pp) and (not stage_manager.is_last_stage()): return {"hidden_states": outputs["hidden_states"]} hidden_states = outputs[0] lm_logits = self.lm_head(hidden_states) - loss = dist_cross_entropy( - labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype - ) + if shard_config.sequence_parallelism_mode == "ring_attn": + # Split labels in a zigzag fashion too + sp_group = shard_config.sequence_parallel_process_group + if not attention_mask.bool().all(): + # [B, max_seqlen // sp_size] + labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) + else: + labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True) + + if labels is not None: + loss = dist_cross_entropy( + labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype + ) if not return_dict: output = (lm_logits,) + outputs[1:] @@ -770,7 +808,7 @@ def gpt2_for_sequence_classification_forward( ) -def get_gpt2_flash_attention_forward(): +def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention def forward( @@ -817,7 +855,22 @@ def forward( if self.scale_attn_by_inverse_layer_idx: scale /= float(self.layer_idx + 1) dropout_p = self.attn_dropout.p if self.training else 0.0 - attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) + + sp_mode = shard_config.sequence_parallelism_mode + sp_group = shard_config.sequence_parallel_process_group + if sp_mode == "ring_attn": + attn_output = RingAttention.attention( + query, + key, + value, + sp_group, + **attention_mask, + dropout_p=dropout_p, + scale=scale, + inner_ring_size=shard_config.inner_ring_size, + ) + else: + attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) @@ -828,466 +881,6 @@ def forward( return forward -def get_gpt_model_forward_for_flash_attn(shard_config: ShardConfig): - def forward( - self: GPT2Model, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - input_ids.shape[0] - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - inputs_embeds.shape[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if position_ids is not None: - position_ids = position_ids.view(-1, input_shape[-1]) - - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) - if position_ids is None: - position_ids = torch.arange( - past_length, - input_shape[-1] + past_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # head_mask has shape n_layer x batch x n_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds - - if token_type_ids is not None: - token_type_embeds = self.wte(token_type_ids) - hidden_states = hidden_states + token_type_embeds - - hidden_states = self.drop(hidden_states) - - output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) - - attention_mask, encoder_attention_mask = _get_attention_mask( - self, - shard_config, - hidden_states, - past_key_values, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - ) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - all_hidden_states = () if output_hidden_states else None - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure layer_past is on same device as hidden_states (might not be correct) - if layer_past is not None: - layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) - # Ensure that attention_mask is always on the same device as hidden_states - if torch.is_tensor(attention_mask): - attention_mask = attention_mask.to(hidden_states.device) - if isinstance(head_mask, torch.Tensor): - head_mask = head_mask.to(hidden_states.device) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) - - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) - - hidden_states = self.ln_f(hidden_states) - - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - presents, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - - return forward - - -def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - input_ids.shape[0] - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - inputs_embeds.shape[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if position_ids is not None: - position_ids = position_ids.view(-1, input_shape[-1]) - - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) - if position_ids is None: - position_ids = torch.arange( - past_length, - input_shape[-1] + past_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # head_mask has shape n_layer x batch x n_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds - - if token_type_ids is not None: - token_type_embeds = self.wte(token_type_ids) - hidden_states = hidden_states + token_type_embeds - - hidden_states = self.drop(hidden_states) - - output_shape = input_shape + (hidden_states.size(-1),) - attention_mask, encoder_attention_mask = _get_attention_mask( - self, - shard_config, - hidden_states, - past_key_values, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - ) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger = logging.get_logger(__name__) - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - all_hidden_states = () if output_hidden_states else None - - # split the input tensor along sequence dimension - # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - hidden_states = split_forward_gather_backward( - hidden_states, - dim=1, - process_group=shard_config.sequence_parallel_process_group, - fp8_communication=shard_config.fp8_communication, - ) - - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure layer_past is on same device as hidden_states (might not be correct) - if layer_past is not None: - layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) - # Ensure that attention_mask is always on the same device as hidden_states - if torch.is_tensor(attention_mask): - attention_mask = attention_mask.to(hidden_states.device) - if isinstance(head_mask, torch.Tensor): - head_mask = head_mask.to(hidden_states.device) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) - - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) - - # When sequence parallelism done, gather the output tensor in forward and split it in backward - hidden_states = gather_forward_split_backward( - hidden_states, - dim=1, - process_group=shard_config.sequence_parallel_process_group, - fp8_communication=shard_config.fp8_communication, - ) - - hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - presents, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - - return forward - - -def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): - from transformers import GPT2LMHeadModel - - def forward( - self: GPT2LMHeadModel, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - - lm_logits = self.lm_head(hidden_states) - loss = dist_cross_entropy( - labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype - ) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - cross_attentions=transformer_outputs.cross_attentions, - ) - - return forward - - def get_jit_fused_gpt2_mlp_forward(): from transformers.models.gpt2.modeling_gpt2 import GPT2MLP diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 4e8f67407bc6..47c17e7494f2 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -25,7 +25,6 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.layer import AttnMaskType from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag from colossalai.shardformer.shard import ShardConfig @@ -58,10 +57,7 @@ def llama_model_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, - # Split output only when computing cross entropy using llama_for_causal_lm_forward - # or get_lm_forward_with_dist_cross_entropy - # Default to True to avoid bug when calling classification forward from huggingface - force_sp_output_gather: bool = True, + force_sp_gather: bool = True, # Set to false only when computing cross entropy ): logger = logging.get_logger(__name__) @@ -78,8 +74,9 @@ def llama_model_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict + disable_pp = stage_manager is None # retrieve input_ids and inputs_embeds - if stage_manager.is_first_stage(): + if disable_pp or stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: @@ -88,10 +85,10 @@ def llama_model_forward( batch_size, seq_length, _ = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") - device = input_ids.device if input_ids is not None else inputs_embeds.device if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds + device = hidden_states.device else: input_shape = hidden_states.shape[:-1] batch_size, seq_length = input_shape @@ -101,8 +98,8 @@ def llama_model_forward( sp_mode = shard_config.sequence_parallelism_mode sp_group = shard_config.sequence_parallel_process_group sp_size = shard_config.sequence_parallel_size - if sp_mode == "all_to_all" and not stage_manager.is_first_stage(): - # For generating full positions ids, as the states will be gather along the seq dim in the attention layer later. + # Generating full positions ids for modes that gather sequence before attn + if stage_manager and (sp_mode != "ring_attn" and not stage_manager.is_first_stage()): seq_length *= sp_size past_seen_tokens = 0 @@ -117,7 +114,6 @@ def llama_model_forward( seq_length_with_past = seq_length + past_seen_tokens - # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False @@ -130,14 +126,13 @@ def llama_model_forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - # embed positions, for the first stage, hidden_states is the input embeddings, - # for the other stages, hidden_states is the output of the previous stage - if not stage_manager.is_first_stage() and sp_mode == "ring_attn": + + no_split_input = disable_pp or not stage_manager.is_first_stage() + if no_split_input and sp_mode == "ring_attn": _, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group) elif shard_config.enable_flash_attention: - # in this case, attention_mask is a dict rather than a tensor mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) - attn_kwargs = ColoAttention.prepare_attn_kwargs( + attn_kwargs: dict = ColoAttention.prepare_attn_kwargs( mask_shape, hidden_states.dtype, hidden_states.device, @@ -146,15 +141,15 @@ def llama_model_forward( invert=(sp_mode != "ring_attn"), ) else: - attn_kwargs = self._update_causal_mask(attention_mask, hidden_states, cache_position) + attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position) - # Support SP + PP - # TODO: support padded casual cu_seqlens across stages - if stage_manager.is_first_stage(): + # Support SP + PP. Later stages have already received the split input. + split_input = disable_pp or stage_manager.is_first_stage() + if split_input: # Ring Attention zigzag batch processing if sp_mode == "ring_attn": assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." - if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: + if not attention_mask.bool().all(): hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch( attention_mask, sp_group, hidden_states, position_ids ) @@ -181,8 +176,8 @@ def llama_model_forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None + start_idx, end_idx = (0, len(self.layers)) if disable_pp else (stage_index[0], stage_index[1]) - start_idx, end_idx = stage_index[0], stage_index[1] num_ckpt_layers = 0 if self.gradient_checkpointing and self.training: num_ckpt_layers = end_idx - start_idx @@ -228,18 +223,16 @@ def llama_model_forward( if output_attentions: all_self_attns += (layer_outputs[1],) - if stage_manager.is_last_stage(): + if disable_pp or stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) - if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): - hidden_states = gather_sp_output( - hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication - ) + if (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode): # noqa + hidden_states = gather_sp_output(hidden_states, shard_config) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None - if stage_manager.is_last_stage(): + if disable_pp or stage_manager.is_last_stage(): if not return_dict: return tuple( v @@ -257,7 +250,7 @@ def llama_model_forward( hidden_states=all_hidden_states, attentions=all_self_attns, ) - # always return dict for imediate stage + # always return dict for intermediate stage return {"hidden_states": hidden_states} @staticmethod @@ -323,7 +316,7 @@ def llama_for_causal_lm_forward( # Split labels in a zigzag fashion too sp_group = shard_config.sequence_parallel_process_group if attention_mask.bool().all(): - labels = split_batch_zigzag(labels, sp_group, seq_dim=1) + labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True) else: # [B, max_seqlen // sp_size] labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) @@ -345,16 +338,17 @@ def llama_for_causal_lm_forward( hidden_states=hidden_states, stage_index=stage_index, shard_config=shard_config, - force_sp_output_gather=False, + force_sp_gather=False, ) past_key_values = None - if stage_manager.is_last_stage(): + disable_pp = stage_manager is None + if disable_pp or stage_manager.is_last_stage(): hidden_states = outputs[0] logits = self.lm_head(hidden_states) - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype) if not return_dict: output = (logits,) + outputs[1:] @@ -629,263 +623,3 @@ def forward( return attn_output, attn_weights, past_key_value return forward - - -def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): - logger = logging.get_logger(__name__) - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - # Split output only when computing cross entropy using llama_for_causal_lm_forward - # or get_lm_forward_with_dist_cross_entropy - # Default to True to avoid bug when calling classification forward from huggingface - force_sp_output_gather: bool = True, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" - ) - - if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - past_seen_tokens = 0 - seq_len = inputs_embeds.shape[1] - batch_size = inputs_embeds.shape[0] - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() - - if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") - cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - if shard_config.enable_flash_attention: - mask_shape = (batch_size, 1, seq_len, past_seen_tokens + seq_len) - attn_kwargs: dict = ColoAttention.prepare_attn_kwargs( - mask_shape, - inputs_embeds.dtype, - inputs_embeds.device, - q_padding_mask=attention_mask, - is_causal=True, - invert=(sp_mode != "ring_attn"), - ) - - else: - attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) - - # Ring Attention zigzag batch processing - if sp_mode == "ring_attn": - assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." - if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: - inputs_embeds, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch( - attention_mask, sp_group, inputs_embeds, position_ids - ) - else: - inputs_embeds, position_ids = split_batch_zigzag([inputs_embeds, position_ids], sp_group) - attn_kwargs = {"attention_mask_type": attn_kwargs["attention_mask_type"]} # drop redundant tensors - - elif is_share_sp_tp(sp_mode): - inputs_embeds = split_forward_gather_backward( - inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication - ) - elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward( - inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication - ) - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attn_kwargs, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attn_kwargs, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - # Cases that don't support parallelizing cross entropy computation along sequence - if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather: - hidden_states = gather_sp_output( - hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication - ) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = None - if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache - ) - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - return forward - - -def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): - from transformers import LlamaForCausalLM - - def forward( - self: LlamaForCausalLM, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output: - # Special processing: Split labels in a zigzag fashion too - sp_group = shard_config.sequence_parallel_process_group - if attention_mask.bool().all(): - labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True) - else: - # [B, max_seq_len // sp_size] - labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - force_sp_output_gather=False, - ) - - hidden_states = outputs[0] - if self.config.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] - logits = torch.cat(logits, dim=-1) - else: - logits = self.lm_head(hidden_states) - logits = logits.float() - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype - ) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - return forward diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index ec1a8a00a58a..7fc6a1062037 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -274,10 +274,9 @@ def mistral_for_causal_lm_forward( hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() - - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype) if not return_dict: output = (logits,) + outputs[1:] @@ -687,10 +686,9 @@ def forward( hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() - - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 636b46cc461d..3ea4db9e2f70 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -330,14 +330,15 @@ def opt_for_causal_lm_forward( ) if stage_manager.is_last_stage(): logits = self.lm_head(outputs[0]).contiguous() - loss = dist_cross_entropy( - labels, - logits, - shard_config, - self.lm_head.out_features, - self.config.vocab_size, - self.model.decoder.dtype, - ) + loss = None + if labels is not None: + loss = dist_cross_entropy( + labels, + logits, + shard_config, + self.lm_head.out_features, + self.model.decoder.dtype, + ) if not return_dict: output = (logits,) + outputs[1:] @@ -955,9 +956,9 @@ def forward( ) logits = self.lm_head(outputs[0]).contiguous() - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.decoder.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.decoder.dtype) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index a61360d17570..569fc4a459c5 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -32,14 +32,12 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.layer._operation import ( - all_to_all_comm, - gather_forward_split_backward, - split_forward_gather_backward, -) +from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig from ..layer import ColoAttention, dist_cross_entropy +from ..layer._operation import gather_sp_output +from ..layer.utils import is_share_sp_tp class Qwen2PipelineForwards: @@ -64,6 +62,7 @@ def qwen2_model_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, + force_sp_output_gather: bool = True, ) -> Union[Tuple, BaseModelOutputWithPast]: logger = logging.get_logger(__name__) @@ -115,6 +114,14 @@ def qwen2_model_forward( past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length + # Support SP + PP + sp_size = shard_config.sequence_parallel_size + sp_group = shard_config.sequence_parallel_process_group + sp_mode = shard_config.sequence_parallelism_mode + # For generating full positions ids (the states will be gathered along the seq dim before attention fwd). + if sp_mode != "ring_attn" and not stage_manager.is_first_stage(): + seq_length *= sp_size + if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( @@ -151,7 +158,6 @@ def qwen2_model_forward( elif self._attn_implementation == "sdpa" and not output_attentions: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), @@ -160,7 +166,6 @@ def qwen2_model_forward( ) else: # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), @@ -169,22 +174,21 @@ def qwen2_model_forward( sliding_window=self.config.sliding_window, ) - if shard_config and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: - hidden_states = split_forward_gather_backward( - hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group, - fp8_communication=shard_config.fp8_communication, - ) - elif shard_config.sequence_parallelism_mode == "all_to_all": - hidden_states = split_forward_gather_backward( - hidden_states, - dim=1, - process_group=shard_config.sequence_parallel_process_group, - grad_scale=1 / shard_config.sequence_parallel_size, - fp8_communication=shard_config.fp8_communication, - ) + if stage_manager.is_first_stage(): + if shard_config.enable_sequence_parallelism: + if is_share_sp_tp(sp_mode): + hidden_states = split_forward_gather_backward( + hidden_states, + dim=1, + process_group=sp_group, + ) + elif sp_mode == "all_to_all": + hidden_states = split_forward_gather_backward( + hidden_states, + dim=1, + process_group=sp_group, + grad_scale=1 / sp_size, + ) # decoder layers all_hidden_states = () if output_hidden_states else None @@ -241,23 +245,10 @@ def qwen2_model_forward( if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) + if shard_config.enable_sequence_parallelism: + if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + hidden_states = gather_sp_output(hidden_states, shard_config) - if shard_config and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: - hidden_states = gather_forward_split_backward( - hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group, - fp8_communication=shard_config.fp8_communication, - ) - elif shard_config.sequence_parallelism_mode == "all_to_all": - hidden_states = gather_forward_split_backward( - hidden_states, - dim=1, - process_group=shard_config.sequence_parallel_process_group, - grad_scale=shard_config.sequence_parallel_size, - fp8_communication=shard_config.fp8_communication, - ) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) @@ -351,15 +342,18 @@ def qwen2_for_causal_lm_forward( hidden_states=hidden_states, stage_index=stage_index, shard_config=shard_config, + force_sp_output_gather=False, ) past_key_values = None if stage_manager.is_last_stage(): hidden_states = outputs[0] + if hidden_states.shape[1] == 2: + pass logits = self.lm_head(hidden_states) - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype) if not return_dict: output = (logits,) + outputs[1:] @@ -541,7 +535,6 @@ def forward( # Because the input can be padded, the absolute sequence length depends on the max position id. rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -635,6 +628,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + force_sp_output_gather: bool = True, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -750,14 +744,9 @@ def forward( hidden_states = self.norm(hidden_states) - if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward( - hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication - ) - elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward( - hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication - ) + if shard_config.enable_sequence_parallelism: + if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + hidden_states = gather_sp_output(hidden_states, shard_config) # add hidden states from the last decoder layer if output_hidden_states: @@ -834,14 +823,15 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + force_sp_output_gather=False, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index 16c13085ade1..1b7d2db85991 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -64,7 +64,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if sp_mode == "ring": warnings.warn( - f"For ChatGLM2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" + f"For ChatGLM2, sequence parallelism doesn't support mode {sp_mode} yet, will set to be split_gather" ) sp_mode = "split_gather" overlap = self.shard_config.enable_sequence_overlap diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 0a1949d85c74..d9233be9a822 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -6,14 +6,7 @@ import colossalai.shardformer.layer as col_nn -from ..modeling.gpt2 import ( - GPT2PipelineForwards, - get_gpt2_flash_attention_forward, - get_gpt_model_forward_for_flash_attn, - get_jit_fused_gpt2_mlp_forward, - get_lm_forward_with_dist_cross_entropy, - gpt2_sequence_parallel_forward_fn, -) +from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, get_jit_fused_gpt2_mlp_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -71,18 +64,10 @@ def module_policy(self): warnings.warn( f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" ) - sp_mode = "split_gather" + self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather" overlap = self.shard_config.enable_sequence_overlap sp_partial_derived = sp_mode in ["split_gather", "ring"] use_flash_attention = self.shard_config.enable_flash_attention - # todo: currently sp cannot be used with flashattention - if sp_mode in ["split_gather", "ring", "all_to_all"]: - if use_flash_attention: - warnings.warn( - f"Sequence parallelism mode {sp_mode} cannot be used with FlashAttention, will disable FlashAttention automatically." - ) - self.shard_config.enable_flash_attention = False - use_flash_attention = False if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -211,18 +196,16 @@ def module_policy(self): if use_flash_attention: self.append_or_create_method_replacement( description={ - "forward": get_gpt2_flash_attention_forward(), + "forward": get_gpt2_flash_attention_forward(shard_config=self.shard_config), }, policy=policy, target_key=attn_cls, ) - if not self.shard_config.pipeline_stage_manager: - policy[GPT2Model].method_replacement = { - "forward": get_gpt_model_forward_for_flash_attn(self.shard_config) - } - if sp_mode is not None: - policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} + if not self.shard_config.pipeline_stage_manager and self.shard_config.enable_sequence_parallelism: + policy[GPT2Model].method_replacement = { + "forward": partial(GPT2PipelineForwards.gpt2_model_forward, shard_config=self.shard_config) + } return policy @@ -328,40 +311,39 @@ def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel module_policy = super().module_policy() - + module_policy[GPT2LMHeadModel] = ModulePolicyDescription() if self.shard_config.enable_tensor_parallelism: - addon_module = { - GPT2LMHeadModel: ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", - target_module=col_nn.VocabParallelLMHead1D, - kwargs={ - "gather_output": False, - "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, - "fp8_communication": self.shard_config.fp8_communication, - }, - ) - ], - ) - } - if self.shard_config.parallel_output: - addon_module[GPT2LMHeadModel].method_replacement = { - "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) - } + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.VocabParallelLMHead1D, + kwargs={ + "gather_output": False, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, + ), + policy=module_policy, + target_key=GPT2LMHeadModel, + ) else: - addon_module = { - GPT2LMHeadModel: ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", - target_module=col_nn.PaddingLMHead, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, - ) - ] - ) - } - module_policy.update(addon_module) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=module_policy, + target_key=GPT2LMHeadModel, + ) + + if self.shard_config.parallel_output: + self.append_or_create_method_replacement( + description={ + "forward": partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, shard_config=self.shard_config) + }, + policy=module_policy, + target_key=GPT2LMHeadModel, + ) if self.pipeline_stage_manager is not None: self.set_pipeline_forward( diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 78ac5eaaeb79..ec517da4966f 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -16,12 +16,7 @@ VocabParallelLMHead1D, ) -from ..modeling.llama import ( - LlamaPipelineForwards, - get_llama_flash_attention_forward, - get_llama_flash_attention_model_forward, - get_lm_forward_with_dist_cross_entropy, -) +from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"] @@ -99,11 +94,9 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.pipeline_stage_manager is None: self.append_or_create_method_replacement( description={ - "forward": get_llama_flash_attention_model_forward( - self.shard_config, - sp_mode=sp_mode, - sp_size=sp_size, - sp_group=sp_group, + "forward": partial( + LlamaPipelineForwards.llama_model_forward, + shard_config=self.shard_config, ), }, policy=policy, @@ -351,7 +344,8 @@ def module_policy(self): elif self.shard_config.enable_tensor_parallelism or self.shard_config.enable_sequence_parallelism: # Compute loss distributedly along the sequence dimension new_item[LlamaForCausalLM].method_replacement = { - "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) + # "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) + "forward": partial(LlamaPipelineForwards.llama_for_causal_lm_forward, shard_config=self.shard_config) } return policy diff --git a/examples/language/gpt/hybridparallelism/benchmark.py b/examples/language/gpt/hybridparallelism/benchmark.py index 8c236b524c26..91b9e6c04950 100644 --- a/examples/language/gpt/hybridparallelism/benchmark.py +++ b/examples/language/gpt/hybridparallelism/benchmark.py @@ -28,7 +28,7 @@ "118M": GPT2Config(activation_function="gelu"), "338M": GPT2Config(n_embd=1024, n_head=16, n_layer=24, activation_function="gelu"), "738M": GPT2Config(n_embd=1280, n_head=20, n_layer=36, activation_function="gelu"), - "6.21B": GPT2Config(n_embd=4096, n_head=32, n_layer=32, n_positions=4096, activation_function="gelu"), + "6.21B": GPT2Config(n_embd=4096, n_head=32, n_layer=32, n_positions=32768, activation_function="gelu"), } @@ -60,6 +60,8 @@ def main(): parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini") parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") + parser.add_argument("--sp", type=int, default=1, help="Sequence parallel size") + parser.add_argument("--sp_mode", type=str, default="ring_attn", help="Sequence parallel mode") parser.add_argument("--mbs", type=int, default=1) parser.add_argument("--zero", type=int, default=0) parser.add_argument("--pp_style", type=str, default="1f1b") @@ -129,6 +131,9 @@ def empty_init(): tp_size=args.tp, pp_size=args.pp, pp_style=args.pp_style, + sp_size=args.sp, + sequence_parallelism_mode=args.sp_mode, + enable_sequence_parallelism=True, zero_stage=args.zero, num_model_chunks=args.num_model_chunks, enable_all_optimization=True, @@ -214,6 +219,8 @@ def empty_init(): performance_evaluator.on_step_start(step) outputs = model(**batch) loss = outputs[0] + del outputs + booster.backward(loss, optimizer) optimizer.step() optimizer.zero_grad() diff --git a/examples/language/performance_evaluator.py b/examples/language/performance_evaluator.py index f5ad1d23d2a7..65c7e49a2f03 100644 --- a/examples/language/performance_evaluator.py +++ b/examples/language/performance_evaluator.py @@ -6,7 +6,6 @@ from torch import Tensor from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler -from colossalai.accelerator import get_accelerator from colossalai.cluster import DistCoordinator @@ -22,8 +21,11 @@ def divide(x: float, y: float) -> float: def all_reduce_mean(x: float, world_size: int) -> float: if world_size == 1: return x - tensor = torch.tensor([x], device=get_accelerator().get_current_device()) - dist.all_reduce(tensor) + + # Use CPU tensor to avoid OOM/weird NCCl error + gloo_group = dist.new_group(backend="gloo") + tensor = torch.tensor([x], device="cpu") + dist.all_reduce(tensor, group=gloo_group) tensor = tensor / world_size return tensor.item() diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index f71776b6b4e0..f2b139beca83 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -27,7 +27,16 @@ def data_gen_for_lm(): # LM data gen # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` data = data_gen() - data["labels"] = data["input_ids"].clone() + + # Test padded sequence for Ring Attention + padding = torch.zeros(1, data["input_ids"].shape[1] // 2, dtype=torch.long) + data["input_ids"] = torch.cat([data["input_ids"], padding], dim=1) + data["attention_mask"] = torch.cat([data["attention_mask"], padding], dim=1) + + ignore_idx = -100 + labels = data["input_ids"].clone() + labels[~data["attention_mask"].bool()] = ignore_idx + data["labels"] = labels return data diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 9ad84341ac9e..0f9ec601387b 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -157,7 +157,6 @@ def build_model_from_hybrid_plugin( sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3) criterion = loss_fn - plugin = pluggin_cls(**test_config) booster = Booster(plugin=plugin) @@ -323,7 +322,6 @@ def check_output_hidden_state( sp_size = shard_config.sequence_parallel_size if org_hidden_state.shape[seq_dim] == sharded_hidden_state.shape[seq_dim] * sp_size: org_hidden_state = org_hidden_state.chunk(sp_size, dim=seq_dim)[dist.get_rank(sp_group)] - assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol) diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index 92c077950ecc..17a8bf318976 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -136,26 +136,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - { # Ulysess + Flash attention - "tp_size": 1, + { + "tp_size": 2, "pp_size": 2, - "sp_size": 2, "num_microbatches": 2, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", + "sequence_parallelism_mode": "split_gather", "enable_flash_attention": True, "use_lazy_init": True, "zero_stage": 1, "precision": "fp16", "initial_scale": 1, }, - { - "tp_size": 2, + { # Ulysess + Flash attention + "tp_size": 1, "pp_size": 2, "sp_size": 2, "num_microbatches": 2, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "split_gather", + "sequence_parallelism_mode": "all_to_all", "enable_flash_attention": True, "use_lazy_init": True, "zero_stage": 1, @@ -174,17 +173,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, - { - "tp_size": 4, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring", - "enable_flash_attention": False, - "use_lazy_init": True, - "precision": "fp32", - "initial_scale": 1, - }, { "tp_size": 4, "pp_size": 1, @@ -248,7 +236,11 @@ def run_chatglm_test(test_config): loss_fn, _, ) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + try: + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + except Exception as e: + print(f"Test config failed for model {name}: {test_config}") + raise e clear_layout_converter() torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py index efe5cee2a2b6..9435ef84bfa8 100644 --- a/tests/test_shardformer/test_model/test_shard_command.py +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -125,7 +125,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 5e-3 if org_model.__class__.__name__ == "CohereModel": - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + check_output_hidden_state( + org_output, + sharded_output, + stage_manager, + atol=atol, + rtol=rtol, + shard_config=booster.plugin.shard_config, + ) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) @@ -274,7 +281,11 @@ def run_command_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + try: + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + except Exception as e: + print(f"Failed test config: {test_config}") + raise e clear_layout_converter() Randomizer.reset_index() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index f9e368c0ebf3..393f7ffca7d3 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -100,7 +100,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 5e-3 if org_model.__class__.__name__ == "GPT2Model": - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + check_output_hidden_state( + org_output, + sharded_output, + stage_manager, + atol=atol, + rtol=rtol, + shard_config=booster.plugin.shard_config, + ) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) @@ -132,14 +139,27 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "test_config", [ { - "tp_size": 4, + "sp_size": 2, + "tp_size": 1, + "pp_size": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring_attn", + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "sp_size": 2, + "tp_size": 2, "pp_size": 1, - "num_microbatches": 1, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring", - "enable_flash_attention": False, + "sequence_parallelism_mode": "ring_attn", + "num_microbatches": 1, + "enable_all_optimization": True, "use_lazy_init": True, - "precision": "fp32", + "precision": "fp16", "initial_scale": 1, }, { @@ -148,7 +168,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": False, + "enable_flash_attention": True, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, @@ -156,7 +176,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, { "tp_size": 2, "pp_size": 2, - "num_microbatches": 4, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, "enable_all_optimization": True, "use_lazy_init": True, "precision": "fp16", @@ -185,7 +216,16 @@ def run_gpt2_test(test_config): loss_fn, _, ) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and name != "transformers_gpt_lm": + # Only wrote zigzag splitting for cross entropy loss + continue + + try: + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + except Exception as e: + print(f"Failed config: {test_config} for model {name}") + raise (e) clear_layout_converter() torch.cuda.empty_cache() @@ -226,7 +266,11 @@ def run_gpt2_3d_test(test_config): loss_fn, _, ) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + try: + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + except Exception as e: + print(f"Failed config: {test_config} for model {name}") + raise (e) clear_layout_converter() torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 3c66f609787a..e8f7916972a5 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -165,7 +165,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "zero_stage": 0, "precision": "fp16", "initial_scale": 1, - "inner_ring_size": 2, }, # Ring Attention + PP { @@ -215,18 +214,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "sequence_parallelism_mode": "all_to_all", "enable_all_optimization": True, "use_lazy_init": True, - "zero_stage": 0, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 4, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": True, - "use_lazy_init": True, + "zero_stage": 1, "precision": "fp16", "initial_scale": 1, }, @@ -294,6 +282,7 @@ def run_llama_test(test_config): except Exception as e: print(f"Failed config: {test_config}, model name: {name}") raise e + clear_layout_converter() Randomizer.reset_index() torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_qwen2.py b/tests/test_shardformer/test_model/test_shard_qwen2.py index c87415b7562d..865563adc625 100644 --- a/tests/test_shardformer/test_model/test_shard_qwen2.py +++ b/tests/test_shardformer/test_model/test_shard_qwen2.py @@ -94,6 +94,32 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 2, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + { # Ulysess + Flash attention + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2, @@ -135,32 +161,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, - { # Ulysess + Flash attention - "tp_size": 1, - "pp_size": 2, - "sp_size": 2, - "num_microbatches": 2, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "enable_flash_attention": True, - "use_lazy_init": True, - "zero_stage": 1, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 2, - "pp_size": 2, - "sp_size": 2, - "num_microbatches": 2, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": True, - "use_lazy_init": True, - "zero_stage": 1, - "precision": "fp16", - "initial_scale": 1, - }, { "tp_size": 2, "pp_size": 2,