From c7a19e2e4e75717d39dc197f714250cb1452bf5d Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 28 Jun 2024 07:42:56 +0000 Subject: [PATCH 01/71] halfway --- colossalai/shardformer/layer/_operation.py | 4 ++++ colossalai/shardformer/layer/attn.py | 20 ++++++-------------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 25983e0a93a6..a9060345d29a 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -812,7 +812,11 @@ def backward(ctx, *grad_output): process_group = ctx.process_group scatter_dim = ctx.gather_dim gather_dim = ctx.scatter_dim + if torch.distributed.get_rank() == 0: + print(f"shape before A2A: {grad_output[0].shape}") return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) + if torch.distributed.get_rank() == 0: + print(f"shape after A2A: {return_grad.shape}") return (return_grad, None, None, None) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 6dab17ec069f..50f45ef0dbfc 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -2,7 +2,6 @@ from typing import Callable, Dict, Optional, Tuple import torch -import torch.distributed import torch.distributed as dist import torch.nn.functional as F from einops import rearrange @@ -250,7 +249,12 @@ def attention( # sanity check if attention_mask is not None: assert torch.is_floating_point(attention_mask), "attention_mask should be a floating point tensor." - if attention_mask_type in (AttnMaskType.CUSTOM, AttnMaskType.CAUSAL): + if attention_mask_type in ( + AttnMaskType.CUSTOM, + AttnMaskType.CAUSAL, + AttnMaskType.PADDED, + AttnMaskType.PADDED_CAUSAL, + ): assert ( cu_seqlens_q is None and cu_seqlens_kv is None @@ -261,18 +265,6 @@ def attention( ) if attention_mask_type == AttnMaskType.CUSTOM: assert not torch.all(attention_mask != 0, dim=-1).any() - elif attention_mask_type in ( - AttnMaskType.PADDED, - AttnMaskType.PADDED_CAUSAL, - ): - assert ( - cu_seqlens_q is not None - and cu_seqlens_kv is not None - and max_seqlen_q is not None - and max_seqlen_kv is not None - and q_indices is not None - and kv_indices is not None - ) else: # if attention_mask is None, attention_mask_type should be the default value assert attention_mask_type == AttnMaskType.CUSTOM From 92b489103139b291ca8389f83edf2bf5212e5f97 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 28 Jun 2024 13:36:43 +0000 Subject: [PATCH 02/71] fix cross-PP-stage position id length diff bug --- tests/test_shardformer/test_model/test_shard_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 3c66f609787a..a38a4f64e2c2 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -59,7 +59,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if ( booster.plugin.zero_stage in [1, 2] and booster.plugin.shard_config.enable_sequence_parallelism - and booster.plugin.shard_config.pipeline_stage_manager is None + and booster.plugin.shard_config.pp_size == 1 and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): master2working = sharded_optimizer.get_master_to_working_map() From a9ed834f274026cfffbc33ee19292520229889b1 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Sat, 29 Jun 2024 02:34:57 +0000 Subject: [PATCH 03/71] fix typo --- tests/test_shardformer/test_model/test_shard_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index a38a4f64e2c2..e42b03b5e18e 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -59,7 +59,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if ( booster.plugin.zero_stage in [1, 2] and booster.plugin.shard_config.enable_sequence_parallelism - and booster.plugin.shard_config.pp_size == 1 + and booster.plugin.shard_config.pipeline_stage_manager is None and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): master2working = sharded_optimizer.get_master_to_working_map() From 20f2a73fc8a68e7899db2d706d78dcb93098975a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 29 Jun 2024 07:40:57 +0000 Subject: [PATCH 04/71] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_shardformer/test_model/test_shard_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index e42b03b5e18e..3c66f609787a 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -59,7 +59,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if ( booster.plugin.zero_stage in [1, 2] and booster.plugin.shard_config.enable_sequence_parallelism - and booster.plugin.shard_config.pipeline_stage_manager is None + and booster.plugin.shard_config.pipeline_stage_manager is None and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): master2working = sharded_optimizer.get_master_to_working_map() From edb90431adda083fb72808d80851090e658d9e61 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 2 Jul 2024 09:35:42 +0000 Subject: [PATCH 05/71] unified cross entropy func for all shardformer models --- examples/language/opt/opt_benchmark.py | 1 + tests/test_shardformer/test_model/test_shard_llama.py | 9 ++------- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/examples/language/opt/opt_benchmark.py b/examples/language/opt/opt_benchmark.py index ca9b63d1a14a..90f41fe1f767 100755 --- a/examples/language/opt/opt_benchmark.py +++ b/examples/language/opt/opt_benchmark.py @@ -135,4 +135,5 @@ def main(): if __name__ == "__main__": + print("--------------------------------------") main() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 3c66f609787a..f4e2140c1f3c 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -63,9 +63,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): master2working = sharded_optimizer.get_master_to_working_map() - for (name, p1), p2 in zip( - llama_model.named_parameters(), sharded_optimizer._master_param_groups_of_current_rank[0] - ): + for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): working_p = master2working[id(p2)] grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p)) grad_index = ( @@ -75,10 +73,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) grad = grads[grad_index] sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] - try: - assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) - except Exception as e: - raise RuntimeError(f"Failed to check grad for {name}") from e + assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} From bad530f29493ae6e2dee25f31841c3fc0a9a6790 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 2 Jul 2024 11:10:15 +0000 Subject: [PATCH 06/71] remove redundant lines --- examples/language/opt/opt_benchmark.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/language/opt/opt_benchmark.py b/examples/language/opt/opt_benchmark.py index 90f41fe1f767..ca9b63d1a14a 100755 --- a/examples/language/opt/opt_benchmark.py +++ b/examples/language/opt/opt_benchmark.py @@ -135,5 +135,4 @@ def main(): if __name__ == "__main__": - print("--------------------------------------") main() From 57746c0d0a8fa20421c042974996bd70be335cf7 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 8 Jul 2024 02:03:40 +0000 Subject: [PATCH 07/71] add basic ring attn; debug cross entropy --- colossalai/pipeline/schedule/one_f_one_b.py | 3 ++ colossalai/shardformer/layer/attn.py | 52 +++++++++++++++++++ colossalai/shardformer/layer/loss.py | 1 + colossalai/shardformer/layer/utils.py | 4 ++ colossalai/shardformer/shard/shard_config.py | 2 + tests/test_shardformer/test_model/_utils.py | 6 +-- .../test_model/test_shard_llama.py | 23 ++++++-- 7 files changed, 84 insertions(+), 7 deletions(-) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 03df67ae78c3..4c8519030b1c 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -32,6 +32,7 @@ def __init__( num_microbatches: Optional[int] = None, microbatch_size: Optional[int] = None, enable_metadata_cache: bool = True, + shard_config=None, ) -> None: """1F1B pipeline schedule. @@ -39,6 +40,7 @@ def __init__( stage_manager (PipelineStageManager): Pipeline stage manager num_microbatches (Optional[int], optional): The number of microbatches. If not provided, it will be derived from microbatch size. Defaults to None. microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None. + shard_config: Shard configuration for gathering Sequence Parallel loss. """ super().__init__(stage_manager) assert ( @@ -53,6 +55,7 @@ def __init__( self.batch_size: Optional[int] = None self.last_batch_size: Optional[int] = None self.microbatch_offset: Optional[int] = None + self.shard_config = shard_config # P2PMeta cache self.enable_metadata_cache = enable_metadata_cache diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 50f45ef0dbfc..b42bd71706d3 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -365,6 +365,58 @@ def _rescale_out_lse(out, block_out, lse, block_lse): return out, lse +@triton.jit +def flash_attn_fwd_out_corr_triton( + out_ptr, out_per_step_ptr, seq_dim, softmax_lse_ptr, softmax_lse_per_step_ptr, BLOCK_SIZE: tl.constexpr +): + # Calculate the global id + pid = tl.program_id(0) + + # Offsets for the current row + offsets = tl.arange(0, BLOCK_SIZE) + + # Pointers to the current row in out and out_per_step + row_start = pid * seq_dim + out_ptrs = out_ptr + row_start + offsets + out_per_step_ptrs = out_per_step_ptr + row_start + offsets + + # Load softmax_lse and softmax_lse_per_step + softmax_lse = tl.load(softmax_lse_ptr + pid) + softmax_lse_per_step = tl.load(softmax_lse_per_step_ptr + pid) + + # Compute the corrected exponentiation + softmax_lse_corrected_exp = tl.exp(softmax_lse_per_step - softmax_lse) + + out_per_step_vals = tl.load(out_per_step_ptrs) + + # Correct the out_per_step by the exponentiation + out_corrected = out_per_step_vals * softmax_lse_corrected_exp + + # Load the current out values + out_vals = tl.load(out_ptrs) + + # Add the corrected output to out + updated_out_vals = out_vals + out_corrected + + # Store the updated out values + tl.store(out_ptrs, updated_out_vals) + + +# Modified from Megatron-LM. TODO: try Triton +def flash_attn_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step): + softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim) + softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) + out_corrected = out_per_step * softmax_lse_corrected_exp + out.add_(out_corrected) + + +def flash_attn_softmax_lse_correction(softmax_lse, softmax_lse_per_step): + max_scale = torch.max(softmax_lse, softmax_lse_per_step) + min_scale = torch.min(softmax_lse, softmax_lse_per_step) + new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) + softmax_lse.copy_(new_scale) + + class RingAttention(torch.autograd.Function): """Implements the Ring Attention from `Ring Attention with Blockwise Transformers for Near-Infinite Context` (https://arxiv.org/abs/2310.01889). diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 12df824d1c0c..0583bcd9375d 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -1,5 +1,6 @@ import torch import torch.distributed as dist +import torch.nn.functional as F from torch.autograd import Function from torch.distributed import ProcessGroup from torch.nn import CrossEntropyLoss diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index c1a73ce05c97..00f3458364e8 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -1,5 +1,9 @@ from contextlib import contextmanager +<<<<<<< HEAD from typing import List, Optional, Union +======= +from typing import Dict, List +>>>>>>> add basic ring attn; debug cross entropy import torch import torch.distributed as dist diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 70eb271c9b69..589ed730ec79 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -2,6 +2,7 @@ from dataclasses import dataclass, field from typing import Any, Dict, Optional +import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -54,6 +55,7 @@ class ShardConfig: # for moe related moe_dp_group: Optional[ProcessGroup] = None ep_group: Optional[ProcessGroup] = None + sp_stream: Optional[torch.cuda.Stream] = None # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 9ad84341ac9e..cd26021a9c8b 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -316,13 +316,11 @@ def check_output_hidden_state( else: sharded_hidden_state = sharded_output.last_hidden_state - # Check if the output sequence is gathered before cross entropy - if shard_config is not None: + if shard_config and shard_config.parallel_output and shard_config.enable_sequence_parallelism: seq_dim = 1 sp_group = shard_config.sequence_parallel_process_group sp_size = shard_config.sequence_parallel_size - if org_hidden_state.shape[seq_dim] == sharded_hidden_state.shape[seq_dim] * sp_size: - org_hidden_state = org_hidden_state.chunk(sp_size, dim=seq_dim)[dist.get_rank(sp_group)] + org_hidden_state = org_hidden_state.chunk(sp_size, dim=seq_dim)[dist.get_rank(sp_group)] assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index f4e2140c1f3c..f7f9f6f05d60 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -1,4 +1,5 @@ import os +from copy import deepcopy import pytest import torch @@ -63,7 +64,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): master2working = sharded_optimizer.get_master_to_working_map() - for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): + for (name, p1), p2 in zip( + llama_model.named_parameters(), sharded_optimizer._master_param_groups_of_current_rank[0] + ): working_p = master2working[id(p2)] grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p)) grad_index = ( @@ -73,7 +76,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) grad = grads[grad_index] sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] - assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) + if name == "embed_tokens.weight": + continue + try: + assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) + except Exception as e: + raise RuntimeError(f"Failed to check grad for {name}") from e # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} @@ -174,6 +182,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "zero_stage": 1, "precision": "fp16", "initial_scale": 1, + "parallel_output": False, }, # Ring Attention + TP { @@ -187,6 +196,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "zero_stage": 2, "precision": "fp16", "initial_scale": 1, + "parallel_output": False, }, { # Ulysess + TP "tp_size": 2, @@ -213,6 +223,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "zero_stage": 0, "precision": "fp16", "initial_scale": 1, + "parallel_output": False, }, { "tp_size": 4, @@ -237,6 +248,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "zero_stage": 2, "precision": "fp16", "initial_scale": 1, + "parallel_output": False, }, { "tp_size": 2, @@ -285,7 +297,12 @@ def run_llama_test(test_config): if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name: continue try: - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + config = test_config + if name == "transformers_llama_for_casual_lm": + # Test the cross entropy loss distributed along sequence + config = deepcopy(test_config) + config["parallel_output"] = True + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, config) except Exception as e: print(f"Failed config: {test_config}, model name: {name}") raise e From 1607ea0f95c9670e5eb3e166b065132ff82dd825 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Sat, 13 Jul 2024 16:10:02 +0000 Subject: [PATCH 08/71] fwd bwd logic complete --- .../hybrid_parallel_checkpoint_io.py | 1 + colossalai/shardformer/layer/attn.py | 15 ++++++ colossalai/shardformer/layer/utils.py | 1 + examples/language/opt/opt_benchmark.py | 2 +- tests/kit/model_zoo/transformers/llama.py | 8 +++ tests/test_shardformer/test_model/_utils.py | 9 ++-- .../test_model/test_shard_llama.py | 51 +++++++++++++++---- 7 files changed, 71 insertions(+), 16 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 043e5c2b0618..6edc89313097 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -656,6 +656,7 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict. state_dict_list = [None for _ in range(self.pp_size)] dist.barrier(self.pp_group) + # torch.cuda.set_device(os.environ["LOCAL_RANK"]) dist.all_gather_object(state_dict_list, state_dict, self.pp_group) # Only the master rank do the saving. if self.coordinator.is_master(): diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index b42bd71706d3..7e3be96cd6f2 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -402,6 +402,17 @@ def flash_attn_fwd_out_corr_triton( tl.store(out_ptrs, updated_out_vals) +def flash_attn_out_lse_rescale(out, out_per_step, lse, lse_step): + """ + out: (B, Sq, H, D) + out_per_step: (B, Sq, H, D) + lse: (B, H, Sq, 1) + """ + new_lse = lse + torch.log(1 + torch.exp(lse_step - lse)) + out.copy_(torch.exp(lse - new_lse) * out + torch.exp(lse_step - new_lse) * out_per_step) + lse.copy_(new_lse) + + # Modified from Megatron-LM. TODO: try Triton def flash_attn_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step): softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim) @@ -411,6 +422,10 @@ def flash_attn_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_l def flash_attn_softmax_lse_correction(softmax_lse, softmax_lse_per_step): + """ + softmax_lse: (B, H, Sq) + softmax_lse_per_step: (B, H, Sq) + """ max_scale = torch.max(softmax_lse, softmax_lse_per_step) min_scale = torch.min(softmax_lse, softmax_lse_per_step) new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 00f3458364e8..e45e573a8072 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -12,6 +12,7 @@ from torch.distributed import ProcessGroup, get_world_size from colossalai.accelerator import get_accelerator +from colossalai.shardformer.layer.attn import get_pad_info class SeqParallelUtils: diff --git a/examples/language/opt/opt_benchmark.py b/examples/language/opt/opt_benchmark.py index ca9b63d1a14a..7b30f1939cf0 100755 --- a/examples/language/opt/opt_benchmark.py +++ b/examples/language/opt/opt_benchmark.py @@ -96,7 +96,7 @@ def main(): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) model, optimizer, _, _, _ = booster.boost(model, optimizer) - + booster.save_model(model, "model.pt") SEQ_LEN = 1024 VOCAB_SIZE = 50257 diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 05ac9d8d24ed..ac729cb1a3e2 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -90,6 +90,14 @@ def data_gen_for_causal_lm(): loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True), ) + model_zoo.register( + name="transformers_llama", + model_fn=lambda: transformers.LlamaModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), + ) model_zoo.register( name="transformers_llama_for_sequence_classification", model_fn=lambda: transformers.LlamaForSequenceClassification(config), diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index cd26021a9c8b..64b56eec1068 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -316,10 +316,11 @@ def check_output_hidden_state( else: sharded_hidden_state = sharded_output.last_hidden_state - if shard_config and shard_config.parallel_output and shard_config.enable_sequence_parallelism: - seq_dim = 1 - sp_group = shard_config.sequence_parallel_process_group - sp_size = shard_config.sequence_parallel_size + # Check if the output sequence is gathered before cross entropy + seq_dim = 1 + sp_group = shard_config.sequence_parallel_process_group + sp_size = shard_config.sequence_parallel_size + if org_hidden_state.shape[seq_dim] == sharded_hidden_state.shape[seq_dim] * sp_size: org_hidden_state = org_hidden_state.chunk(sp_size, dim=seq_dim)[dist.get_rank(sp_group)] assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index f7f9f6f05d60..85001f9b3f24 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -1,5 +1,4 @@ import os -from copy import deepcopy import pytest import torch @@ -76,8 +75,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) grad = grads[grad_index] sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] - if name == "embed_tokens.weight": - continue try: assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) except Exception as e: @@ -156,19 +153,31 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ +<<<<<<< HEAD # Double Ring Attention +======= +>>>>>>> fwd bwd logic complete { "tp_size": 1, "pp_size": 1, "sp_size": 4, "num_microbatches": 1, "enable_sequence_parallelism": True, +<<<<<<< HEAD "sequence_parallelism_mode": "ring_attn", "use_lazy_init": True, "zero_stage": 0, "precision": "fp16", "initial_scale": 1, "inner_ring_size": 2, +======= + "sequence_parallelism_mode": "all_to_all", + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + "parallel_output": True, +>>>>>>> fwd bwd logic complete }, # Ring Attention + PP { @@ -182,13 +191,30 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "zero_stage": 1, "precision": "fp16", "initial_scale": 1, - "parallel_output": False, + "parallel_output": True, }, +<<<<<<< HEAD # Ring Attention + TP { +======= + # { + # "tp_size": 2, + # "pp_size": 1, + # "sp_size": 2, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "ring_attn", + # "use_lazy_init": True, + # "zero_stage": 1, + # "precision": "fp16", + # "initial_scale": 1, + # "parallel_output": True, + # }, + { # Test ring + Flash attention +>>>>>>> fwd bwd logic complete "tp_size": 2, "pp_size": 1, - "sp_size": 2, + "sp_size": 1, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "ring_attn", @@ -196,6 +222,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "zero_stage": 2, "precision": "fp16", "initial_scale": 1, +<<<<<<< HEAD "parallel_output": False, }, { # Ulysess + TP @@ -224,6 +251,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, "parallel_output": False, +======= + "parallel_output": True, +>>>>>>> fwd bwd logic complete }, { "tp_size": 4, @@ -235,6 +265,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, +<<<<<<< HEAD }, { "tp_size": 2, @@ -249,6 +280,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, "parallel_output": False, +======= + "parallel_output": True, +>>>>>>> fwd bwd logic complete }, { "tp_size": 2, @@ -297,12 +331,7 @@ def run_llama_test(test_config): if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name: continue try: - config = test_config - if name == "transformers_llama_for_casual_lm": - # Test the cross entropy loss distributed along sequence - config = deepcopy(test_config) - config["parallel_output"] = True - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, config) + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) except Exception as e: print(f"Failed config: {test_config}, model name: {name}") raise e From 1796b806b2f77ca19acd16959f1c210dd8e8b9c4 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Sun, 14 Jul 2024 14:18:12 +0000 Subject: [PATCH 09/71] fwd bwd logic complete; add experimental triton rescale --- .../hybrid_parallel_checkpoint_io.py | 1 - colossalai/pipeline/schedule/one_f_one_b.py | 3 - colossalai/shardformer/layer/attn.py | 150 ++++++++++++------ colossalai/shardformer/layer/utils.py | 5 - colossalai/shardformer/modeling/llama.py | 7 + colossalai/shardformer/policies/llama.py | 5 + .../test_model/test_shard_llama.py | 52 ++---- 7 files changed, 128 insertions(+), 95 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 6edc89313097..043e5c2b0618 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -656,7 +656,6 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict. state_dict_list = [None for _ in range(self.pp_size)] dist.barrier(self.pp_group) - # torch.cuda.set_device(os.environ["LOCAL_RANK"]) dist.all_gather_object(state_dict_list, state_dict, self.pp_group) # Only the master rank do the saving. if self.coordinator.is_master(): diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 4c8519030b1c..03df67ae78c3 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -32,7 +32,6 @@ def __init__( num_microbatches: Optional[int] = None, microbatch_size: Optional[int] = None, enable_metadata_cache: bool = True, - shard_config=None, ) -> None: """1F1B pipeline schedule. @@ -40,7 +39,6 @@ def __init__( stage_manager (PipelineStageManager): Pipeline stage manager num_microbatches (Optional[int], optional): The number of microbatches. If not provided, it will be derived from microbatch size. Defaults to None. microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None. - shard_config: Shard configuration for gathering Sequence Parallel loss. """ super().__init__(stage_manager) assert ( @@ -55,7 +53,6 @@ def __init__( self.batch_size: Optional[int] = None self.last_batch_size: Optional[int] = None self.microbatch_offset: Optional[int] = None - self.shard_config = shard_config # P2PMeta cache self.enable_metadata_cache = enable_metadata_cache diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 7e3be96cd6f2..efbd5cb33d1b 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -46,9 +46,13 @@ def invert_mask(mask: torch.Tensor) -> torch.Tensor: # adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py +<<<<<<< HEAD def get_pad_info( padding_mask: torch.Tensor, invert: Optional[bool] = False, return_indices: Optional[bool] = True ) -> Tuple[int, torch.Tensor, torch.Tensor]: +======= +def get_pad_info(padding_mask: torch.Tensor, invert: Optional[bool] = False) -> Tuple[int, torch.Tensor, torch.Tensor]: +>>>>>>> fwd bwd logic complete; add experimental triton rescale """Get padding information from padding mask. Args: @@ -366,43 +370,87 @@ def _rescale_out_lse(out, block_out, lse, block_lse): @triton.jit -def flash_attn_fwd_out_corr_triton( - out_ptr, out_per_step_ptr, seq_dim, softmax_lse_ptr, softmax_lse_per_step_ptr, BLOCK_SIZE: tl.constexpr +def flash_attn_out_lse_rescale_kernel( + out_ptr, + out_per_step_ptr, + lse_ptr, + lse_step_ptr, + B, + Sq, + H, + D, + stride_out_0, + stride_out_1, + stride_out_2, + stride_out_3, + stride_out_per_step_0, + stride_out_per_step_1, + stride_out_per_step_2, + stride_out_per_step_3, + stride_lse_0, + stride_lse_1, + stride_lse_2, + stride_lse_3, ): - # Calculate the global id - pid = tl.program_id(0) - - # Offsets for the current row - offsets = tl.arange(0, BLOCK_SIZE) - - # Pointers to the current row in out and out_per_step - row_start = pid * seq_dim - out_ptrs = out_ptr + row_start + offsets - out_per_step_ptrs = out_per_step_ptr + row_start + offsets - - # Load softmax_lse and softmax_lse_per_step - softmax_lse = tl.load(softmax_lse_ptr + pid) - softmax_lse_per_step = tl.load(softmax_lse_per_step_ptr + pid) - - # Compute the corrected exponentiation - softmax_lse_corrected_exp = tl.exp(softmax_lse_per_step - softmax_lse) - - out_per_step_vals = tl.load(out_per_step_ptrs) - - # Correct the out_per_step by the exponentiation - out_corrected = out_per_step_vals * softmax_lse_corrected_exp - - # Load the current out values - out_vals = tl.load(out_ptrs) - - # Add the corrected output to out - updated_out_vals = out_vals + out_corrected - - # Store the updated out values - tl.store(out_ptrs, updated_out_vals) - - -def flash_attn_out_lse_rescale(out, out_per_step, lse, lse_step): + batch_id = tl.program_id(0) + sq_id = tl.program_id(1) + h_id = tl.program_id(2) + d_id = tl.arange(0, D) + + out_idx = batch_id * stride_out_0 + sq_id * stride_out_1 + h_id * stride_out_2 + d_id * stride_out_3 + out_per_step_idx = ( + batch_id * stride_out_per_step_0 + + sq_id * stride_out_per_step_1 + + h_id * stride_out_per_step_2 + + d_id * stride_out_per_step_3 + ) + lse_idx = batch_id * stride_lse_0 + h_id * stride_lse_1 + sq_id * stride_lse_2 + tl.zeros(D) * stride_lse_3 + lse_step_idx = batch_id * stride_lse_0 + h_id * stride_lse_1 + sq_id * stride_lse_2 + tl.zeros(D) * stride_lse_3 + + out = tl.load(out_ptr + out_idx) + out_per_step = tl.load(out_per_step_ptr + out_per_step_idx) + lse = tl.load(lse_ptr + lse_idx) + lse_step = tl.load(lse_step_ptr + lse_step_idx) + + new_lse = lse + tl.log(1 + tl.exp(lse_step - lse)) + out = tl.exp(lse - new_lse) * out + tl.exp(lse_step - new_lse) * out_per_step + + tl.store(out_ptr + out_idx, out) + tl.store(lse_ptr + lse_idx, new_lse) + + +def rescale_out_lse_triton(out, out_per_step, lse, lse_step): + B, Sq, H, D = out.shape + + assert out.is_contiguous() and out_per_step.is_contiguous() and lse.is_contiguous() and lse_step.is_contiguous() + + grid = (B, Sq, H) + + flash_attn_out_lse_rescale_kernel[grid]( + out, + out_per_step, + lse, + lse_step, + B, + Sq, + H, + D, + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + out_per_step.stride(0), + out_per_step.stride(1), + out_per_step.stride(2), + out_per_step.stride(3), + lse.stride(0), + lse.stride(1), + lse.stride(2), + lse.stride(3), + ) + + +def rescale_out_lse(out, out_per_step, lse, lse_step): """ out: (B, Sq, H, D) out_per_step: (B, Sq, H, D) @@ -413,23 +461,23 @@ def flash_attn_out_lse_rescale(out, out_per_step, lse, lse_step): lse.copy_(new_lse) -# Modified from Megatron-LM. TODO: try Triton -def flash_attn_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step): - softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim) - softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) - out_corrected = out_per_step * softmax_lse_corrected_exp - out.add_(out_corrected) +# From Megatron-LM. TODO: try Triton +# def flash_attn_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step): +# softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim) +# softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) +# out_corrected = out_per_step * softmax_lse_corrected_exp +# out.add_(out_corrected) -def flash_attn_softmax_lse_correction(softmax_lse, softmax_lse_per_step): - """ - softmax_lse: (B, H, Sq) - softmax_lse_per_step: (B, H, Sq) - """ - max_scale = torch.max(softmax_lse, softmax_lse_per_step) - min_scale = torch.min(softmax_lse, softmax_lse_per_step) - new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) - softmax_lse.copy_(new_scale) +# def flash_attn_softmax_lse_correction(softmax_lse, softmax_lse_per_step): +# """ +# softmax_lse: (B, H, Sq) +# softmax_lse_per_step: (B, H, Sq) +# """ +# max_scale = torch.max(softmax_lse, softmax_lse_per_step) +# min_scale = torch.min(softmax_lse, softmax_lse_per_step) +# new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) +# softmax_lse.copy_(new_scale) class RingAttention(torch.autograd.Function): diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index e45e573a8072..c1a73ce05c97 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -1,9 +1,5 @@ from contextlib import contextmanager -<<<<<<< HEAD from typing import List, Optional, Union -======= -from typing import Dict, List ->>>>>>> add basic ring attn; debug cross entropy import torch import torch.distributed as dist @@ -12,7 +8,6 @@ from torch.distributed import ProcessGroup, get_world_size from colossalai.accelerator import get_accelerator -from colossalai.shardformer.layer.attn import get_pad_info class SeqParallelUtils: diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index af610500a8eb..4b75d5261225 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -27,10 +27,17 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import AttnMaskType from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward +<<<<<<< HEAD from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag from colossalai.shardformer.shard import ShardConfig from ..layer import ColoAttention, RingAttention, dist_cross_entropy +======= +from colossalai.shardformer.layer.utils import is_share_sp_tp, zigzag_split_batch +from colossalai.shardformer.shard import ShardConfig + +from ..layer import ColoAttention, RingAttention, dist_cross_entropy, get_pad_info +>>>>>>> fwd bwd logic complete; add experimental triton rescale _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index f72a72df0b1b..76b824d8dd14 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -77,6 +77,11 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: num_q_heads = self.model.config.num_attention_heads num_kv_heads = getattr(self.model.config, "num_key_value_heads", None) + tp_size = self.shard_config.tensor_parallel_size + # Modified by SP and TP + num_q_heads = self.model.config.num_attention_heads + num_kv_heads = getattr(self.model.config, "num_key_value_heads", None) + if sp_mode == "all_to_all": num_q_heads //= sp_size decoder_attribute_replacement = {"num_heads": num_q_heads} diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 85001f9b3f24..73360c011272 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -153,65 +153,55 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ -<<<<<<< HEAD # Double Ring Attention -======= ->>>>>>> fwd bwd logic complete + # Zigzag Ring Attention { - "tp_size": 1, + "tp_size": 2, "pp_size": 1, "sp_size": 4, "num_microbatches": 1, "enable_sequence_parallelism": True, -<<<<<<< HEAD "sequence_parallelism_mode": "ring_attn", "use_lazy_init": True, "zero_stage": 0, "precision": "fp16", "initial_scale": 1, "inner_ring_size": 2, -======= + }, + # Ring Attention + PP + { + { # Ulysess + TP + "tp_size": 2, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, "sequence_parallelism_mode": "all_to_all", + "enable_all_optimization": True, "use_lazy_init": True, - "zero_stage": 1, + "zero_stage": 0, "precision": "fp16", "initial_scale": 1, "parallel_output": True, ->>>>>>> fwd bwd logic complete }, - # Ring Attention + PP - { + { # Ulysess + PP "tp_size": 1, "pp_size": 2, "sp_size": 2, "num_microbatches": 2, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "ring_attn", + "sequence_parallelism_mode": "all_to_all", + "enable_all_optimization": True, "use_lazy_init": True, "zero_stage": 1, "precision": "fp16", "initial_scale": 1, "parallel_output": True, }, -<<<<<<< HEAD # Ring Attention + TP { -======= - # { - # "tp_size": 2, - # "pp_size": 1, - # "sp_size": 2, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "ring_attn", - # "use_lazy_init": True, - # "zero_stage": 1, - # "precision": "fp16", - # "initial_scale": 1, - # "parallel_output": True, - # }, - { # Test ring + Flash attention ->>>>>>> fwd bwd logic complete + { "tp_size": 2, "pp_size": 1, "sp_size": 1, @@ -222,7 +212,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "zero_stage": 2, "precision": "fp16", "initial_scale": 1, -<<<<<<< HEAD "parallel_output": False, }, { # Ulysess + TP @@ -251,9 +240,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, "parallel_output": False, -======= - "parallel_output": True, ->>>>>>> fwd bwd logic complete }, { "tp_size": 4, @@ -265,7 +251,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, -<<<<<<< HEAD }, { "tp_size": 2, @@ -280,9 +265,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, "parallel_output": False, -======= - "parallel_output": True, ->>>>>>> fwd bwd logic complete }, { "tp_size": 2, From 500454b0d09ff66b6faeca5b205848a554b21c06 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 18 Jul 2024 07:17:08 +0000 Subject: [PATCH 10/71] precision tests passed --- colossalai/shardformer/layer/_operation.py | 4 - colossalai/shardformer/layer/attn.py | 73 ++++++++++++------- colossalai/shardformer/modeling/llama.py | 9 +-- colossalai/shardformer/policies/llama.py | 3 + examples/language/opt/opt_benchmark.py | 1 - .../test_model/test_shard_llama.py | 33 +++++++-- 6 files changed, 77 insertions(+), 46 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index a9060345d29a..25983e0a93a6 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -812,11 +812,7 @@ def backward(ctx, *grad_output): process_group = ctx.process_group scatter_dim = ctx.gather_dim gather_dim = ctx.scatter_dim - if torch.distributed.get_rank() == 0: - print(f"shape before A2A: {grad_output[0].shape}") return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) - if torch.distributed.get_rank() == 0: - print(f"shape after A2A: {return_grad.shape}") return (return_grad, None, None, None) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index efbd5cb33d1b..d05b3dc7f904 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -36,7 +36,7 @@ def invert_mask(mask: torch.Tensor) -> torch.Tensor: """Invert the mask tensor. Args: - mask (torch.Tensor): Mask tensor. Shape should be [B, 1, Sq, Skv] + mask (torch.Tensor): Mask tensor. Shape should be [B, 1, Sq, Sq] Returns: torch.Tensor: Inverted mask tensor. @@ -46,13 +46,9 @@ def invert_mask(mask: torch.Tensor) -> torch.Tensor: # adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py -<<<<<<< HEAD def get_pad_info( padding_mask: torch.Tensor, invert: Optional[bool] = False, return_indices: Optional[bool] = True ) -> Tuple[int, torch.Tensor, torch.Tensor]: -======= -def get_pad_info(padding_mask: torch.Tensor, invert: Optional[bool] = False) -> Tuple[int, torch.Tensor, torch.Tensor]: ->>>>>>> fwd bwd logic complete; add experimental triton rescale """Get padding information from padding mask. Args: @@ -369,16 +365,17 @@ def _rescale_out_lse(out, block_out, lse, block_lse): return out, lse +def _not_nan(x): + return not (x.isnan().any() or x.isinf().any()) + + @triton.jit -def flash_attn_out_lse_rescale_kernel( +def _rescale_out_lse_kernel( out_ptr, out_per_step_ptr, lse_ptr, lse_step_ptr, - B, - Sq, - H, - D, + D, # Each thread handles D elements stride_out_0, stride_out_1, stride_out_2, @@ -391,6 +388,7 @@ def flash_attn_out_lse_rescale_kernel( stride_lse_1, stride_lse_2, stride_lse_3, + BLOCK_M: tl.constexpr, ): batch_id = tl.program_id(0) sq_id = tl.program_id(1) @@ -407,11 +405,13 @@ def flash_attn_out_lse_rescale_kernel( lse_idx = batch_id * stride_lse_0 + h_id * stride_lse_1 + sq_id * stride_lse_2 + tl.zeros(D) * stride_lse_3 lse_step_idx = batch_id * stride_lse_0 + h_id * stride_lse_1 + sq_id * stride_lse_2 + tl.zeros(D) * stride_lse_3 + # Load inputs out = tl.load(out_ptr + out_idx) out_per_step = tl.load(out_per_step_ptr + out_per_step_idx) lse = tl.load(lse_ptr + lse_idx) lse_step = tl.load(lse_step_ptr + lse_step_idx) + # Element-wise rescale new_lse = lse + tl.log(1 + tl.exp(lse_step - lse)) out = tl.exp(lse - new_lse) * out + tl.exp(lse_step - new_lse) * out_per_step @@ -419,18 +419,18 @@ def flash_attn_out_lse_rescale_kernel( tl.store(lse_ptr + lse_idx, new_lse) -def rescale_out_lse_triton(out, out_per_step, lse, lse_step): +def _rescale_out_lse_triton(out, block_out, lse, block_lse): B, Sq, H, D = out.shape - assert out.is_contiguous() and out_per_step.is_contiguous() and lse.is_contiguous() and lse_step.is_contiguous() + assert out.is_contiguous() and block_out.is_contiguous() and lse.is_contiguous() and block_lse.is_contiguous() - grid = (B, Sq, H) - - flash_attn_out_lse_rescale_kernel[grid]( + # TODO: use 1d kernel? + grid = lambda META: (triton.cdiv(Sq, META["BLOCK_M"]), B, H) + _rescale_out_lse_kernel[grid]( out, - out_per_step, + block_out, lse, - lse_step, + block_lse, B, Sq, H, @@ -439,10 +439,10 @@ def rescale_out_lse_triton(out, out_per_step, lse, lse_step): out.stride(1), out.stride(2), out.stride(3), - out_per_step.stride(0), - out_per_step.stride(1), - out_per_step.stride(2), - out_per_step.stride(3), + block_out.stride(0), + block_out.stride(1), + block_out.stride(2), + block_out.stride(3), lse.stride(0), lse.stride(1), lse.stride(2), @@ -450,16 +450,35 @@ def rescale_out_lse_triton(out, out_per_step, lse, lse_step): ) -def rescale_out_lse(out, out_per_step, lse, lse_step): +def _rescale_out_lse(out, block_out, lse, block_lse): """ - out: (B, Sq, H, D) - out_per_step: (B, Sq, H, D) - lse: (B, H, Sq, 1) + Compute the new attention denominator: + exp(lse) + exp(block_lse) = exp(max_scale) * (exp(min_scale - max_scale) + 1) + Args: + out: (B, Sq, H, D) + block_out: (B, Sq, H, D) + lse: (B, H, Sq, 1) + block_lse: (B, H, Sq, 1) """ - new_lse = lse + torch.log(1 + torch.exp(lse_step - lse)) - out.copy_(torch.exp(lse - new_lse) * out + torch.exp(lse_step - new_lse) * out_per_step) + + # min_scale = torch.min(lse, block_lse) + # max_scale = torch.max(lse, block_lse) + # new_lse = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) + new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + new_block_lse = torch.exp(block_lse - new_lse) + assert _not_nan(new_lse), new_lse + # dist.barrier() + assert _not_nan(new_block_lse), new_block_lse + + out.copy_(torch.exp(lse - new_lse) * out + new_block_lse * block_out) lse.copy_(new_lse) + # block_out = block_out.float() + # out.copy_(out - F.sigmoid(block_lse - lse) * (out - block_out)) + # lse.copy_(lse - F.logsigmoid(lse - block_lse)) + # assert not lse.isnan().any(), lse + # assert not out.isnan().any(), out + # From Megatron-LM. TODO: try Triton # def flash_attn_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step): diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 4b75d5261225..ff033e5597a8 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -27,17 +27,10 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import AttnMaskType from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward -<<<<<<< HEAD from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag from colossalai.shardformer.shard import ShardConfig from ..layer import ColoAttention, RingAttention, dist_cross_entropy -======= -from colossalai.shardformer.layer.utils import is_share_sp_tp, zigzag_split_batch -from colossalai.shardformer.shard import ShardConfig - -from ..layer import ColoAttention, RingAttention, dist_cross_entropy, get_pad_info ->>>>>>> fwd bwd logic complete; add experimental triton rescale _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] @@ -833,6 +826,8 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if shard_config.sequence_parallelism_mode == "ring_attn": + labels = zigzag_split_batch([labels], shard_config.sequence_parallel_process_group)[0] if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output: # Special processing: Split labels in a zigzag fashion too diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 76b824d8dd14..ce1e0e3de35d 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -71,11 +71,14 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: sp_partial_derived = sp_mode in ["split_gather", "ring"] if sp_mode == "ring_attn" and not self.is_causal: raise ValueError("Ring attention is only meant for causal language modeling.") +<<<<<<< HEAD tp_size = self.shard_config.tensor_parallel_size # Modified by SP and TP num_q_heads = self.model.config.num_attention_heads num_kv_heads = getattr(self.model.config, "num_key_value_heads", None) +======= +>>>>>>> precision tests passed tp_size = self.shard_config.tensor_parallel_size # Modified by SP and TP diff --git a/examples/language/opt/opt_benchmark.py b/examples/language/opt/opt_benchmark.py index 7b30f1939cf0..5e5971d9f560 100755 --- a/examples/language/opt/opt_benchmark.py +++ b/examples/language/opt/opt_benchmark.py @@ -96,7 +96,6 @@ def main(): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) model, optimizer, _, _, _ = booster.boost(model, optimizer) - booster.save_model(model, "model.pt") SEQ_LEN = 1024 VOCAB_SIZE = 50257 diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 73360c011272..e69283a3b18d 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -163,10 +163,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "ring_attn", "use_lazy_init": True, +<<<<<<< HEAD "zero_stage": 0, "precision": "fp16", "initial_scale": 1, "inner_ring_size": 2, +======= + "zero_stage": 1, + "precision": "bf16", + "initial_scale": 1, +>>>>>>> precision tests passed }, # Ring Attention + PP { @@ -182,7 +188,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "zero_stage": 0, "precision": "fp16", "initial_scale": 1, - "parallel_output": True, }, { # Ulysess + PP "tp_size": 1, @@ -197,21 +202,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "zero_stage": 1, "precision": "fp16", "initial_scale": 1, - "parallel_output": True, }, # Ring Attention + TP { { - "tp_size": 2, + "tp_size": 4, "pp_size": 1, - "sp_size": 1, "num_microbatches": 1, "enable_sequence_parallelism": True, +<<<<<<< HEAD "sequence_parallelism_mode": "ring_attn", +======= + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": True, +>>>>>>> precision tests passed "use_lazy_init": True, - "zero_stage": 2, "precision": "fp16", "initial_scale": 1, +<<<<<<< HEAD "parallel_output": False, }, { # Ulysess + TP @@ -240,17 +248,22 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, "parallel_output": False, +======= +>>>>>>> precision tests passed }, { - "tp_size": 4, + "tp_size": 2, "pp_size": 1, + "sp_size": 1, "num_microbatches": 1, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "split_gather", + "sequence_parallelism_mode": "ring", "enable_flash_attention": True, "use_lazy_init": True, + "zero_stage": 2, "precision": "fp16", "initial_scale": 1, +<<<<<<< HEAD }, { "tp_size": 2, @@ -265,6 +278,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, "parallel_output": False, +======= +>>>>>>> precision tests passed }, { "tp_size": 2, @@ -312,6 +327,10 @@ def run_llama_test(test_config): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name: continue +<<<<<<< HEAD +======= + +>>>>>>> precision tests passed try: check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) except Exception as e: From c1ea3ba8fe0f27f4e6f85f4877603baa15550e6e Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Sun, 21 Jul 2024 14:32:49 +0000 Subject: [PATCH 11/71] precision tests passed --- .../pipeline/schedule/interleaved_pp.py | 2 + colossalai/shardformer/layer/attn.py | 136 +----- colossalai/shardformer/layer/utils.py | 2 +- colossalai/shardformer/modeling/llama.py | 4 +- .../benchmark/benchmark_qkvpacked_func.py | 87 ++++ .../benchmark_varlen_qkvpacked_func.py | 91 ++++ .../ring_flash_attn/__init__.py | 16 + .../ring_flash_attn/ring_flash_attn.py | 281 +++++++++++ .../ring_flash_attn/ring_flash_attn_varlen.py | 318 +++++++++++++ .../ring_flash_attn/stripe_flash_attn.py | 325 +++++++++++++ .../ring_flash_attn/triton_utils.py | 137 ++++++ ring-flash-attention/ring_flash_attn/utils.py | 110 +++++ .../ring_flash_attn/zigzag_ring_flash_attn.py | 327 +++++++++++++ .../zigzag_ring_flash_attn_varlen.py | 441 ++++++++++++++++++ ring-flash-attention/setup.py | 9 + .../test/test_ring_flash_attn_func.py | 124 +++++ .../test/test_ring_flash_attn_varlen_func.py | 157 +++++++ .../test/test_stripe_flash_attn_func.py | 130 ++++++ .../test/test_triton_kernels.py | 30 ++ .../test/test_zigzag_ring_flash_attn_func.py | 150 ++++++ ...test_zigzag_ring_flash_attn_varlen_func.py | 163 +++++++ tests/test_shardformer/test_model/_utils.py | 11 +- .../test_model/test_shard_llama.py | 24 +- 23 files changed, 2932 insertions(+), 143 deletions(-) create mode 100644 ring-flash-attention/benchmark/benchmark_qkvpacked_func.py create mode 100644 ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py create mode 100644 ring-flash-attention/ring_flash_attn/__init__.py create mode 100644 ring-flash-attention/ring_flash_attn/ring_flash_attn.py create mode 100644 ring-flash-attention/ring_flash_attn/ring_flash_attn_varlen.py create mode 100644 ring-flash-attention/ring_flash_attn/stripe_flash_attn.py create mode 100644 ring-flash-attention/ring_flash_attn/triton_utils.py create mode 100644 ring-flash-attention/ring_flash_attn/utils.py create mode 100644 ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn.py create mode 100644 ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py create mode 100644 ring-flash-attention/setup.py create mode 100644 ring-flash-attention/test/test_ring_flash_attn_func.py create mode 100644 ring-flash-attention/test/test_ring_flash_attn_varlen_func.py create mode 100644 ring-flash-attention/test/test_stripe_flash_attn_func.py create mode 100644 ring-flash-attention/test/test_triton_kernels.py create mode 100644 ring-flash-attention/test/test_zigzag_ring_flash_attn_func.py create mode 100644 ring-flash-attention/test/test_zigzag_ring_flash_attn_varlen_func.py diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 412f3896fb80..8f26f8cb5bb5 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -283,6 +283,8 @@ def forward_step( # Load input ids, attention mask and labels micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) + if input_obj is not None: + assert all(not x.isnan().any() for x in input_obj.values()), "NaN detected in input_obj" # for the first stage, input_obj is None # for other stages, input_obj is the output of the previous stage containing hidden_states etc. # Only attention_mask from micro_batch is used diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index d05b3dc7f904..17a965bfac03 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -16,6 +16,8 @@ from .utils import RingComm, get_half_index, split_varlen_zigzag +from .utils import RingComm + __all__ = [ "AttnMaskType", "ColoAttention", @@ -365,140 +367,6 @@ def _rescale_out_lse(out, block_out, lse, block_lse): return out, lse -def _not_nan(x): - return not (x.isnan().any() or x.isinf().any()) - - -@triton.jit -def _rescale_out_lse_kernel( - out_ptr, - out_per_step_ptr, - lse_ptr, - lse_step_ptr, - D, # Each thread handles D elements - stride_out_0, - stride_out_1, - stride_out_2, - stride_out_3, - stride_out_per_step_0, - stride_out_per_step_1, - stride_out_per_step_2, - stride_out_per_step_3, - stride_lse_0, - stride_lse_1, - stride_lse_2, - stride_lse_3, - BLOCK_M: tl.constexpr, -): - batch_id = tl.program_id(0) - sq_id = tl.program_id(1) - h_id = tl.program_id(2) - d_id = tl.arange(0, D) - - out_idx = batch_id * stride_out_0 + sq_id * stride_out_1 + h_id * stride_out_2 + d_id * stride_out_3 - out_per_step_idx = ( - batch_id * stride_out_per_step_0 - + sq_id * stride_out_per_step_1 - + h_id * stride_out_per_step_2 - + d_id * stride_out_per_step_3 - ) - lse_idx = batch_id * stride_lse_0 + h_id * stride_lse_1 + sq_id * stride_lse_2 + tl.zeros(D) * stride_lse_3 - lse_step_idx = batch_id * stride_lse_0 + h_id * stride_lse_1 + sq_id * stride_lse_2 + tl.zeros(D) * stride_lse_3 - - # Load inputs - out = tl.load(out_ptr + out_idx) - out_per_step = tl.load(out_per_step_ptr + out_per_step_idx) - lse = tl.load(lse_ptr + lse_idx) - lse_step = tl.load(lse_step_ptr + lse_step_idx) - - # Element-wise rescale - new_lse = lse + tl.log(1 + tl.exp(lse_step - lse)) - out = tl.exp(lse - new_lse) * out + tl.exp(lse_step - new_lse) * out_per_step - - tl.store(out_ptr + out_idx, out) - tl.store(lse_ptr + lse_idx, new_lse) - - -def _rescale_out_lse_triton(out, block_out, lse, block_lse): - B, Sq, H, D = out.shape - - assert out.is_contiguous() and block_out.is_contiguous() and lse.is_contiguous() and block_lse.is_contiguous() - - # TODO: use 1d kernel? - grid = lambda META: (triton.cdiv(Sq, META["BLOCK_M"]), B, H) - _rescale_out_lse_kernel[grid]( - out, - block_out, - lse, - block_lse, - B, - Sq, - H, - D, - out.stride(0), - out.stride(1), - out.stride(2), - out.stride(3), - block_out.stride(0), - block_out.stride(1), - block_out.stride(2), - block_out.stride(3), - lse.stride(0), - lse.stride(1), - lse.stride(2), - lse.stride(3), - ) - - -def _rescale_out_lse(out, block_out, lse, block_lse): - """ - Compute the new attention denominator: - exp(lse) + exp(block_lse) = exp(max_scale) * (exp(min_scale - max_scale) + 1) - Args: - out: (B, Sq, H, D) - block_out: (B, Sq, H, D) - lse: (B, H, Sq, 1) - block_lse: (B, H, Sq, 1) - """ - - # min_scale = torch.min(lse, block_lse) - # max_scale = torch.max(lse, block_lse) - # new_lse = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) - new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) - new_block_lse = torch.exp(block_lse - new_lse) - assert _not_nan(new_lse), new_lse - # dist.barrier() - assert _not_nan(new_block_lse), new_block_lse - - out.copy_(torch.exp(lse - new_lse) * out + new_block_lse * block_out) - lse.copy_(new_lse) - - # block_out = block_out.float() - # out.copy_(out - F.sigmoid(block_lse - lse) * (out - block_out)) - # lse.copy_(lse - F.logsigmoid(lse - block_lse)) - # assert not lse.isnan().any(), lse - # assert not out.isnan().any(), out - - -# From Megatron-LM. TODO: try Triton -# def flash_attn_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step): -# softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim) -# softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) -# out_corrected = out_per_step * softmax_lse_corrected_exp -# out.add_(out_corrected) - - -# def flash_attn_softmax_lse_correction(softmax_lse, softmax_lse_per_step): -# """ -# softmax_lse: (B, H, Sq) -# softmax_lse_per_step: (B, H, Sq) -# """ -# max_scale = torch.max(softmax_lse, softmax_lse_per_step) -# min_scale = torch.min(softmax_lse, softmax_lse_per_step) -# new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) -# softmax_lse.copy_(new_scale) - - class RingAttention(torch.autograd.Function): """Implements the Ring Attention from `Ring Attention with Blockwise Transformers for Near-Infinite Context` (https://arxiv.org/abs/2310.01889). diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index c1a73ce05c97..53d9576894f6 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -331,7 +331,7 @@ def split_batch_zigzag( indices = torch.tensor([sp_rank, 2 * sp_size - 1 - sp_rank], device=tensor.device) tensor = tensor.index_select(seq_dim, indices).contiguous() # (B, 2, Sq // (2 * sp_size), ...) -> (B, Sq // sp_size, ...) - batch[idx] = tensor.view(*tensor.shape[:seq_dim], -1, *tensor.shape[seq_dim + 2 :]) + batch[idx] = tensor.view(*tensor.shape[:seq_dim], -1, *tensor.shape[seq_dim + 2 :]).contiguous() if len(batch) == 1: return batch[0] diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index ff033e5597a8..258ba3051f8b 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -562,7 +562,9 @@ def forward( # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) + assert not self.q_proj.weight.isnan().any(), self.q_proj.weight + assert not query_states.isnan().any(), query_states if sp_mode == "ring_attn": attn_output = RingAttention.attention( query_states, @@ -827,7 +829,7 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if shard_config.sequence_parallelism_mode == "ring_attn": - labels = zigzag_split_batch([labels], shard_config.sequence_parallel_process_group)[0] + labels = zigzag_split_batch(labels, shard_config.sequence_parallel_process_group) if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output: # Special processing: Split labels in a zigzag fashion too diff --git a/ring-flash-attention/benchmark/benchmark_qkvpacked_func.py b/ring-flash-attention/benchmark/benchmark_qkvpacked_func.py new file mode 100644 index 000000000000..a6742e04a696 --- /dev/null +++ b/ring-flash-attention/benchmark/benchmark_qkvpacked_func.py @@ -0,0 +1,87 @@ +import torch +import torch.cuda +import torch.distributed as dist +from flash_attn import flash_attn_qkvpacked_func +from ring_flash_attn import ( + ring_flash_attn_qkvpacked_func, + stripe_flash_attn_qkvpacked_func, + zigzag_ring_flash_attn_qkvpacked_func, +) + + +def benchmark(f, num_iter=100, forward_only=True, log=True): + dtype = torch.bfloat16 + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + batch_size = 1 + seqlen = 1024 * 8 + nheads = 5 + d = 128 + dropout_p = 0 + causal = True + deterministic = False + + assert seqlen % (2 * world_size) == 0 + assert d % 8 == 0 + + qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) + dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) + + begin = torch.cuda.Event(enable_timing=True) + begin.record() + + if forward_only: + with torch.no_grad(): + for _ in range(num_iter): + _ = f( + qkv, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=False, + ) + + else: + for _ in range(num_iter): + qkv.grad = None + out = f( + qkv, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=False, + ) + out.backward(dout) + end = torch.cuda.Event(enable_timing=True) + end.record() + torch.cuda.synchronize(device=device) + time = begin.elapsed_time(end) / 1000.0 + + if rank == 0 and log: + print(f"{num_iter / time:.3f} iter/s, {time:.3f} sec") + + +if __name__ == "__main__": + dist.init_process_group("nccl") + rank = dist.get_rank() + + forward_only = False + + for f in [ + flash_attn_qkvpacked_func, + ring_flash_attn_qkvpacked_func, + zigzag_ring_flash_attn_qkvpacked_func, + stripe_flash_attn_qkvpacked_func, + ]: + torch.cuda.empty_cache() + if rank == 0: + print(f"# {f.__name__}") + benchmark(f, forward_only=forward_only, log=False) + benchmark(f, forward_only=forward_only, log=True) diff --git a/ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py b/ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py new file mode 100644 index 000000000000..18c8cafc0078 --- /dev/null +++ b/ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py @@ -0,0 +1,91 @@ +import torch +import torch.cuda +import torch.distributed as dist +from flash_attn import flash_attn_varlen_qkvpacked_func +from ring_flash_attn import ring_flash_attn_varlen_qkvpacked_func, zigzag_ring_flash_attn_varlen_qkvpacked_func + + +def benchmark(f, num_iter=100, forward_only=True, log=True): + dtype = torch.bfloat16 + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + seqlen = 1024 * 8 + nheads = 5 + d = 128 + dropout_p = 0 + causal = True + deterministic = False + + assert seqlen % (2 * world_size) == 0 + assert d % 8 == 0 + + qkv = torch.randn(seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) + dout = torch.randn(seqlen, nheads, d, device=device, dtype=dtype) + + cu_seqlens_list = [ + torch.tensor([0, 8192], device=device, dtype=torch.int32), + torch.tensor([0, 256, 7648, 8192], device=device, dtype=torch.int32), + torch.tensor([0, 4096, 8192], device=device, dtype=torch.int32), + torch.tensor([0, 3104, 6304, 7904, 8064, 8192], device=device, dtype=torch.int32), + ] + max_seqlen_list = [(cu_seqlens[1:] - cu_seqlens[:1]).max().item() for cu_seqlens in cu_seqlens_list] + + begin = torch.cuda.Event(enable_timing=True) + begin.record() + if forward_only: + with torch.no_grad(): + for i in range(num_iter): + _ = f( + qkv, + cu_seqlens_list[i % len(cu_seqlens_list)], + max_seqlen_list[i % len(max_seqlen_list)], + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=False, + ) + else: + for i in range(num_iter): + qkv.grad = None + out = f( + qkv, + cu_seqlens_list[i % len(cu_seqlens_list)], + max_seqlen_list[i % len(max_seqlen_list)], + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=False, + ) + out.backward(dout) + end = torch.cuda.Event(enable_timing=True) + end.record() + torch.cuda.synchronize(device=device) + time = begin.elapsed_time(end) / 1000.0 + + if rank == 0 and log: + print(f"{num_iter / time} iter/s, {time} sec") + + +if __name__ == "__main__": + dist.init_process_group("nccl") + rank = dist.get_rank() + + forward_only = False + + for f in [ + flash_attn_varlen_qkvpacked_func, + ring_flash_attn_varlen_qkvpacked_func, + zigzag_ring_flash_attn_varlen_qkvpacked_func, + ]: + torch.cuda.empty_cache() + if rank == 0: + print(f"# {f.__name__}") + benchmark(f, forward_only=forward_only, log=False) + benchmark(f, forward_only=forward_only, log=True) diff --git a/ring-flash-attention/ring_flash_attn/__init__.py b/ring-flash-attention/ring_flash_attn/__init__.py new file mode 100644 index 000000000000..01d5ec36218c --- /dev/null +++ b/ring-flash-attention/ring_flash_attn/__init__.py @@ -0,0 +1,16 @@ +from .ring_flash_attn import ring_flash_attn_func, ring_flash_attn_kvpacked_func, ring_flash_attn_qkvpacked_func +from .ring_flash_attn_varlen import ( + ring_flash_attn_varlen_func, + ring_flash_attn_varlen_kvpacked_func, + ring_flash_attn_varlen_qkvpacked_func, +) +from .stripe_flash_attn import stripe_flash_attn_func, stripe_flash_attn_kvpacked_func, stripe_flash_attn_qkvpacked_func +from .zigzag_ring_flash_attn import ( + zigzag_ring_flash_attn_func, + zigzag_ring_flash_attn_kvpacked_func, + zigzag_ring_flash_attn_qkvpacked_func, +) +from .zigzag_ring_flash_attn_varlen import ( + zigzag_ring_flash_attn_varlen_func, + zigzag_ring_flash_attn_varlen_qkvpacked_func, +) diff --git a/ring-flash-attention/ring_flash_attn/ring_flash_attn.py b/ring-flash-attention/ring_flash_attn/ring_flash_attn.py new file mode 100644 index 000000000000..b36484dbd145 --- /dev/null +++ b/ring-flash-attention/ring_flash_attn/ring_flash_attn.py @@ -0,0 +1,281 @@ +import torch +from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward + +from .utils import RingComm, update_out_and_lse + + +def ring_flash_attn_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + comm = RingComm(process_group) + + out = None + lse = None + + next_k, next_v = None, None + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k: torch.Tensor = comm.send_recv(k) + next_v: torch.Tensor = comm.send_recv(v) + comm.commit() + + if not causal or step <= comm.rank: + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( + q, + k, + v, + dropout_p, + softmax_scale, + causal=causal and step == 0, + window_size=window_size, + alibi_slopes=alibi_slopes, + return_softmax=True and dropout_p > 0, + ) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + + if step + 1 != comm.world_size: + comm.wait() + k = next_k + v = next_v + + out = out.to(q.dtype) + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +def ring_flash_attn_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + + block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + + next_dk, next_dv = None, None + next_k, next_v = None, None + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k = kv_comm.send_recv(k) + next_v = kv_comm.send_recv(v) + kv_comm.commit() + if step <= kv_comm.rank or not causal: + bwd_causal = causal and step == 0 + _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + block_dq_buffer, + block_dk_buffer, + block_dv_buffer, + dropout_p, + softmax_scale, + bwd_causal, + window_size, + alibi_slopes, + deterministic, + rng_state=None, + ) + + if dq is None: + dq = block_dq_buffer.to(torch.float32) + dk = block_dk_buffer.to(torch.float32) + dv = block_dv_buffer.to(torch.float32) + else: + dq += block_dq_buffer + d_kv_comm.wait() + dk = block_dk_buffer + next_dk + dv = block_dv_buffer + next_dv + elif step != 0: + d_kv_comm.wait() + dk = next_dk + dv = next_dv + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k = next_k + v = next_v + + next_dk = d_kv_comm.send_recv(dk) + next_dv = d_kv_comm.send_recv(dv) + d_kv_comm.commit() + + d_kv_comm.wait() + + return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class RingFlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + k = k.contiguous() + v = v.contiguous() + out, softmax_lse = ring_flash_attn_forward( + group, + q, + k, + v, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse = ctx.saved_tensors + dq, dk, dv = ring_flash_attn_backward( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None + + +def ring_flash_attn_qkvpacked_func( + qkv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return RingFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def ring_flash_attn_kvpacked_func( + q, + kv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return RingFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def ring_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return RingFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/ring-flash-attention/ring_flash_attn/ring_flash_attn_varlen.py b/ring-flash-attention/ring_flash_attn/ring_flash_attn_varlen.py new file mode 100644 index 000000000000..118bdea4c7d0 --- /dev/null +++ b/ring-flash-attention/ring_flash_attn/ring_flash_attn_varlen.py @@ -0,0 +1,318 @@ +import torch +from flash_attn.flash_attn_interface import _flash_attn_varlen_backward, _flash_attn_varlen_forward + +from .utils import RingComm, update_out_and_lse + +try: + from .triton_utils import flatten_varlen_lse, unflatten_varlen_lse +except: + from .utils import flatten_varlen_lse, unflatten_varlen_lse + + +def ring_flash_attn_varlen_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens, + max_seqlen, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + comm = RingComm(process_group) + + out = None + lse = None + next_k, next_v = None, None + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k: torch.Tensor = comm.send_recv(k) + next_v: torch.Tensor = comm.send_recv(v) + comm.commit() + if not causal or step <= comm.rank: + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_varlen_forward( + q, + k, + v, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + dropout_p, + softmax_scale, + causal=causal and step == 0, + window_size=window_size, + alibi_slopes=alibi_slopes, + return_softmax=True and dropout_p > 0, + ) + block_lse = flatten_varlen_lse( + block_lse, + cu_seqlens=cu_seqlens, + ) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + + if step + 1 != comm.world_size: + comm.wait() + k = next_k + v = next_v + + out = out.to(q.dtype) + lse = unflatten_varlen_lse(lse, cu_seqlens, max_seqlen) + return out, lse + + +def ring_flash_attn_varlen_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + cu_seqlens, + max_seqlen, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + + block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + + next_dk, next_dv = None, None + next_k, next_v = None, None + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k = kv_comm.send_recv(k) + next_v = kv_comm.send_recv(v) + kv_comm.commit() + if step <= kv_comm.rank or not causal: + bwd_causal = causal and step == 0 + _flash_attn_varlen_backward( + dout, + q, + k, + v, + out, + softmax_lse, + block_dq_buffer, + block_dk_buffer, + block_dv_buffer, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + dropout_p, + softmax_scale, + bwd_causal, + window_size, + alibi_slopes, + deterministic, + rng_state=None, + ) + + if dq is None: + dq = block_dq_buffer.to(torch.float32) + dk = block_dk_buffer.to(torch.float32) + dv = block_dv_buffer.to(torch.float32) + else: + dq += block_dq_buffer + d_kv_comm.wait() + dk = block_dk_buffer + next_dk + dv = block_dv_buffer + next_dv + elif step != 0: + d_kv_comm.wait() + dk = next_dk + dv = next_dv + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k = next_k + v = next_v + + next_dk = d_kv_comm.send_recv(dk) + next_dv = d_kv_comm.send_recv(dv) + d_kv_comm.commit() + + d_kv_comm.wait() + + return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class RingFlashAttnVarlenFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + k = k.contiguous() + v = v.contiguous() + out, softmax_lse = ring_flash_attn_varlen_forward( + group, + q, + k, + v, + cu_seqlens, + max_seqlen, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens) + ctx.max_seqlen = max_seqlen + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens = ctx.saved_tensors + dq, dk, dv = ring_flash_attn_varlen_backward( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + cu_seqlens, + ctx.max_seqlen, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None + + +def ring_flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens, + max_seqlen, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return RingFlashAttnVarlenFunc.apply( + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def ring_flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens, + max_seqlen, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return RingFlashAttnVarlenFunc.apply( + q, + kv[:, 0], + kv[:, 1], + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def ring_flash_attn_varlen_func( + q, + k, + v, + cu_seqlens, + max_seqlen, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return RingFlashAttnVarlenFunc.apply( + q, + k, + v, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/ring-flash-attention/ring_flash_attn/stripe_flash_attn.py b/ring-flash-attention/ring_flash_attn/stripe_flash_attn.py new file mode 100644 index 000000000000..ca426920f4ed --- /dev/null +++ b/ring-flash-attention/ring_flash_attn/stripe_flash_attn.py @@ -0,0 +1,325 @@ +import torch +from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward + +from .utils import RingComm, update_out_and_lse + + +def stripe_flash_attn_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal, "stripe flash attn only supports causal attention, if not causal, use ring flash attn instead" + comm = RingComm(process_group) + + out = None + lse = None + + next_k, next_v = None, None + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k: torch.Tensor = comm.send_recv(k) + next_v: torch.Tensor = comm.send_recv(v) + comm.commit() + + if step <= comm.rank: + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( + q, + k, + v, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + return_softmax=True and dropout_p > 0, + ) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + else: + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( + q[:, 1:], + k[:, :-1], + v[:, :-1], + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + return_softmax=True and dropout_p > 0, + ) + out, lse = update_out_and_lse(out, lse, block_out, block_lse, slice_=(slice(None), slice(1, None))) + + if step + 1 != comm.world_size: + comm.wait() + k = next_k + v = next_v + + out = out.to(q.dtype) + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +def stripe_flash_attn_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal, "stripe flash attn only supports causal attention, if not causal, ring flash attn instead" + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k = kv_comm.send_recv(k) + next_v = kv_comm.send_recv(v) + kv_comm.commit() + + shift_causal = step > kv_comm.rank + softmax_lse_1 = None + if not shift_causal: + _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + block_dq_buffer, + block_dk_buffer, + block_dv_buffer, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + rng_state=None, + ) + else: + if softmax_lse_1 is None: + # lazy init, since the last rank does not need softmax_lse_1 + softmax_lse_1 = softmax_lse[:, :, 1:].contiguous() + _flash_attn_backward( + dout[:, 1:], + q[:, 1:], + k[:, :-1], + v[:, :-1], + out[:, 1:], + softmax_lse_1, + block_dq_buffer[:, 1:], + block_dk_buffer[:, :-1], + block_dv_buffer[:, :-1], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + rng_state=None, + ) + + if dq is None: + dq = block_dq_buffer.to(torch.float32) + dk = block_dk_buffer.to(torch.float32) + dv = block_dv_buffer.to(torch.float32) + else: + if not shift_causal: + dq += block_dq_buffer + else: + dq[:, 1:] += block_dq_buffer[:, 1:] + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk = next_dk + dv = next_dv + + if not shift_causal: + dk = block_dk_buffer + dk + dv = block_dv_buffer + dv + else: + dk[:, :-1] += block_dk_buffer[:, :-1] + dv[:, :-1] += block_dv_buffer[:, :-1] + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k = next_k + v = next_v + + next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) + next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) + d_kv_comm.commit() + + d_kv_comm.wait() + + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class StripeFlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + k = k.contiguous() + v = v.contiguous() + out, softmax_lse = stripe_flash_attn_forward( + group, + q, + k, + v, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse = ctx.saved_tensors + dq, dk, dv = stripe_flash_attn_backward( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None + + +def stripe_flash_attn_qkvpacked_func( + qkv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return StripeFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def stripe_flash_attn_kvpacked_func( + q, + kv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return StripeFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def stripe_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return StripeFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/ring-flash-attention/ring_flash_attn/triton_utils.py b/ring-flash-attention/ring_flash_attn/triton_utils.py new file mode 100644 index 000000000000..66e362d93d68 --- /dev/null +++ b/ring-flash-attention/ring_flash_attn/triton_utils.py @@ -0,0 +1,137 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def flatten_kernel( + # pointers to matrices + OUT, + LSE, + CU_SEQLENS, + # strides + stride_out_nheads, + stride_out_seqlen, + stride_lse_batch, + stride_lse_nheads, + stride_lse_seqlen, + # meta-parameters + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + LSE = LSE + pid_batch * stride_lse_batch + pid_head * stride_lse_nheads + OUT = OUT + pid_head * stride_out_nheads + start_idx * stride_out_seqlen + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + + LSE = LSE + rm[:, None] * stride_lse_seqlen + x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) + + OUT = OUT + rm[:, None] * stride_out_seqlen + tl.store(OUT, x, mask=rm[:, None] < seqlen) + + +def flatten_varlen_lse(lse, cu_seqlens): + """ + Arguments: + lse: (batch_size, nheads, max_seqlen) + cu_seqlens: (batch_size + 1,) + Return: + flatten_lse: (nheads, total_seqlen) + """ + total_seqlen = cu_seqlens[-1] + batch_size, nheads, max_seqlen = lse.shape + output = torch.empty((nheads, total_seqlen), dtype=lse.dtype, device=lse.device) + + grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) + BLOCK_M = 4 + + with torch.cuda.device(lse.device.index): + flatten_kernel[grid]( + output, + lse, + cu_seqlens, + # strides + output.stride(0), + output.stride(1), + lse.stride(0), + lse.stride(1), + lse.stride(2), + BLOCK_M, + ) + return output + + +@triton.jit +def unflatten_kernel( + # pointers to matrices + OUT, + LSE, + CU_SEQLENS, + # strides + stride_out_batch, + stride_out_nheads, + stride_out_seqlen, + stride_lse_seqlen, + stride_lse_nheads, + # meta-parameters + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + LSE = LSE + pid_head * stride_lse_nheads + start_idx * stride_lse_seqlen + OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + + LSE = LSE + rm[:, None] * stride_lse_seqlen + x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) + + OUT = OUT + rm[:, None] * stride_out_seqlen + tl.store(OUT, x, mask=rm[:, None] < seqlen) + + +def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): + """ + Arguments: + lse: (total_seqlen, nheads, 1) + cu_seqlens: (batch_size + 1,) + max_seqlen: int + Return: + unflatten_lse: (batch_size, nheads, max_seqlen) + """ + lse = lse.unsqueeze(dim=-1) + batch_size = len(cu_seqlens) - 1 + nheads = lse.shape[1] + output = torch.empty( + (batch_size, nheads, max_seqlen), + dtype=lse.dtype, + device=lse.device, + ) + + grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) + BLOCK_M = 4 + + with torch.cuda.device(lse.device.index): + unflatten_kernel[grid]( + output, + lse, + cu_seqlens, + # strides + output.stride(0), + output.stride(1), + output.stride(2), + lse.stride(0), + lse.stride(1), + BLOCK_M, + ) + return output diff --git a/ring-flash-attention/ring_flash_attn/utils.py b/ring-flash-attention/ring_flash_attn/utils.py new file mode 100644 index 000000000000..787732af8135 --- /dev/null +++ b/ring-flash-attention/ring_flash_attn/utils.py @@ -0,0 +1,110 @@ +from typing import Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +__all__ = ["update_out_and_lse", "RingComm"] + + +@torch.jit.script +def _update_out_and_lse( + out: torch.Tensor, + lse: torch.Tensor, + block_out: torch.Tensor, + block_lse: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + + block_out = block_out.to(torch.float32) + block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + + # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out + # For additional context and discussion, please refer to: + # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 + out = out - F.sigmoid(block_lse - lse) * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) + + return out, lse + + +def update_out_and_lse( + out: Optional[torch.Tensor], + lse: Optional[torch.Tensor], + block_out: torch.Tensor, + block_lse: torch.Tensor, + slice_=None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if out is None: + if slice_ is not None: + raise RuntimeError("first update_out_and_lse should not pass slice_ args") + out = block_out.to(torch.float32) + lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + elif slice_ is not None: + slice_out, slice_lse = out[slice_], lse[slice_] + slice_out, slice_lse = _update_out_and_lse(slice_out, slice_lse, block_out, block_lse) + out[slice_], lse[slice_] = slice_out, slice_lse + else: + out, lse = _update_out_and_lse(out, lse, block_out, block_lse) + return out, lse + + +@torch.jit.script +def flatten_varlen_lse(lse, cu_seqlens): + new_lse = [] + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + new_lse.append(lse[i, :, : end - start]) + return torch.cat(new_lse, dim=1) + + +@torch.jit.script +def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): + num_seq = len(cu_seqlens) - 1 + num_head = lse.shape[-2] + new_lse = torch.empty((num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device) + for i in range(num_seq): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + new_lse[i, : end - start] = lse[start:end] + return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous() + + +class RingComm: + def __init__(self, process_group: dist.ProcessGroup): + self._process_group = process_group + self._ops = [] + self.rank = dist.get_rank(self._process_group) + self.world_size = dist.get_world_size(self._process_group) + self._reqs = None + + self.send_rank = (self.rank + 1) % self.world_size + self.recv_rank = (self.rank - 1) % self.world_size + + if process_group is not None: + self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) + self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) + + def send_recv(self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None) -> torch.Tensor: + if recv_tensor is None: + res = torch.empty_like(to_send) + else: + res = recv_tensor + + send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group) + recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) + self._ops.append(send_op) + self._ops.append(recv_op) + return res + + def commit(self): + if self._reqs is not None: + raise RuntimeError("commit called twice") + self._reqs = dist.batch_isend_irecv(self._ops) + + def wait(self): + if self._reqs is None: + raise RuntimeError("wait called before commit") + for req in self._reqs: + req.wait() + self._reqs = None + self._ops = [] diff --git a/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn.py b/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn.py new file mode 100644 index 000000000000..d3e2821c5d4d --- /dev/null +++ b/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn.py @@ -0,0 +1,327 @@ +import torch +import torch.distributed as dist +from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward + +from .utils import RingComm, update_out_and_lse + + +def zigzag_ring_flash_attn_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal == True, "zigzag ring is meaningless for causal=False" + comm = RingComm(process_group) + + block_seq_len = q.shape[1] // 2 + q1 = q[:, block_seq_len:] + + out = None + lse = None + next_k, next_v = None, None + + def forward(q, k, v, causal): + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( + q, + k, + v, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + return_softmax=True and dropout_p > 0, + ) + return block_out, block_lse + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k: torch.Tensor = comm.send_recv(k) + next_v: torch.Tensor = comm.send_recv(v) + comm.commit() + + if step == 0: + block_out, block_lse = forward(q, k, v, causal=True) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + elif step <= comm.rank: + k0 = k[:, :block_seq_len] + v0 = v[:, :block_seq_len] + block_out, block_lse = forward(q, k0, v0, causal=False) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + else: + block_out, block_lse = forward(q1, k, v, causal=False) + out, lse = update_out_and_lse( + out, + lse, + block_out, + block_lse, + slice_=(slice(None), slice(block_seq_len, None)), + ) + + if step + 1 != comm.world_size: + comm.wait() + k = next_k + v = next_v + + out = out.to(q.dtype) + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +def zigzag_ring_flash_attn_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal == True, "zigzag ring is meaningless for causal=False" + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + dout1 = dout.chunk(2, dim=1)[1] + q1 = q.chunk(2, dim=1)[1] + out1 = out.chunk(2, dim=1)[1] + softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous() + block_seq_len = q.shape[1] // 2 + + # repeatly allocating buffer may be slow... + dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + + def backward(dout, q, k, v, out, softmax_lse, causal): + seqlen_q = q.shape[1] + seqlen_kv = k.shape[1] + _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq_buffer[:, :seqlen_q], + dk_buffer[:, :seqlen_kv], + dv_buffer[:, :seqlen_kv], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + rng_state=None, + ) + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k = kv_comm.send_recv(k) + next_v = kv_comm.send_recv(v) + kv_comm.commit() + + if step == 0: + backward(dout, q, k, v, out, softmax_lse, causal=True) + dq = dq_buffer.to(torch.float32) + dk = dk_buffer.to(torch.float32) + dv = dv_buffer.to(torch.float32) + else: + if step <= kv_comm.rank: + k0 = k[:, :block_seq_len] + v0 = v[:, :block_seq_len] + backward(dout, q, k0, v0, out, softmax_lse, causal=False) + dq += dq_buffer + else: + backward(dout1, q1, k, v, out1, softmax_lse1, causal=False) + # always use the first half in dq_buffer. + dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len] + + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + if step <= kv_comm.rank: + dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len] + dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len] + else: + dk += dk_buffer + dv += dv_buffer + if dist.get_rank() == 0: + torch.save(torch.stack((dk, dv)), f"step_{step}.pt") + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k = next_k + v = next_v + + next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) + next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) + d_kv_comm.commit() + + d_kv_comm.wait() + + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class ZigZagRingFlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + k = k.contiguous() + v = v.contiguous() + out, softmax_lse = zigzag_ring_flash_attn_forward( + group, + q, + k, + v, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse = ctx.saved_tensors + dq, dk, dv = zigzag_ring_flash_attn_backward( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None + + +def zigzag_ring_flash_attn_qkvpacked_func( + qkv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def zigzag_ring_flash_attn_kvpacked_func( + q, + kv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def zigzag_ring_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py b/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py new file mode 100644 index 000000000000..5d4a8dd2daf0 --- /dev/null +++ b/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py @@ -0,0 +1,441 @@ +import torch +from flash_attn.flash_attn_interface import _flash_attn_varlen_backward, _flash_attn_varlen_forward + +from .utils import RingComm, update_out_and_lse + +try: + from .triton_utils import flatten_varlen_lse, unflatten_varlen_lse +except: + from .utils import flatten_varlen_lse, unflatten_varlen_lse + + +def get_half_index(cu_seqlens, *, front: bool): + if len(cu_seqlens) == 2: + if front: + return slice(None, cu_seqlens[-1] // 2) + else: + return slice(cu_seqlens[-1] // 2, None) + + index = torch.zeros((cu_seqlens[-1],), dtype=bool) + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + if front: + end = (start + end) // 2 + else: + start = (start + end) // 2 + index[start:end] = True + return index + + +@torch.jit.script +def get_half_lse(lse, cu_seqlens, *, front: bool): + new_lse = torch.empty( + (lse.shape[0], lse.shape[1], lse.shape[2] // 2), + dtype=lse.dtype, + device=lse.device, + ) + for i in range(len(cu_seqlens) - 1): + seqlen = (cu_seqlens[i + 1] - cu_seqlens[i]).item() + if front: + start, end = 0, seqlen // 2 + else: + start, end = seqlen // 2, seqlen + new_lse[i, :, : seqlen // 2] = lse[i, :, start:end] + return new_lse + + +def zigzag_ring_flash_attn_varlen_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens, + max_seqlen, + half_index0, + half_index1, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal == True, "zigzag ring is meaningless for causal=False" + comm = RingComm(process_group) + + block_seq_len = q.shape[0] // 2 + q1 = q[half_index1] + + out = None + lse = None + next_k, next_v = None, None + half_cu_seqlens = cu_seqlens // 2 + half_max_seqlen = max_seqlen // 2 + + def forward(q, k, v, causal): + seqlen_q = q.shape[0] + seqlen_kv = k.shape[0] + cu_seqlens_q = half_cu_seqlens if seqlen_q == block_seq_len else cu_seqlens + max_seqlen_q = half_max_seqlen if seqlen_q == block_seq_len else max_seqlen + cu_seqlens_kv = half_cu_seqlens if seqlen_kv == block_seq_len else cu_seqlens + max_seqlen_kv = half_max_seqlen if seqlen_kv == block_seq_len else max_seqlen + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_varlen_forward( + q, + k, + v, + # the first half and the second half are the same + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + return_softmax=True and dropout_p > 0, + ) + return block_out, block_lse + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k: torch.Tensor = comm.send_recv(k) + next_v: torch.Tensor = comm.send_recv(v) + comm.commit() + + if step == 0: + block_out, block_lse = forward(q, k, v, causal=True) + block_lse = flatten_varlen_lse( + block_lse, + cu_seqlens=cu_seqlens, + ) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + elif step <= comm.rank: + k0 = k[half_index0] + v0 = v[half_index0] + block_out, block_lse = forward(q, k0, v0, causal=False) + block_lse = flatten_varlen_lse( + block_lse, + cu_seqlens=cu_seqlens, + ) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + else: + block_out, block_lse = forward(q1, k, v, causal=False) + block_lse = flatten_varlen_lse( + block_lse, + cu_seqlens=half_cu_seqlens, + ) + out[half_index1], lse[half_index1] = update_out_and_lse( + out[half_index1], lse[half_index1], block_out, block_lse + ) + + if step + 1 != comm.world_size: + comm.wait() + k = next_k + v = next_v + + out = out.to(q.dtype) + lse = unflatten_varlen_lse(lse, cu_seqlens, max_seqlen) + return out, lse + + +def zigzag_ring_flash_attn_varlen_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + cu_seqlens, + max_seqlen, + half_index0, + half_index1, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal == True, "zigzag ring is meaningless for causal=False" + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + dout1 = dout[half_index1] + q1 = q[half_index1] + out1 = out[half_index1] + softmax_lse1 = get_half_lse(softmax_lse, cu_seqlens, front=False) + block_seq_len = q.shape[0] // 2 + + half_cu_seqlens = cu_seqlens // 2 + half_max_seqlen = max_seqlen // 2 + + # repeatly allocating buffer may be slow... + dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + + def backward(dout, q, k, v, out, softmax_lse, causal): + seqlen_q = q.shape[0] + seqlen_kv = k.shape[0] + cu_seqlens_q = half_cu_seqlens if seqlen_q == block_seq_len else cu_seqlens + max_seqlen_q = half_max_seqlen if seqlen_q == block_seq_len else max_seqlen + cu_seqlens_kv = half_cu_seqlens if seqlen_kv == block_seq_len else cu_seqlens + max_seqlen_kv = half_max_seqlen if seqlen_kv == block_seq_len else max_seqlen + _flash_attn_varlen_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq_buffer[:seqlen_q], + dk_buffer[:seqlen_kv], + dv_buffer[:seqlen_kv], + # the first half and the second half are the same + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + rng_state=None, + ) + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k = kv_comm.send_recv(k) + next_v = kv_comm.send_recv(v) + kv_comm.commit() + + if step == 0: + backward(dout, q, k, v, out, softmax_lse, causal=True) + dq = dq_buffer.to(torch.float32) + dk = dk_buffer.to(torch.float32) + dv = dv_buffer.to(torch.float32) + else: + if step <= kv_comm.rank: + k0 = k[half_index0] + v0 = v[half_index0] + backward(dout, q, k0, v0, out, softmax_lse, causal=False) + dq += dq_buffer + else: + backward(dout1, q1, k, v, out1, softmax_lse1, causal=False) + dq[half_index1] += dq_buffer[:block_seq_len] + + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + if step <= kv_comm.rank: + dk[half_index0] += dk_buffer[:block_seq_len] + dv[half_index0] += dv_buffer[:block_seq_len] + else: + dk += dk_buffer + dv += dv_buffer + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k = next_k + v = next_v + + next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) + next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) + d_kv_comm.commit() + + d_kv_comm.wait() + + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class ZigZagRingFlashAttnVarlenFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + k = k.contiguous() + v = v.contiguous() + half_index0 = get_half_index(cu_seqlens, front=True) + half_index1 = get_half_index(cu_seqlens, front=False) + out, softmax_lse = zigzag_ring_flash_attn_varlen_forward( + group, + q, + k, + v, + cu_seqlens, + max_seqlen, + half_index0, + half_index1, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + is_half_index_tensor = isinstance(half_index0, torch.Tensor) + ctx.is_half_index_tensor = is_half_index_tensor + if is_half_index_tensor: + ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens, half_index0, half_index1) + else: + ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens) + ctx.half_index0 = half_index0 + ctx.half_index1 = half_index1 + ctx.max_seqlen = max_seqlen + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + if ctx.is_half_index_tensor: + (q, k, v, out, softmax_lse, cu_seqlens, half_index0, half_index1) = ctx.saved_tensors + else: + q, k, v, out, softmax_lse, cu_seqlens = ctx.saved_tensors + half_index0 = ctx.half_index0 + half_index1 = ctx.half_index1 + dq, dk, dv = zigzag_ring_flash_attn_varlen_backward( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + cu_seqlens, + ctx.max_seqlen, + half_index0, + half_index1, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None + + +def zigzag_ring_flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens, + max_seqlen, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnVarlenFunc.apply( + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def zigzag_ring_flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens, + max_seqlen, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnVarlenFunc.apply( + q, + kv[:, 0], + kv[:, 1], + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def zigzag_ring_flash_attn_varlen_func( + q, + k, + v, + cu_seqlens, + max_seqlen, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnVarlenFunc.apply( + q, + k, + v, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/ring-flash-attention/setup.py b/ring-flash-attention/setup.py new file mode 100644 index 000000000000..58413e1b54f3 --- /dev/null +++ b/ring-flash-attention/setup.py @@ -0,0 +1,9 @@ +from setuptools import find_packages, setup + +setup( + name="ring_flash_attn", + version="0.1", + author="zhuzilin", + url="https://github.com/zhuzilin/ring-flash-attention", + packages=find_packages(), +) diff --git a/ring-flash-attention/test/test_ring_flash_attn_func.py b/ring-flash-attention/test/test_ring_flash_attn_func.py new file mode 100644 index 000000000000..50edd03bef4e --- /dev/null +++ b/ring-flash-attention/test/test_ring_flash_attn_func.py @@ -0,0 +1,124 @@ +import random + +import torch +import torch.distributed as dist +from flash_attn import flash_attn_qkvpacked_func +from ring_flash_attn import ring_flash_attn_qkvpacked_func + + +def set_seed(rank, seed=42): + seed = rank + seed + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def log(msg, a, rank0_only=False): + world_size = dist.get_world_size() + rank = dist.get_rank() + if rank0_only: + if rank == 0: + print( + f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", + flush=True, + ) + return + + for i in range(world_size): + if i == rank: + if rank == 0: + print(f"{msg}:") + print( + f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", + flush=True, + ) + dist.barrier() + + +if __name__ == "__main__": + dist.init_process_group("nccl") + rank = dist.get_rank() + set_seed(rank) + world_size = dist.get_world_size() + dtype = torch.bfloat16 + device = torch.device(f"cuda:{rank}") + + batch_size = 1 + seqlen = 3816 + nheads = 5 + d = 128 + dropout_p = 0 + causal = True + deterministic = False + + assert seqlen % world_size == 0 + assert d % 8 == 0 + + qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) + dist.broadcast(qkv, src=0) + + dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) + dist.broadcast(dout, src=0) + + local_qkv = qkv.chunk(world_size, dim=1)[rank].detach().clone() + local_qkv.requires_grad = True + local_dout = dout.chunk(world_size, dim=1)[rank].detach().clone() + + dist.barrier() + if rank == 0: + print("#" * 30) + print("# forward:") + print("#" * 30) + + out, lse, _ = flash_attn_qkvpacked_func( + qkv, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=True, + ) + + local_out = out.chunk(world_size, dim=1)[rank] + local_lse = lse.chunk(world_size, dim=-1)[rank] + + fn = ring_flash_attn_qkvpacked_func + + ring_out, ring_lse, _ = fn( + local_qkv, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=True, + ) + + log("out", out, rank0_only=True) + log("lse", lse, rank0_only=True) + log("out diff", local_out - ring_out) + log("lse diff", local_lse - ring_lse) + + dist.barrier() + if rank == 0: + print("#" * 30) + print("# backward:") + print("#" * 30) + + out.backward(dout) + dqkv = qkv.grad + local_dqkv = dqkv.chunk(world_size, dim=1)[rank] + + ring_out.backward(local_dout) + ring_dqkv = local_qkv.grad + + log("local_dq", local_dqkv[:, :, 0, :]) + log("dq diff", local_dqkv[:, :, 0, :] - ring_dqkv[:, :, 0, :]) + + log("local_dk", local_dqkv[:, :, 1, :]) + log("dk diff", local_dqkv[:, :, 1, :] - ring_dqkv[:, :, 1, :]) + + log("local_dv", local_dqkv[:, :, 2, :]) + log("dv diff", local_dqkv[:, :, 2, :] - ring_dqkv[:, :, 2, :]) diff --git a/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py b/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py new file mode 100644 index 000000000000..51bb1ec5d67d --- /dev/null +++ b/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py @@ -0,0 +1,157 @@ +import random + +import torch +import torch.distributed as dist +from flash_attn import flash_attn_varlen_qkvpacked_func +from ring_flash_attn import ring_flash_attn_varlen_qkvpacked_func + + +def set_seed(rank, seed=42): + seed = rank + seed + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def log(msg, a, rank0_only=False): + world_size = dist.get_world_size() + rank = dist.get_rank() + if rank0_only: + if rank == 0: + print( + f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", + flush=True, + ) + return + + for i in range(world_size): + if i == rank: + if rank == 0: + print(f"{msg}:") + print( + f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", + flush=True, + ) + dist.barrier() + + +def extract_local(value, cu_seqlens, rank, world_size): + local_values = [] + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + local_value = value[start:end].chunk(world_size, dim=0)[rank].detach().clone() + local_values.append(local_value) + return torch.cat(local_values, dim=0).contiguous() + + +def extract_lse(lse, cu_seqlens): + values = [] + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + value = lse[i, :, : end - start] + values.append(value) + return values + + +if __name__ == "__main__": + dist.init_process_group("nccl") + rank = dist.get_rank() + set_seed(rank) + world_size = dist.get_world_size() + dtype = torch.bfloat16 + device = torch.device(f"cuda:{rank}") + + batch_size = 1 + nheads = 5 + d = 128 + dropout_p = 0 + causal = True + deterministic = False + + cu_seqlens = [0, 120, 1248, 4232] + cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item() + total_length = cu_seqlens[-1] + num_seq = len(cu_seqlens) - 1 + + assert torch.all(cu_seqlens_tensor % world_size == 0) + assert d % 8 == 0 + + qkv = torch.randn(total_length, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) + dist.broadcast(qkv, src=0) + + dout = torch.randn(total_length, nheads, d, device=device, dtype=dtype) + dist.broadcast(dout, src=0) + + local_cu_seqlens_tensor = cu_seqlens_tensor // world_size + local_max_seqlen = max_seqlen // world_size + + local_qkv = extract_local(qkv, cu_seqlens, rank, world_size) + local_qkv.requires_grad = True + local_dout = extract_local(dout, cu_seqlens, rank, world_size) + + dist.barrier() + if rank == 0: + print("#" * 30) + print("# forward:") + print("#" * 30) + + out, lse, _ = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens_tensor, + max_seqlen, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=True, + ) + + local_out = extract_local(out, cu_seqlens, rank, world_size) + lse_list = extract_lse(lse, cu_seqlens) + + ring_out, ring_lse, _ = ring_flash_attn_varlen_qkvpacked_func( + local_qkv, + local_cu_seqlens_tensor, + local_max_seqlen, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=True, + ) + + ring_lse_list = extract_lse(ring_lse, local_cu_seqlens_tensor.tolist()) + + log("out", out, rank0_only=True) + log("out diff", local_out - ring_out) + + for lse, ring_lse in zip(lse_list, ring_lse_list): + local_lse = lse.chunk(world_size, dim=-1)[rank] + log("lse", lse, rank0_only=True) + log("lse diff", local_lse - ring_lse) + + dist.barrier() + if rank == 0: + print("#" * 30) + print("# backward:") + print("#" * 30) + + out.backward(dout) + dqkv = qkv.grad + local_dqkv = extract_local(dqkv, cu_seqlens, rank, world_size) + + ring_out.backward(local_dout) + ring_dqkv = local_qkv.grad + + log("local_dq", local_dqkv[:, 0]) + log("dq diff", local_dqkv[:, 0] - ring_dqkv[:, 0]) + + log("local_dk", local_dqkv[:, 1]) + log("dk diff", local_dqkv[:, 1] - ring_dqkv[:, 1]) + + log("local_dv", local_dqkv[:, 2]) + log("dv diff", local_dqkv[:, 2] - ring_dqkv[:, 2]) diff --git a/ring-flash-attention/test/test_stripe_flash_attn_func.py b/ring-flash-attention/test/test_stripe_flash_attn_func.py new file mode 100644 index 000000000000..dc9f5248d69d --- /dev/null +++ b/ring-flash-attention/test/test_stripe_flash_attn_func.py @@ -0,0 +1,130 @@ +import random + +import torch +import torch.distributed as dist +from flash_attn import flash_attn_qkvpacked_func +from ring_flash_attn import stripe_flash_attn_qkvpacked_func + + +def set_seed(rank, seed=42): + seed = rank + seed + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def log(msg, a, rank0_only=False): + world_size = dist.get_world_size() + rank = dist.get_rank() + if rank0_only: + if rank == 0: + print( + f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", + flush=True, + ) + return + + for i in range(world_size): + if i == rank: + if rank == 0: + print(f"{msg}:") + print( + f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", + flush=True, + ) + dist.barrier() + + +def extract_local(value, rank, world_size, dim=1): + value = torch.stack(value.split(world_size, dim=dim), dim=dim).transpose(dim, dim + 1) + slicer = [rank if i == dim else slice(None) for i in range(len(value.shape))] + return value[slicer].contiguous() + + +if __name__ == "__main__": + dist.init_process_group("nccl") + rank = dist.get_rank() + set_seed(rank) + world_size = dist.get_world_size() + dtype = torch.bfloat16 + device = torch.device(f"cuda:{rank}") + + batch_size = 1 + seqlen = 3824 + nheads = 5 + d = 128 + dropout_p = 0 + causal = True + deterministic = False + + assert causal + assert seqlen % (2 * world_size) == 0 + assert d % 8 == 0 + + qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) + dist.broadcast(qkv, src=0) + + dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) + dist.broadcast(dout, src=0) + + local_qkv = extract_local(qkv, rank, world_size).detach().clone() + local_qkv.requires_grad = True + local_dout = extract_local(dout, rank, world_size).detach().clone() + + dist.barrier() + if rank == 0: + print("#" * 30) + print("# forward:") + print("#" * 30) + + out, lse, _ = flash_attn_qkvpacked_func( + qkv, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=True, + ) + + local_out = extract_local(out, rank, world_size) + local_lse = extract_local(lse, rank, world_size, dim=2) + + ring_out, ring_lse, _ = stripe_flash_attn_qkvpacked_func( + local_qkv, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=True, + ) + + log("out", out, rank0_only=True) + log("lse", lse, rank0_only=True) + log("out diff", local_out - ring_out) + log("lse diff", local_lse - ring_lse) + + dist.barrier() + if rank == 0: + print("#" * 30) + print("# backward:") + print("#" * 30) + + out.backward(dout) + dqkv = qkv.grad + + local_dqkv = extract_local(dqkv, rank, world_size) + + ring_out.backward(local_dout) + ring_dqkv = local_qkv.grad + + log("local_dq", local_dqkv[:, :, 0, :]) + log("dq diff", local_dqkv[:, :, 0, :] - ring_dqkv[:, :, 0, :]) + + log("local_dk", local_dqkv[:, :, 1, :]) + log("dk0 diff", local_dqkv[:, :, 1, :] - ring_dqkv[:, :, 1, :]) + + log("local_dv", local_dqkv[:, :, 2, :]) + log("dv diff", local_dqkv[:, :, 2, :] - ring_dqkv[:, :, 2, :]) diff --git a/ring-flash-attention/test/test_triton_kernels.py b/ring-flash-attention/test/test_triton_kernels.py new file mode 100644 index 000000000000..aa1c1fdcd338 --- /dev/null +++ b/ring-flash-attention/test/test_triton_kernels.py @@ -0,0 +1,30 @@ +import torch +from ring_flash_attn.triton_utils import flatten_varlen_lse as triton_flatten_varlen_lse +from ring_flash_attn.triton_utils import unflatten_varlen_lse as triton_unflatten_varlen_lse +from ring_flash_attn.utils import flatten_varlen_lse, unflatten_varlen_lse + +if __name__ == "__main__": + device = torch.device("cuda:0") + + cu_seqlens = [0, 15, 156, 529] + cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + batch_size = len(cu_seqlens) - 1 + max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item() + n_head = 5 + + lse = torch.randn((batch_size, n_head, max_seqlen), dtype=torch.float32, device=device) + flatten_lse = flatten_varlen_lse(lse, cu_seqlens_tensor) + triton_flatten_lse = triton_flatten_varlen_lse(lse, cu_seqlens_tensor) + assert torch.all(flatten_lse == triton_flatten_lse) + + flatten_lse = flatten_lse.transpose(-2, -1).unsqueeze(dim=-1) + triton_flatten_lse = triton_flatten_lse.transpose(-2, -1).unsqueeze(dim=-1) + + unflatten_lse = unflatten_varlen_lse(flatten_lse, cu_seqlens_tensor, max_seqlen) + triton_unflatten_lse = triton_unflatten_varlen_lse(triton_flatten_lse, cu_seqlens_tensor, max_seqlen) + + for i in range(batch_size): + seqlen = cu_seqlens[i + 1] - cu_seqlens[i] + assert torch.all( + unflatten_lse[i, :, :seqlen] == triton_unflatten_lse[i, :, :seqlen] + ), f"{unflatten_lse[i, :seqlen]} vs {triton_unflatten_lse[i, :seqlen]}" diff --git a/ring-flash-attention/test/test_zigzag_ring_flash_attn_func.py b/ring-flash-attention/test/test_zigzag_ring_flash_attn_func.py new file mode 100644 index 000000000000..5f84bc58cf10 --- /dev/null +++ b/ring-flash-attention/test/test_zigzag_ring_flash_attn_func.py @@ -0,0 +1,150 @@ +import os +import random + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from flash_attn import flash_attn_qkvpacked_func +from ring_flash_attn import zigzag_ring_flash_attn_qkvpacked_func + +from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention + + +def set_seed(rank, seed=42): + seed = rank + seed + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def log(msg, a, rank0_only=False): + world_size = dist.get_world_size() + rank = dist.get_rank() + if rank0_only: + if rank == 0: + print( + f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", + flush=True, + ) + return + + for i in range(world_size): + if i == rank: + if rank == 0: + print(f"{msg}:") + print( + f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", + flush=True, + ) + dist.barrier() + + +def extract_local(value, rank, world_size, dim=1): + value_chunks = value.chunk(2 * world_size, dim=dim) + local_value = torch.cat([value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim) + return local_value.contiguous() + + +def run_test(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" # or the IP of the master node + os.environ["MASTER_PORT"] = "8125" # make sure this port is free + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + set_seed(rank) + dtype = torch.bfloat16 + device = torch.device(f"cuda:{rank}") + + batch_size = 1 + seqlen = 3824 + nheads = 5 + d = 128 + dropout_p = 0 + causal = True + deterministic = False + + assert causal + assert seqlen % (2 * world_size) == 0 + assert d % 8 == 0 + + qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) + dist.broadcast(qkv, src=0) + + dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) + dist.broadcast(dout, src=0) + + local_qkv = extract_local(qkv, rank, world_size).detach().clone() + local_qkv.requires_grad = True + extract_local(dout, rank, world_size).detach().clone() + + dist.barrier() + if rank == 0: + print("#" * 30) + print("# forward:") + print("#" * 30) + + out, lse, _ = flash_attn_qkvpacked_func( + qkv, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=True, + ) + + local_out = extract_local(out, rank, world_size) + # local_lse = extract_local(lse, rank, world_size, dim=2) + q, k, v = local_qkv.chunk(3, dim=2) + q, k, v = [x.squeeze(2).detach().clone().transpose(1, 2) for x in (q, k, v)] + q.requires_grad = k.requires_grad = v.requires_grad = True + sp_stream = torch.cuda.Stream() + sp_group = dist.new_group() + colo_out = RingAttention.attention(q, k, v, sp_group, sp_stream, AttnMaskType.CAUSAL) + + ring_out, ring_lse, _ = zigzag_ring_flash_attn_qkvpacked_func( + local_qkv, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=True, + ) + log("colo_out", colo_out, rank0_only=True) + log("ring_out", ring_out, rank0_only=True) + # log("lse", lse, rank0_only=True) + log("colo_out - ring_out", colo_out - ring_out) + # log("lse diff", local_lse - ring_lse) + log("ring_out - local_out", ring_out - local_out) + log("colo_out - local_out", colo_out - local_out) + + dist.barrier() + if rank == 0: + print("#" * 30) + print("# backward:") + print("#" * 30) + + colo_out.sum().backward() + qkv.grad + # q, k, v = [x.transpose(1, 2) for x in (q, k, v)] + colo_dq, colo_dk, colo_dv = [x.transpose(1, 2) for x in (q.grad, k.grad, v.grad)] + + ring_out.sum().backward() + ring_dqkv = local_qkv.grad + out.sum().backward() + dqkv = extract_local(qkv.grad, rank, world_size) + + # log("colo_dq", colo_dq) + log("dq diff", colo_dq - ring_dqkv[:, :, 0, :]) + + # log("colo_dk", colo_dk) + log("dk diff", colo_dk - ring_dqkv[:, :, 1, :]) + + # log("colo_dv", colo_dv) + log("dv diff", colo_dv - ring_dqkv[:, :, 2, :]) + log("colo_dv - local_dv", colo_dv - dqkv[:, :, 2, :]) + + +if __name__ == "__main__": + world_size = 4 + mp.spawn(run_test, args=(world_size,), nprocs=world_size, join=True) diff --git a/ring-flash-attention/test/test_zigzag_ring_flash_attn_varlen_func.py b/ring-flash-attention/test/test_zigzag_ring_flash_attn_varlen_func.py new file mode 100644 index 000000000000..7f6eced6e57b --- /dev/null +++ b/ring-flash-attention/test/test_zigzag_ring_flash_attn_varlen_func.py @@ -0,0 +1,163 @@ +import random + +import torch +import torch.distributed as dist +from flash_attn import flash_attn_varlen_qkvpacked_func +from ring_flash_attn import zigzag_ring_flash_attn_varlen_qkvpacked_func + + +def set_seed(rank, seed=42): + seed = rank + seed + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def log(msg, a, rank0_only=False): + world_size = dist.get_world_size() + rank = dist.get_rank() + if rank0_only: + if rank == 0: + print( + f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", + flush=True, + ) + return + + for i in range(world_size): + if i == rank: + if rank == 0: + print(f"{msg}:") + print( + f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", + flush=True, + ) + dist.barrier() + + +def extract_local(value, cu_seqlens, rank, world_size): + local_values = [] + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + local_value = value[start:end].chunk(2 * world_size, dim=0) + local_values.extend( + [ + local_value[rank].detach().clone(), + local_value[2 * world_size - 1 - rank].detach().clone(), + ] + ) + return torch.cat(local_values, dim=0).contiguous() + + +def extract_lse(lse, cu_seqlens): + values = [] + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + value = lse[i, :, : end - start] + values.append(value) + return values + + +if __name__ == "__main__": + dist.init_process_group("nccl") + rank = dist.get_rank() + set_seed(rank) + world_size = dist.get_world_size() + dtype = torch.bfloat16 + device = torch.device(f"cuda:{rank}") + + batch_size = 1 + nheads = 5 + d = 128 + dropout_p = 0 + causal = True + deterministic = False + + cu_seqlens = [0, 128, 1248, 4240] + cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item() + total_length = cu_seqlens[-1] + num_seq = len(cu_seqlens) - 1 + + assert torch.all(cu_seqlens_tensor % (2 * world_size) == 0) + assert d % 8 == 0 + + qkv = torch.randn(total_length, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) + dist.broadcast(qkv, src=0) + + dout = torch.randn(total_length, nheads, d, device=device, dtype=dtype) + dist.broadcast(dout, src=0) + + local_cu_seqlens_tensor = cu_seqlens_tensor // world_size + local_max_seqlen = max_seqlen // world_size + + local_qkv = extract_local(qkv, cu_seqlens, rank, world_size) + local_qkv.requires_grad = True + local_dout = extract_local(dout, cu_seqlens, rank, world_size) + + dist.barrier() + if rank == 0: + print("#" * 30) + print("# forward:") + print("#" * 30) + + out, lse, _ = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens_tensor, + max_seqlen, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=True, + ) + + local_out = extract_local(out, cu_seqlens, rank, world_size) + lse_list = extract_lse(lse, cu_seqlens) + + ring_out, ring_lse, _ = zigzag_ring_flash_attn_varlen_qkvpacked_func( + local_qkv, + local_cu_seqlens_tensor, + local_max_seqlen, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=True, + ) + + ring_lse_list = extract_lse(ring_lse, local_cu_seqlens_tensor.tolist()) + + log("out", out, rank0_only=True) + log("out diff", local_out - ring_out) + + for i, (lse, ring_lse) in enumerate(zip(lse_list, ring_lse_list)): + local_lse = lse.chunk(2 * world_size, dim=-1) + local_lse = torch.cat([local_lse[rank], local_lse[2 * world_size - 1 - rank]], dim=-1) + log(f"lse {i}", lse, rank0_only=True) + log(f"lse diff {i}", local_lse - ring_lse) + + dist.barrier() + if rank == 0: + print("#" * 30) + print("# backward:") + print("#" * 30) + + out.backward(dout) + dqkv = qkv.grad + local_dqkv = extract_local(dqkv, cu_seqlens, rank, world_size) + + ring_out.backward(local_dout) + ring_dqkv = local_qkv.grad + + log("local_dq", local_dqkv[:, 0]) + log("dq diff", local_dqkv - ring_dqkv) + + log("local_dk", local_dqkv[:, 1]) + log("dk diff", local_dqkv[:, 1] - ring_dqkv[:, 1]) + + log("local_dv", local_dqkv[:, 2]) + log("dv diff", local_dqkv[:, 2] - ring_dqkv[:, 2]) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 64b56eec1068..9ad84341ac9e 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -317,11 +317,12 @@ def check_output_hidden_state( sharded_hidden_state = sharded_output.last_hidden_state # Check if the output sequence is gathered before cross entropy - seq_dim = 1 - sp_group = shard_config.sequence_parallel_process_group - sp_size = shard_config.sequence_parallel_size - if org_hidden_state.shape[seq_dim] == sharded_hidden_state.shape[seq_dim] * sp_size: - org_hidden_state = org_hidden_state.chunk(sp_size, dim=seq_dim)[dist.get_rank(sp_group)] + if shard_config is not None: + seq_dim = 1 + sp_group = shard_config.sequence_parallel_process_group + sp_size = shard_config.sequence_parallel_size + if org_hidden_state.shape[seq_dim] == sharded_hidden_state.shape[seq_dim] * sp_size: + org_hidden_state = org_hidden_state.chunk(sp_size, dim=seq_dim)[dist.get_rank(sp_group)] assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index e69283a3b18d..1c7b9321583f 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -153,8 +153,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ +<<<<<<< HEAD # Double Ring Attention # Zigzag Ring Attention +======= + # Zigzag Ring Attention + PP + { + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring_attn", + "use_lazy_init": True, + "zero_stage": 1, + "precision": "bf16", + "initial_scale": 1, + }, + # Ring Attention + TP +>>>>>>> precision tests passed { "tp_size": 2, "pp_size": 1, @@ -180,7 +197,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 1, "sp_size": 2, - "num_microbatches": 2, + "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "all_to_all", "enable_all_optimization": True, @@ -328,8 +345,11 @@ def run_llama_test(test_config): if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name: continue <<<<<<< HEAD +<<<<<<< HEAD ======= +>>>>>>> precision tests passed +======= >>>>>>> precision tests passed try: check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -423,4 +443,4 @@ def test_llama_3d(): if __name__ == "__main__": test_llama() - test_llama_3d() + # test_llama_3d() From 4b8a4125fb2d7aa6c7cfc07d763d3173a6b578d0 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 22 Jul 2024 03:39:19 +0000 Subject: [PATCH 12/71] fix typos and remove misc files --- colossalai/shardformer/layer/attn.py | 7 +- .../benchmark/benchmark_qkvpacked_func.py | 87 ---- .../benchmark_varlen_qkvpacked_func.py | 91 ---- .../ring_flash_attn/__init__.py | 16 - .../ring_flash_attn/ring_flash_attn.py | 281 ----------- .../ring_flash_attn/ring_flash_attn_varlen.py | 318 ------------- .../ring_flash_attn/stripe_flash_attn.py | 325 ------------- .../ring_flash_attn/triton_utils.py | 137 ------ ring-flash-attention/ring_flash_attn/utils.py | 110 ----- .../ring_flash_attn/zigzag_ring_flash_attn.py | 327 ------------- .../zigzag_ring_flash_attn_varlen.py | 441 ------------------ ring-flash-attention/setup.py | 9 - .../test/test_ring_flash_attn_func.py | 124 ----- .../test/test_ring_flash_attn_varlen_func.py | 157 ------- .../test/test_stripe_flash_attn_func.py | 130 ------ .../test/test_triton_kernels.py | 30 -- .../test/test_zigzag_ring_flash_attn_func.py | 150 ------ ...test_zigzag_ring_flash_attn_varlen_func.py | 163 ------- 18 files changed, 1 insertion(+), 2902 deletions(-) delete mode 100644 ring-flash-attention/benchmark/benchmark_qkvpacked_func.py delete mode 100644 ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py delete mode 100644 ring-flash-attention/ring_flash_attn/__init__.py delete mode 100644 ring-flash-attention/ring_flash_attn/ring_flash_attn.py delete mode 100644 ring-flash-attention/ring_flash_attn/ring_flash_attn_varlen.py delete mode 100644 ring-flash-attention/ring_flash_attn/stripe_flash_attn.py delete mode 100644 ring-flash-attention/ring_flash_attn/triton_utils.py delete mode 100644 ring-flash-attention/ring_flash_attn/utils.py delete mode 100644 ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn.py delete mode 100644 ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py delete mode 100644 ring-flash-attention/setup.py delete mode 100644 ring-flash-attention/test/test_ring_flash_attn_func.py delete mode 100644 ring-flash-attention/test/test_ring_flash_attn_varlen_func.py delete mode 100644 ring-flash-attention/test/test_stripe_flash_attn_func.py delete mode 100644 ring-flash-attention/test/test_triton_kernels.py delete mode 100644 ring-flash-attention/test/test_zigzag_ring_flash_attn_func.py delete mode 100644 ring-flash-attention/test/test_zigzag_ring_flash_attn_varlen_func.py diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 17a965bfac03..e6fa5f65b52b 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -251,12 +251,7 @@ def attention( # sanity check if attention_mask is not None: assert torch.is_floating_point(attention_mask), "attention_mask should be a floating point tensor." - if attention_mask_type in ( - AttnMaskType.CUSTOM, - AttnMaskType.CAUSAL, - AttnMaskType.PADDED, - AttnMaskType.PADDED_CAUSAL, - ): + if attention_mask_type in (AttnMaskType.CUSTOM, AttnMaskType.CAUSAL): assert ( cu_seqlens_q is None and cu_seqlens_kv is None diff --git a/ring-flash-attention/benchmark/benchmark_qkvpacked_func.py b/ring-flash-attention/benchmark/benchmark_qkvpacked_func.py deleted file mode 100644 index a6742e04a696..000000000000 --- a/ring-flash-attention/benchmark/benchmark_qkvpacked_func.py +++ /dev/null @@ -1,87 +0,0 @@ -import torch -import torch.cuda -import torch.distributed as dist -from flash_attn import flash_attn_qkvpacked_func -from ring_flash_attn import ( - ring_flash_attn_qkvpacked_func, - stripe_flash_attn_qkvpacked_func, - zigzag_ring_flash_attn_qkvpacked_func, -) - - -def benchmark(f, num_iter=100, forward_only=True, log=True): - dtype = torch.bfloat16 - rank = dist.get_rank() - world_size = dist.get_world_size() - device = torch.device(f"cuda:{rank}") - torch.cuda.set_device(device) - - batch_size = 1 - seqlen = 1024 * 8 - nheads = 5 - d = 128 - dropout_p = 0 - causal = True - deterministic = False - - assert seqlen % (2 * world_size) == 0 - assert d % 8 == 0 - - qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) - dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) - - begin = torch.cuda.Event(enable_timing=True) - begin.record() - - if forward_only: - with torch.no_grad(): - for _ in range(num_iter): - _ = f( - qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=False, - ) - - else: - for _ in range(num_iter): - qkv.grad = None - out = f( - qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=False, - ) - out.backward(dout) - end = torch.cuda.Event(enable_timing=True) - end.record() - torch.cuda.synchronize(device=device) - time = begin.elapsed_time(end) / 1000.0 - - if rank == 0 and log: - print(f"{num_iter / time:.3f} iter/s, {time:.3f} sec") - - -if __name__ == "__main__": - dist.init_process_group("nccl") - rank = dist.get_rank() - - forward_only = False - - for f in [ - flash_attn_qkvpacked_func, - ring_flash_attn_qkvpacked_func, - zigzag_ring_flash_attn_qkvpacked_func, - stripe_flash_attn_qkvpacked_func, - ]: - torch.cuda.empty_cache() - if rank == 0: - print(f"# {f.__name__}") - benchmark(f, forward_only=forward_only, log=False) - benchmark(f, forward_only=forward_only, log=True) diff --git a/ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py b/ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py deleted file mode 100644 index 18c8cafc0078..000000000000 --- a/ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch -import torch.cuda -import torch.distributed as dist -from flash_attn import flash_attn_varlen_qkvpacked_func -from ring_flash_attn import ring_flash_attn_varlen_qkvpacked_func, zigzag_ring_flash_attn_varlen_qkvpacked_func - - -def benchmark(f, num_iter=100, forward_only=True, log=True): - dtype = torch.bfloat16 - rank = dist.get_rank() - world_size = dist.get_world_size() - device = torch.device(f"cuda:{rank}") - torch.cuda.set_device(device) - - seqlen = 1024 * 8 - nheads = 5 - d = 128 - dropout_p = 0 - causal = True - deterministic = False - - assert seqlen % (2 * world_size) == 0 - assert d % 8 == 0 - - qkv = torch.randn(seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) - dout = torch.randn(seqlen, nheads, d, device=device, dtype=dtype) - - cu_seqlens_list = [ - torch.tensor([0, 8192], device=device, dtype=torch.int32), - torch.tensor([0, 256, 7648, 8192], device=device, dtype=torch.int32), - torch.tensor([0, 4096, 8192], device=device, dtype=torch.int32), - torch.tensor([0, 3104, 6304, 7904, 8064, 8192], device=device, dtype=torch.int32), - ] - max_seqlen_list = [(cu_seqlens[1:] - cu_seqlens[:1]).max().item() for cu_seqlens in cu_seqlens_list] - - begin = torch.cuda.Event(enable_timing=True) - begin.record() - if forward_only: - with torch.no_grad(): - for i in range(num_iter): - _ = f( - qkv, - cu_seqlens_list[i % len(cu_seqlens_list)], - max_seqlen_list[i % len(max_seqlen_list)], - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=False, - ) - else: - for i in range(num_iter): - qkv.grad = None - out = f( - qkv, - cu_seqlens_list[i % len(cu_seqlens_list)], - max_seqlen_list[i % len(max_seqlen_list)], - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=False, - ) - out.backward(dout) - end = torch.cuda.Event(enable_timing=True) - end.record() - torch.cuda.synchronize(device=device) - time = begin.elapsed_time(end) / 1000.0 - - if rank == 0 and log: - print(f"{num_iter / time} iter/s, {time} sec") - - -if __name__ == "__main__": - dist.init_process_group("nccl") - rank = dist.get_rank() - - forward_only = False - - for f in [ - flash_attn_varlen_qkvpacked_func, - ring_flash_attn_varlen_qkvpacked_func, - zigzag_ring_flash_attn_varlen_qkvpacked_func, - ]: - torch.cuda.empty_cache() - if rank == 0: - print(f"# {f.__name__}") - benchmark(f, forward_only=forward_only, log=False) - benchmark(f, forward_only=forward_only, log=True) diff --git a/ring-flash-attention/ring_flash_attn/__init__.py b/ring-flash-attention/ring_flash_attn/__init__.py deleted file mode 100644 index 01d5ec36218c..000000000000 --- a/ring-flash-attention/ring_flash_attn/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from .ring_flash_attn import ring_flash_attn_func, ring_flash_attn_kvpacked_func, ring_flash_attn_qkvpacked_func -from .ring_flash_attn_varlen import ( - ring_flash_attn_varlen_func, - ring_flash_attn_varlen_kvpacked_func, - ring_flash_attn_varlen_qkvpacked_func, -) -from .stripe_flash_attn import stripe_flash_attn_func, stripe_flash_attn_kvpacked_func, stripe_flash_attn_qkvpacked_func -from .zigzag_ring_flash_attn import ( - zigzag_ring_flash_attn_func, - zigzag_ring_flash_attn_kvpacked_func, - zigzag_ring_flash_attn_qkvpacked_func, -) -from .zigzag_ring_flash_attn_varlen import ( - zigzag_ring_flash_attn_varlen_func, - zigzag_ring_flash_attn_varlen_qkvpacked_func, -) diff --git a/ring-flash-attention/ring_flash_attn/ring_flash_attn.py b/ring-flash-attention/ring_flash_attn/ring_flash_attn.py deleted file mode 100644 index b36484dbd145..000000000000 --- a/ring-flash-attention/ring_flash_attn/ring_flash_attn.py +++ /dev/null @@ -1,281 +0,0 @@ -import torch -from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward - -from .utils import RingComm, update_out_and_lse - - -def ring_flash_attn_forward( - process_group, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - comm = RingComm(process_group) - - out = None - lse = None - - next_k, next_v = None, None - - for step in range(comm.world_size): - if step + 1 != comm.world_size: - next_k: torch.Tensor = comm.send_recv(k) - next_v: torch.Tensor = comm.send_recv(v) - comm.commit() - - if not causal or step <= comm.rank: - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal and step == 0, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=True and dropout_p > 0, - ) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - - if step + 1 != comm.world_size: - comm.wait() - k = next_k - v = next_v - - out = out.to(q.dtype) - lse = lse.squeeze(dim=-1).transpose(1, 2) - return out, lse - - -def ring_flash_attn_backward( - process_group, - dout, - q, - k, - v, - out, - softmax_lse, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - kv_comm = RingComm(process_group) - d_kv_comm = RingComm(process_group) - dq, dk, dv = None, None, None - next_dk, next_dv = None, None - - block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) - block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) - block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) - - next_dk, next_dv = None, None - next_k, next_v = None, None - - for step in range(kv_comm.world_size): - if step + 1 != kv_comm.world_size: - next_k = kv_comm.send_recv(k) - next_v = kv_comm.send_recv(v) - kv_comm.commit() - if step <= kv_comm.rank or not causal: - bwd_causal = causal and step == 0 - _flash_attn_backward( - dout, - q, - k, - v, - out, - softmax_lse, - block_dq_buffer, - block_dk_buffer, - block_dv_buffer, - dropout_p, - softmax_scale, - bwd_causal, - window_size, - alibi_slopes, - deterministic, - rng_state=None, - ) - - if dq is None: - dq = block_dq_buffer.to(torch.float32) - dk = block_dk_buffer.to(torch.float32) - dv = block_dv_buffer.to(torch.float32) - else: - dq += block_dq_buffer - d_kv_comm.wait() - dk = block_dk_buffer + next_dk - dv = block_dv_buffer + next_dv - elif step != 0: - d_kv_comm.wait() - dk = next_dk - dv = next_dv - - if step + 1 != kv_comm.world_size: - kv_comm.wait() - k = next_k - v = next_v - - next_dk = d_kv_comm.send_recv(dk) - next_dv = d_kv_comm.send_recv(dv) - d_kv_comm.commit() - - d_kv_comm.wait() - - return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) - - -class RingFlashAttnFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_softmax, - group, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - - assert alibi_slopes is None - k = k.contiguous() - v = v.contiguous() - out, softmax_lse = ring_flash_attn_forward( - group, - q, - k, - v, - softmax_scale=softmax_scale, - dropout_p=dropout_p, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - deterministic=False, - ) - # this should be out_padded - ctx.save_for_backward(q, k, v, out, softmax_lse) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - ctx.group = group - return out if not return_softmax else (out, softmax_lse, None) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse = ctx.saved_tensors - dq, dk, dv = ring_flash_attn_backward( - ctx.group, - dout, - q, - k, - v, - out, - softmax_lse, - softmax_scale=ctx.softmax_scale, - dropout_p=ctx.dropout_p, - causal=ctx.causal, - window_size=ctx.window_size, - alibi_slopes=ctx.alibi_slopes, - deterministic=ctx.deterministic, - ) - return dq, dk, dv, None, None, None, None, None, None, None, None - - -def ring_flash_attn_qkvpacked_func( - qkv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return RingFlashAttnFunc.apply( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def ring_flash_attn_kvpacked_func( - q, - kv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return RingFlashAttnFunc.apply( - q, - kv[:, :, 0], - kv[:, :, 1], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def ring_flash_attn_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return RingFlashAttnFunc.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) diff --git a/ring-flash-attention/ring_flash_attn/ring_flash_attn_varlen.py b/ring-flash-attention/ring_flash_attn/ring_flash_attn_varlen.py deleted file mode 100644 index 118bdea4c7d0..000000000000 --- a/ring-flash-attention/ring_flash_attn/ring_flash_attn_varlen.py +++ /dev/null @@ -1,318 +0,0 @@ -import torch -from flash_attn.flash_attn_interface import _flash_attn_varlen_backward, _flash_attn_varlen_forward - -from .utils import RingComm, update_out_and_lse - -try: - from .triton_utils import flatten_varlen_lse, unflatten_varlen_lse -except: - from .utils import flatten_varlen_lse, unflatten_varlen_lse - - -def ring_flash_attn_varlen_forward( - process_group, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens, - max_seqlen, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - comm = RingComm(process_group) - - out = None - lse = None - next_k, next_v = None, None - - for step in range(comm.world_size): - if step + 1 != comm.world_size: - next_k: torch.Tensor = comm.send_recv(k) - next_v: torch.Tensor = comm.send_recv(v) - comm.commit() - if not causal or step <= comm.rank: - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_varlen_forward( - q, - k, - v, - cu_seqlens, - cu_seqlens, - max_seqlen, - max_seqlen, - dropout_p, - softmax_scale, - causal=causal and step == 0, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=True and dropout_p > 0, - ) - block_lse = flatten_varlen_lse( - block_lse, - cu_seqlens=cu_seqlens, - ) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - - if step + 1 != comm.world_size: - comm.wait() - k = next_k - v = next_v - - out = out.to(q.dtype) - lse = unflatten_varlen_lse(lse, cu_seqlens, max_seqlen) - return out, lse - - -def ring_flash_attn_varlen_backward( - process_group, - dout, - q, - k, - v, - out, - softmax_lse, - cu_seqlens, - max_seqlen, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - kv_comm = RingComm(process_group) - d_kv_comm = RingComm(process_group) - dq, dk, dv = None, None, None - next_dk, next_dv = None, None - - block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) - block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) - block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) - - next_dk, next_dv = None, None - next_k, next_v = None, None - for step in range(kv_comm.world_size): - if step + 1 != kv_comm.world_size: - next_k = kv_comm.send_recv(k) - next_v = kv_comm.send_recv(v) - kv_comm.commit() - if step <= kv_comm.rank or not causal: - bwd_causal = causal and step == 0 - _flash_attn_varlen_backward( - dout, - q, - k, - v, - out, - softmax_lse, - block_dq_buffer, - block_dk_buffer, - block_dv_buffer, - cu_seqlens, - cu_seqlens, - max_seqlen, - max_seqlen, - dropout_p, - softmax_scale, - bwd_causal, - window_size, - alibi_slopes, - deterministic, - rng_state=None, - ) - - if dq is None: - dq = block_dq_buffer.to(torch.float32) - dk = block_dk_buffer.to(torch.float32) - dv = block_dv_buffer.to(torch.float32) - else: - dq += block_dq_buffer - d_kv_comm.wait() - dk = block_dk_buffer + next_dk - dv = block_dv_buffer + next_dv - elif step != 0: - d_kv_comm.wait() - dk = next_dk - dv = next_dv - - if step + 1 != kv_comm.world_size: - kv_comm.wait() - k = next_k - v = next_v - - next_dk = d_kv_comm.send_recv(dk) - next_dv = d_kv_comm.send_recv(dv) - d_kv_comm.commit() - - d_kv_comm.wait() - - return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) - - -class RingFlashAttnVarlenFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_softmax, - group, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - - assert alibi_slopes is None - k = k.contiguous() - v = v.contiguous() - out, softmax_lse = ring_flash_attn_varlen_forward( - group, - q, - k, - v, - cu_seqlens, - max_seqlen, - softmax_scale=softmax_scale, - dropout_p=dropout_p, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - deterministic=False, - ) - # this should be out_padded - ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens) - ctx.max_seqlen = max_seqlen - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - ctx.group = group - return out if not return_softmax else (out, softmax_lse, None) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens = ctx.saved_tensors - dq, dk, dv = ring_flash_attn_varlen_backward( - ctx.group, - dout, - q, - k, - v, - out, - softmax_lse, - cu_seqlens, - ctx.max_seqlen, - softmax_scale=ctx.softmax_scale, - dropout_p=ctx.dropout_p, - causal=ctx.causal, - window_size=ctx.window_size, - alibi_slopes=ctx.alibi_slopes, - deterministic=ctx.deterministic, - ) - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None - - -def ring_flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return RingFlashAttnVarlenFunc.apply( - qkv[:, 0], - qkv[:, 1], - qkv[:, 2], - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def ring_flash_attn_varlen_kvpacked_func( - q, - kv, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return RingFlashAttnVarlenFunc.apply( - q, - kv[:, 0], - kv[:, 1], - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def ring_flash_attn_varlen_func( - q, - k, - v, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return RingFlashAttnVarlenFunc.apply( - q, - k, - v, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) diff --git a/ring-flash-attention/ring_flash_attn/stripe_flash_attn.py b/ring-flash-attention/ring_flash_attn/stripe_flash_attn.py deleted file mode 100644 index ca426920f4ed..000000000000 --- a/ring-flash-attention/ring_flash_attn/stripe_flash_attn.py +++ /dev/null @@ -1,325 +0,0 @@ -import torch -from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward - -from .utils import RingComm, update_out_and_lse - - -def stripe_flash_attn_forward( - process_group, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - assert causal, "stripe flash attn only supports causal attention, if not causal, use ring flash attn instead" - comm = RingComm(process_group) - - out = None - lse = None - - next_k, next_v = None, None - - for step in range(comm.world_size): - if step + 1 != comm.world_size: - next_k: torch.Tensor = comm.send_recv(k) - next_v: torch.Tensor = comm.send_recv(v) - comm.commit() - - if step <= comm.rank: - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=True and dropout_p > 0, - ) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - else: - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( - q[:, 1:], - k[:, :-1], - v[:, :-1], - dropout_p, - softmax_scale, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=True and dropout_p > 0, - ) - out, lse = update_out_and_lse(out, lse, block_out, block_lse, slice_=(slice(None), slice(1, None))) - - if step + 1 != comm.world_size: - comm.wait() - k = next_k - v = next_v - - out = out.to(q.dtype) - lse = lse.squeeze(dim=-1).transpose(1, 2) - return out, lse - - -def stripe_flash_attn_backward( - process_group, - dout, - q, - k, - v, - out, - softmax_lse, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - assert causal, "stripe flash attn only supports causal attention, if not causal, ring flash attn instead" - kv_comm = RingComm(process_group) - d_kv_comm = RingComm(process_group) - dq, dk, dv = None, None, None - next_dk, next_dv = None, None - next_k, next_v = None, None - dk_comm_buffer, dv_comm_buffer = None, None - - block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) - block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) - block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) - for step in range(kv_comm.world_size): - if step + 1 != kv_comm.world_size: - next_k = kv_comm.send_recv(k) - next_v = kv_comm.send_recv(v) - kv_comm.commit() - - shift_causal = step > kv_comm.rank - softmax_lse_1 = None - if not shift_causal: - _flash_attn_backward( - dout, - q, - k, - v, - out, - softmax_lse, - block_dq_buffer, - block_dk_buffer, - block_dv_buffer, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - rng_state=None, - ) - else: - if softmax_lse_1 is None: - # lazy init, since the last rank does not need softmax_lse_1 - softmax_lse_1 = softmax_lse[:, :, 1:].contiguous() - _flash_attn_backward( - dout[:, 1:], - q[:, 1:], - k[:, :-1], - v[:, :-1], - out[:, 1:], - softmax_lse_1, - block_dq_buffer[:, 1:], - block_dk_buffer[:, :-1], - block_dv_buffer[:, :-1], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - rng_state=None, - ) - - if dq is None: - dq = block_dq_buffer.to(torch.float32) - dk = block_dk_buffer.to(torch.float32) - dv = block_dv_buffer.to(torch.float32) - else: - if not shift_causal: - dq += block_dq_buffer - else: - dq[:, 1:] += block_dq_buffer[:, 1:] - d_kv_comm.wait() - dk_comm_buffer, dv_comm_buffer = dk, dv - dk = next_dk - dv = next_dv - - if not shift_causal: - dk = block_dk_buffer + dk - dv = block_dv_buffer + dv - else: - dk[:, :-1] += block_dk_buffer[:, :-1] - dv[:, :-1] += block_dv_buffer[:, :-1] - - if step + 1 != kv_comm.world_size: - kv_comm.wait() - k = next_k - v = next_v - - next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) - next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) - d_kv_comm.commit() - - d_kv_comm.wait() - - return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) - - -class StripeFlashAttnFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_softmax, - group, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - - assert alibi_slopes is None - k = k.contiguous() - v = v.contiguous() - out, softmax_lse = stripe_flash_attn_forward( - group, - q, - k, - v, - softmax_scale=softmax_scale, - dropout_p=dropout_p, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - deterministic=False, - ) - # this should be out_padded - ctx.save_for_backward(q, k, v, out, softmax_lse) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - ctx.group = group - return out if not return_softmax else (out, softmax_lse, None) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse = ctx.saved_tensors - dq, dk, dv = stripe_flash_attn_backward( - ctx.group, - dout, - q, - k, - v, - out, - softmax_lse, - softmax_scale=ctx.softmax_scale, - dropout_p=ctx.dropout_p, - causal=ctx.causal, - window_size=ctx.window_size, - alibi_slopes=ctx.alibi_slopes, - deterministic=ctx.deterministic, - ) - return dq, dk, dv, None, None, None, None, None, None, None, None - - -def stripe_flash_attn_qkvpacked_func( - qkv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return StripeFlashAttnFunc.apply( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def stripe_flash_attn_kvpacked_func( - q, - kv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return StripeFlashAttnFunc.apply( - q, - kv[:, :, 0], - kv[:, :, 1], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def stripe_flash_attn_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return StripeFlashAttnFunc.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) diff --git a/ring-flash-attention/ring_flash_attn/triton_utils.py b/ring-flash-attention/ring_flash_attn/triton_utils.py deleted file mode 100644 index 66e362d93d68..000000000000 --- a/ring-flash-attention/ring_flash_attn/triton_utils.py +++ /dev/null @@ -1,137 +0,0 @@ -import torch -import triton -import triton.language as tl - - -@triton.jit -def flatten_kernel( - # pointers to matrices - OUT, - LSE, - CU_SEQLENS, - # strides - stride_out_nheads, - stride_out_seqlen, - stride_lse_batch, - stride_lse_nheads, - stride_lse_seqlen, - # meta-parameters - BLOCK_M: tl.constexpr, -): - pid_m = tl.program_id(axis=0) - pid_batch = tl.program_id(axis=1) - pid_head = tl.program_id(axis=2) - - start_idx = tl.load(CU_SEQLENS + pid_batch) - seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx - LSE = LSE + pid_batch * stride_lse_batch + pid_head * stride_lse_nheads - OUT = OUT + pid_head * stride_out_nheads + start_idx * stride_out_seqlen - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - - LSE = LSE + rm[:, None] * stride_lse_seqlen - x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) - - OUT = OUT + rm[:, None] * stride_out_seqlen - tl.store(OUT, x, mask=rm[:, None] < seqlen) - - -def flatten_varlen_lse(lse, cu_seqlens): - """ - Arguments: - lse: (batch_size, nheads, max_seqlen) - cu_seqlens: (batch_size + 1,) - Return: - flatten_lse: (nheads, total_seqlen) - """ - total_seqlen = cu_seqlens[-1] - batch_size, nheads, max_seqlen = lse.shape - output = torch.empty((nheads, total_seqlen), dtype=lse.dtype, device=lse.device) - - grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) - BLOCK_M = 4 - - with torch.cuda.device(lse.device.index): - flatten_kernel[grid]( - output, - lse, - cu_seqlens, - # strides - output.stride(0), - output.stride(1), - lse.stride(0), - lse.stride(1), - lse.stride(2), - BLOCK_M, - ) - return output - - -@triton.jit -def unflatten_kernel( - # pointers to matrices - OUT, - LSE, - CU_SEQLENS, - # strides - stride_out_batch, - stride_out_nheads, - stride_out_seqlen, - stride_lse_seqlen, - stride_lse_nheads, - # meta-parameters - BLOCK_M: tl.constexpr, -): - pid_m = tl.program_id(axis=0) - pid_batch = tl.program_id(axis=1) - pid_head = tl.program_id(axis=2) - - start_idx = tl.load(CU_SEQLENS + pid_batch) - seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx - LSE = LSE + pid_head * stride_lse_nheads + start_idx * stride_lse_seqlen - OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - - LSE = LSE + rm[:, None] * stride_lse_seqlen - x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) - - OUT = OUT + rm[:, None] * stride_out_seqlen - tl.store(OUT, x, mask=rm[:, None] < seqlen) - - -def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): - """ - Arguments: - lse: (total_seqlen, nheads, 1) - cu_seqlens: (batch_size + 1,) - max_seqlen: int - Return: - unflatten_lse: (batch_size, nheads, max_seqlen) - """ - lse = lse.unsqueeze(dim=-1) - batch_size = len(cu_seqlens) - 1 - nheads = lse.shape[1] - output = torch.empty( - (batch_size, nheads, max_seqlen), - dtype=lse.dtype, - device=lse.device, - ) - - grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) - BLOCK_M = 4 - - with torch.cuda.device(lse.device.index): - unflatten_kernel[grid]( - output, - lse, - cu_seqlens, - # strides - output.stride(0), - output.stride(1), - output.stride(2), - lse.stride(0), - lse.stride(1), - BLOCK_M, - ) - return output diff --git a/ring-flash-attention/ring_flash_attn/utils.py b/ring-flash-attention/ring_flash_attn/utils.py deleted file mode 100644 index 787732af8135..000000000000 --- a/ring-flash-attention/ring_flash_attn/utils.py +++ /dev/null @@ -1,110 +0,0 @@ -from typing import Optional, Tuple - -import torch -import torch.distributed as dist -import torch.nn.functional as F - -__all__ = ["update_out_and_lse", "RingComm"] - - -@torch.jit.script -def _update_out_and_lse( - out: torch.Tensor, - lse: torch.Tensor, - block_out: torch.Tensor, - block_lse: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - - block_out = block_out.to(torch.float32) - block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) - - # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) - # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out - # For additional context and discussion, please refer to: - # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 - out = out - F.sigmoid(block_lse - lse) * (out - block_out) - lse = lse - F.logsigmoid(lse - block_lse) - - return out, lse - - -def update_out_and_lse( - out: Optional[torch.Tensor], - lse: Optional[torch.Tensor], - block_out: torch.Tensor, - block_lse: torch.Tensor, - slice_=None, -) -> Tuple[torch.Tensor, torch.Tensor]: - if out is None: - if slice_ is not None: - raise RuntimeError("first update_out_and_lse should not pass slice_ args") - out = block_out.to(torch.float32) - lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) - elif slice_ is not None: - slice_out, slice_lse = out[slice_], lse[slice_] - slice_out, slice_lse = _update_out_and_lse(slice_out, slice_lse, block_out, block_lse) - out[slice_], lse[slice_] = slice_out, slice_lse - else: - out, lse = _update_out_and_lse(out, lse, block_out, block_lse) - return out, lse - - -@torch.jit.script -def flatten_varlen_lse(lse, cu_seqlens): - new_lse = [] - for i in range(len(cu_seqlens) - 1): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - new_lse.append(lse[i, :, : end - start]) - return torch.cat(new_lse, dim=1) - - -@torch.jit.script -def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): - num_seq = len(cu_seqlens) - 1 - num_head = lse.shape[-2] - new_lse = torch.empty((num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device) - for i in range(num_seq): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - new_lse[i, : end - start] = lse[start:end] - return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous() - - -class RingComm: - def __init__(self, process_group: dist.ProcessGroup): - self._process_group = process_group - self._ops = [] - self.rank = dist.get_rank(self._process_group) - self.world_size = dist.get_world_size(self._process_group) - self._reqs = None - - self.send_rank = (self.rank + 1) % self.world_size - self.recv_rank = (self.rank - 1) % self.world_size - - if process_group is not None: - self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) - self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) - - def send_recv(self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None) -> torch.Tensor: - if recv_tensor is None: - res = torch.empty_like(to_send) - else: - res = recv_tensor - - send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group) - recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) - self._ops.append(send_op) - self._ops.append(recv_op) - return res - - def commit(self): - if self._reqs is not None: - raise RuntimeError("commit called twice") - self._reqs = dist.batch_isend_irecv(self._ops) - - def wait(self): - if self._reqs is None: - raise RuntimeError("wait called before commit") - for req in self._reqs: - req.wait() - self._reqs = None - self._ops = [] diff --git a/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn.py b/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn.py deleted file mode 100644 index d3e2821c5d4d..000000000000 --- a/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn.py +++ /dev/null @@ -1,327 +0,0 @@ -import torch -import torch.distributed as dist -from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward - -from .utils import RingComm, update_out_and_lse - - -def zigzag_ring_flash_attn_forward( - process_group, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - assert causal == True, "zigzag ring is meaningless for causal=False" - comm = RingComm(process_group) - - block_seq_len = q.shape[1] // 2 - q1 = q[:, block_seq_len:] - - out = None - lse = None - next_k, next_v = None, None - - def forward(q, k, v, causal): - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=True and dropout_p > 0, - ) - return block_out, block_lse - - for step in range(comm.world_size): - if step + 1 != comm.world_size: - next_k: torch.Tensor = comm.send_recv(k) - next_v: torch.Tensor = comm.send_recv(v) - comm.commit() - - if step == 0: - block_out, block_lse = forward(q, k, v, causal=True) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - elif step <= comm.rank: - k0 = k[:, :block_seq_len] - v0 = v[:, :block_seq_len] - block_out, block_lse = forward(q, k0, v0, causal=False) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - else: - block_out, block_lse = forward(q1, k, v, causal=False) - out, lse = update_out_and_lse( - out, - lse, - block_out, - block_lse, - slice_=(slice(None), slice(block_seq_len, None)), - ) - - if step + 1 != comm.world_size: - comm.wait() - k = next_k - v = next_v - - out = out.to(q.dtype) - lse = lse.squeeze(dim=-1).transpose(1, 2) - return out, lse - - -def zigzag_ring_flash_attn_backward( - process_group, - dout, - q, - k, - v, - out, - softmax_lse, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - assert causal == True, "zigzag ring is meaningless for causal=False" - kv_comm = RingComm(process_group) - d_kv_comm = RingComm(process_group) - dq, dk, dv = None, None, None - next_dk, next_dv = None, None - next_k, next_v = None, None - dk_comm_buffer, dv_comm_buffer = None, None - - dout1 = dout.chunk(2, dim=1)[1] - q1 = q.chunk(2, dim=1)[1] - out1 = out.chunk(2, dim=1)[1] - softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous() - block_seq_len = q.shape[1] // 2 - - # repeatly allocating buffer may be slow... - dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) - dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) - dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) - - def backward(dout, q, k, v, out, softmax_lse, causal): - seqlen_q = q.shape[1] - seqlen_kv = k.shape[1] - _flash_attn_backward( - dout, - q, - k, - v, - out, - softmax_lse, - dq_buffer[:, :seqlen_q], - dk_buffer[:, :seqlen_kv], - dv_buffer[:, :seqlen_kv], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - rng_state=None, - ) - - for step in range(kv_comm.world_size): - if step + 1 != kv_comm.world_size: - next_k = kv_comm.send_recv(k) - next_v = kv_comm.send_recv(v) - kv_comm.commit() - - if step == 0: - backward(dout, q, k, v, out, softmax_lse, causal=True) - dq = dq_buffer.to(torch.float32) - dk = dk_buffer.to(torch.float32) - dv = dv_buffer.to(torch.float32) - else: - if step <= kv_comm.rank: - k0 = k[:, :block_seq_len] - v0 = v[:, :block_seq_len] - backward(dout, q, k0, v0, out, softmax_lse, causal=False) - dq += dq_buffer - else: - backward(dout1, q1, k, v, out1, softmax_lse1, causal=False) - # always use the first half in dq_buffer. - dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len] - - d_kv_comm.wait() - dk_comm_buffer, dv_comm_buffer = dk, dv - dk, dv = next_dk, next_dv - - if step <= kv_comm.rank: - dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len] - dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len] - else: - dk += dk_buffer - dv += dv_buffer - if dist.get_rank() == 0: - torch.save(torch.stack((dk, dv)), f"step_{step}.pt") - - if step + 1 != kv_comm.world_size: - kv_comm.wait() - k = next_k - v = next_v - - next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) - next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) - d_kv_comm.commit() - - d_kv_comm.wait() - - return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) - - -class ZigZagRingFlashAttnFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_softmax, - group, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - - assert alibi_slopes is None - k = k.contiguous() - v = v.contiguous() - out, softmax_lse = zigzag_ring_flash_attn_forward( - group, - q, - k, - v, - softmax_scale=softmax_scale, - dropout_p=dropout_p, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - deterministic=False, - ) - # this should be out_padded - ctx.save_for_backward(q, k, v, out, softmax_lse) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - ctx.group = group - return out if not return_softmax else (out, softmax_lse, None) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse = ctx.saved_tensors - dq, dk, dv = zigzag_ring_flash_attn_backward( - ctx.group, - dout, - q, - k, - v, - out, - softmax_lse, - softmax_scale=ctx.softmax_scale, - dropout_p=ctx.dropout_p, - causal=ctx.causal, - window_size=ctx.window_size, - alibi_slopes=ctx.alibi_slopes, - deterministic=ctx.deterministic, - ) - return dq, dk, dv, None, None, None, None, None, None, None, None - - -def zigzag_ring_flash_attn_qkvpacked_func( - qkv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return ZigZagRingFlashAttnFunc.apply( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def zigzag_ring_flash_attn_kvpacked_func( - q, - kv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return ZigZagRingFlashAttnFunc.apply( - q, - kv[:, :, 0], - kv[:, :, 1], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def zigzag_ring_flash_attn_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return ZigZagRingFlashAttnFunc.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) diff --git a/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py b/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py deleted file mode 100644 index 5d4a8dd2daf0..000000000000 --- a/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py +++ /dev/null @@ -1,441 +0,0 @@ -import torch -from flash_attn.flash_attn_interface import _flash_attn_varlen_backward, _flash_attn_varlen_forward - -from .utils import RingComm, update_out_and_lse - -try: - from .triton_utils import flatten_varlen_lse, unflatten_varlen_lse -except: - from .utils import flatten_varlen_lse, unflatten_varlen_lse - - -def get_half_index(cu_seqlens, *, front: bool): - if len(cu_seqlens) == 2: - if front: - return slice(None, cu_seqlens[-1] // 2) - else: - return slice(cu_seqlens[-1] // 2, None) - - index = torch.zeros((cu_seqlens[-1],), dtype=bool) - for i in range(len(cu_seqlens) - 1): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - if front: - end = (start + end) // 2 - else: - start = (start + end) // 2 - index[start:end] = True - return index - - -@torch.jit.script -def get_half_lse(lse, cu_seqlens, *, front: bool): - new_lse = torch.empty( - (lse.shape[0], lse.shape[1], lse.shape[2] // 2), - dtype=lse.dtype, - device=lse.device, - ) - for i in range(len(cu_seqlens) - 1): - seqlen = (cu_seqlens[i + 1] - cu_seqlens[i]).item() - if front: - start, end = 0, seqlen // 2 - else: - start, end = seqlen // 2, seqlen - new_lse[i, :, : seqlen // 2] = lse[i, :, start:end] - return new_lse - - -def zigzag_ring_flash_attn_varlen_forward( - process_group, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens, - max_seqlen, - half_index0, - half_index1, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - assert causal == True, "zigzag ring is meaningless for causal=False" - comm = RingComm(process_group) - - block_seq_len = q.shape[0] // 2 - q1 = q[half_index1] - - out = None - lse = None - next_k, next_v = None, None - half_cu_seqlens = cu_seqlens // 2 - half_max_seqlen = max_seqlen // 2 - - def forward(q, k, v, causal): - seqlen_q = q.shape[0] - seqlen_kv = k.shape[0] - cu_seqlens_q = half_cu_seqlens if seqlen_q == block_seq_len else cu_seqlens - max_seqlen_q = half_max_seqlen if seqlen_q == block_seq_len else max_seqlen - cu_seqlens_kv = half_cu_seqlens if seqlen_kv == block_seq_len else cu_seqlens - max_seqlen_kv = half_max_seqlen if seqlen_kv == block_seq_len else max_seqlen - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_varlen_forward( - q, - k, - v, - # the first half and the second half are the same - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - dropout_p, - softmax_scale, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=True and dropout_p > 0, - ) - return block_out, block_lse - - for step in range(comm.world_size): - if step + 1 != comm.world_size: - next_k: torch.Tensor = comm.send_recv(k) - next_v: torch.Tensor = comm.send_recv(v) - comm.commit() - - if step == 0: - block_out, block_lse = forward(q, k, v, causal=True) - block_lse = flatten_varlen_lse( - block_lse, - cu_seqlens=cu_seqlens, - ) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - elif step <= comm.rank: - k0 = k[half_index0] - v0 = v[half_index0] - block_out, block_lse = forward(q, k0, v0, causal=False) - block_lse = flatten_varlen_lse( - block_lse, - cu_seqlens=cu_seqlens, - ) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - else: - block_out, block_lse = forward(q1, k, v, causal=False) - block_lse = flatten_varlen_lse( - block_lse, - cu_seqlens=half_cu_seqlens, - ) - out[half_index1], lse[half_index1] = update_out_and_lse( - out[half_index1], lse[half_index1], block_out, block_lse - ) - - if step + 1 != comm.world_size: - comm.wait() - k = next_k - v = next_v - - out = out.to(q.dtype) - lse = unflatten_varlen_lse(lse, cu_seqlens, max_seqlen) - return out, lse - - -def zigzag_ring_flash_attn_varlen_backward( - process_group, - dout, - q, - k, - v, - out, - softmax_lse, - cu_seqlens, - max_seqlen, - half_index0, - half_index1, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - assert causal == True, "zigzag ring is meaningless for causal=False" - kv_comm = RingComm(process_group) - d_kv_comm = RingComm(process_group) - dq, dk, dv = None, None, None - next_dk, next_dv = None, None - next_k, next_v = None, None - dk_comm_buffer, dv_comm_buffer = None, None - - dout1 = dout[half_index1] - q1 = q[half_index1] - out1 = out[half_index1] - softmax_lse1 = get_half_lse(softmax_lse, cu_seqlens, front=False) - block_seq_len = q.shape[0] // 2 - - half_cu_seqlens = cu_seqlens // 2 - half_max_seqlen = max_seqlen // 2 - - # repeatly allocating buffer may be slow... - dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) - dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) - dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) - - def backward(dout, q, k, v, out, softmax_lse, causal): - seqlen_q = q.shape[0] - seqlen_kv = k.shape[0] - cu_seqlens_q = half_cu_seqlens if seqlen_q == block_seq_len else cu_seqlens - max_seqlen_q = half_max_seqlen if seqlen_q == block_seq_len else max_seqlen - cu_seqlens_kv = half_cu_seqlens if seqlen_kv == block_seq_len else cu_seqlens - max_seqlen_kv = half_max_seqlen if seqlen_kv == block_seq_len else max_seqlen - _flash_attn_varlen_backward( - dout, - q, - k, - v, - out, - softmax_lse, - dq_buffer[:seqlen_q], - dk_buffer[:seqlen_kv], - dv_buffer[:seqlen_kv], - # the first half and the second half are the same - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - rng_state=None, - ) - - for step in range(kv_comm.world_size): - if step + 1 != kv_comm.world_size: - next_k = kv_comm.send_recv(k) - next_v = kv_comm.send_recv(v) - kv_comm.commit() - - if step == 0: - backward(dout, q, k, v, out, softmax_lse, causal=True) - dq = dq_buffer.to(torch.float32) - dk = dk_buffer.to(torch.float32) - dv = dv_buffer.to(torch.float32) - else: - if step <= kv_comm.rank: - k0 = k[half_index0] - v0 = v[half_index0] - backward(dout, q, k0, v0, out, softmax_lse, causal=False) - dq += dq_buffer - else: - backward(dout1, q1, k, v, out1, softmax_lse1, causal=False) - dq[half_index1] += dq_buffer[:block_seq_len] - - d_kv_comm.wait() - dk_comm_buffer, dv_comm_buffer = dk, dv - dk, dv = next_dk, next_dv - - if step <= kv_comm.rank: - dk[half_index0] += dk_buffer[:block_seq_len] - dv[half_index0] += dv_buffer[:block_seq_len] - else: - dk += dk_buffer - dv += dv_buffer - - if step + 1 != kv_comm.world_size: - kv_comm.wait() - k = next_k - v = next_v - - next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) - next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) - d_kv_comm.commit() - - d_kv_comm.wait() - - return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) - - -class ZigZagRingFlashAttnVarlenFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_softmax, - group, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - - assert alibi_slopes is None - k = k.contiguous() - v = v.contiguous() - half_index0 = get_half_index(cu_seqlens, front=True) - half_index1 = get_half_index(cu_seqlens, front=False) - out, softmax_lse = zigzag_ring_flash_attn_varlen_forward( - group, - q, - k, - v, - cu_seqlens, - max_seqlen, - half_index0, - half_index1, - softmax_scale=softmax_scale, - dropout_p=dropout_p, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - deterministic=False, - ) - # this should be out_padded - is_half_index_tensor = isinstance(half_index0, torch.Tensor) - ctx.is_half_index_tensor = is_half_index_tensor - if is_half_index_tensor: - ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens, half_index0, half_index1) - else: - ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens) - ctx.half_index0 = half_index0 - ctx.half_index1 = half_index1 - ctx.max_seqlen = max_seqlen - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - ctx.group = group - return out if not return_softmax else (out, softmax_lse, None) - - @staticmethod - def backward(ctx, dout, *args): - if ctx.is_half_index_tensor: - (q, k, v, out, softmax_lse, cu_seqlens, half_index0, half_index1) = ctx.saved_tensors - else: - q, k, v, out, softmax_lse, cu_seqlens = ctx.saved_tensors - half_index0 = ctx.half_index0 - half_index1 = ctx.half_index1 - dq, dk, dv = zigzag_ring_flash_attn_varlen_backward( - ctx.group, - dout, - q, - k, - v, - out, - softmax_lse, - cu_seqlens, - ctx.max_seqlen, - half_index0, - half_index1, - softmax_scale=ctx.softmax_scale, - dropout_p=ctx.dropout_p, - causal=ctx.causal, - window_size=ctx.window_size, - alibi_slopes=ctx.alibi_slopes, - deterministic=ctx.deterministic, - ) - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None - - -def zigzag_ring_flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return ZigZagRingFlashAttnVarlenFunc.apply( - qkv[:, 0], - qkv[:, 1], - qkv[:, 2], - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def zigzag_ring_flash_attn_varlen_kvpacked_func( - q, - kv, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return ZigZagRingFlashAttnVarlenFunc.apply( - q, - kv[:, 0], - kv[:, 1], - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def zigzag_ring_flash_attn_varlen_func( - q, - k, - v, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return ZigZagRingFlashAttnVarlenFunc.apply( - q, - k, - v, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) diff --git a/ring-flash-attention/setup.py b/ring-flash-attention/setup.py deleted file mode 100644 index 58413e1b54f3..000000000000 --- a/ring-flash-attention/setup.py +++ /dev/null @@ -1,9 +0,0 @@ -from setuptools import find_packages, setup - -setup( - name="ring_flash_attn", - version="0.1", - author="zhuzilin", - url="https://github.com/zhuzilin/ring-flash-attention", - packages=find_packages(), -) diff --git a/ring-flash-attention/test/test_ring_flash_attn_func.py b/ring-flash-attention/test/test_ring_flash_attn_func.py deleted file mode 100644 index 50edd03bef4e..000000000000 --- a/ring-flash-attention/test/test_ring_flash_attn_func.py +++ /dev/null @@ -1,124 +0,0 @@ -import random - -import torch -import torch.distributed as dist -from flash_attn import flash_attn_qkvpacked_func -from ring_flash_attn import ring_flash_attn_qkvpacked_func - - -def set_seed(rank, seed=42): - seed = rank + seed - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def log(msg, a, rank0_only=False): - world_size = dist.get_world_size() - rank = dist.get_rank() - if rank0_only: - if rank == 0: - print( - f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - return - - for i in range(world_size): - if i == rank: - if rank == 0: - print(f"{msg}:") - print( - f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - dist.barrier() - - -if __name__ == "__main__": - dist.init_process_group("nccl") - rank = dist.get_rank() - set_seed(rank) - world_size = dist.get_world_size() - dtype = torch.bfloat16 - device = torch.device(f"cuda:{rank}") - - batch_size = 1 - seqlen = 3816 - nheads = 5 - d = 128 - dropout_p = 0 - causal = True - deterministic = False - - assert seqlen % world_size == 0 - assert d % 8 == 0 - - qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) - dist.broadcast(qkv, src=0) - - dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) - dist.broadcast(dout, src=0) - - local_qkv = qkv.chunk(world_size, dim=1)[rank].detach().clone() - local_qkv.requires_grad = True - local_dout = dout.chunk(world_size, dim=1)[rank].detach().clone() - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# forward:") - print("#" * 30) - - out, lse, _ = flash_attn_qkvpacked_func( - qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - local_out = out.chunk(world_size, dim=1)[rank] - local_lse = lse.chunk(world_size, dim=-1)[rank] - - fn = ring_flash_attn_qkvpacked_func - - ring_out, ring_lse, _ = fn( - local_qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - log("out", out, rank0_only=True) - log("lse", lse, rank0_only=True) - log("out diff", local_out - ring_out) - log("lse diff", local_lse - ring_lse) - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# backward:") - print("#" * 30) - - out.backward(dout) - dqkv = qkv.grad - local_dqkv = dqkv.chunk(world_size, dim=1)[rank] - - ring_out.backward(local_dout) - ring_dqkv = local_qkv.grad - - log("local_dq", local_dqkv[:, :, 0, :]) - log("dq diff", local_dqkv[:, :, 0, :] - ring_dqkv[:, :, 0, :]) - - log("local_dk", local_dqkv[:, :, 1, :]) - log("dk diff", local_dqkv[:, :, 1, :] - ring_dqkv[:, :, 1, :]) - - log("local_dv", local_dqkv[:, :, 2, :]) - log("dv diff", local_dqkv[:, :, 2, :] - ring_dqkv[:, :, 2, :]) diff --git a/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py b/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py deleted file mode 100644 index 51bb1ec5d67d..000000000000 --- a/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py +++ /dev/null @@ -1,157 +0,0 @@ -import random - -import torch -import torch.distributed as dist -from flash_attn import flash_attn_varlen_qkvpacked_func -from ring_flash_attn import ring_flash_attn_varlen_qkvpacked_func - - -def set_seed(rank, seed=42): - seed = rank + seed - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def log(msg, a, rank0_only=False): - world_size = dist.get_world_size() - rank = dist.get_rank() - if rank0_only: - if rank == 0: - print( - f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - return - - for i in range(world_size): - if i == rank: - if rank == 0: - print(f"{msg}:") - print( - f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - dist.barrier() - - -def extract_local(value, cu_seqlens, rank, world_size): - local_values = [] - for i in range(len(cu_seqlens) - 1): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - local_value = value[start:end].chunk(world_size, dim=0)[rank].detach().clone() - local_values.append(local_value) - return torch.cat(local_values, dim=0).contiguous() - - -def extract_lse(lse, cu_seqlens): - values = [] - for i in range(len(cu_seqlens) - 1): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - value = lse[i, :, : end - start] - values.append(value) - return values - - -if __name__ == "__main__": - dist.init_process_group("nccl") - rank = dist.get_rank() - set_seed(rank) - world_size = dist.get_world_size() - dtype = torch.bfloat16 - device = torch.device(f"cuda:{rank}") - - batch_size = 1 - nheads = 5 - d = 128 - dropout_p = 0 - causal = True - deterministic = False - - cu_seqlens = [0, 120, 1248, 4232] - cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) - max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item() - total_length = cu_seqlens[-1] - num_seq = len(cu_seqlens) - 1 - - assert torch.all(cu_seqlens_tensor % world_size == 0) - assert d % 8 == 0 - - qkv = torch.randn(total_length, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) - dist.broadcast(qkv, src=0) - - dout = torch.randn(total_length, nheads, d, device=device, dtype=dtype) - dist.broadcast(dout, src=0) - - local_cu_seqlens_tensor = cu_seqlens_tensor // world_size - local_max_seqlen = max_seqlen // world_size - - local_qkv = extract_local(qkv, cu_seqlens, rank, world_size) - local_qkv.requires_grad = True - local_dout = extract_local(dout, cu_seqlens, rank, world_size) - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# forward:") - print("#" * 30) - - out, lse, _ = flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens_tensor, - max_seqlen, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - local_out = extract_local(out, cu_seqlens, rank, world_size) - lse_list = extract_lse(lse, cu_seqlens) - - ring_out, ring_lse, _ = ring_flash_attn_varlen_qkvpacked_func( - local_qkv, - local_cu_seqlens_tensor, - local_max_seqlen, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - ring_lse_list = extract_lse(ring_lse, local_cu_seqlens_tensor.tolist()) - - log("out", out, rank0_only=True) - log("out diff", local_out - ring_out) - - for lse, ring_lse in zip(lse_list, ring_lse_list): - local_lse = lse.chunk(world_size, dim=-1)[rank] - log("lse", lse, rank0_only=True) - log("lse diff", local_lse - ring_lse) - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# backward:") - print("#" * 30) - - out.backward(dout) - dqkv = qkv.grad - local_dqkv = extract_local(dqkv, cu_seqlens, rank, world_size) - - ring_out.backward(local_dout) - ring_dqkv = local_qkv.grad - - log("local_dq", local_dqkv[:, 0]) - log("dq diff", local_dqkv[:, 0] - ring_dqkv[:, 0]) - - log("local_dk", local_dqkv[:, 1]) - log("dk diff", local_dqkv[:, 1] - ring_dqkv[:, 1]) - - log("local_dv", local_dqkv[:, 2]) - log("dv diff", local_dqkv[:, 2] - ring_dqkv[:, 2]) diff --git a/ring-flash-attention/test/test_stripe_flash_attn_func.py b/ring-flash-attention/test/test_stripe_flash_attn_func.py deleted file mode 100644 index dc9f5248d69d..000000000000 --- a/ring-flash-attention/test/test_stripe_flash_attn_func.py +++ /dev/null @@ -1,130 +0,0 @@ -import random - -import torch -import torch.distributed as dist -from flash_attn import flash_attn_qkvpacked_func -from ring_flash_attn import stripe_flash_attn_qkvpacked_func - - -def set_seed(rank, seed=42): - seed = rank + seed - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def log(msg, a, rank0_only=False): - world_size = dist.get_world_size() - rank = dist.get_rank() - if rank0_only: - if rank == 0: - print( - f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - return - - for i in range(world_size): - if i == rank: - if rank == 0: - print(f"{msg}:") - print( - f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - dist.barrier() - - -def extract_local(value, rank, world_size, dim=1): - value = torch.stack(value.split(world_size, dim=dim), dim=dim).transpose(dim, dim + 1) - slicer = [rank if i == dim else slice(None) for i in range(len(value.shape))] - return value[slicer].contiguous() - - -if __name__ == "__main__": - dist.init_process_group("nccl") - rank = dist.get_rank() - set_seed(rank) - world_size = dist.get_world_size() - dtype = torch.bfloat16 - device = torch.device(f"cuda:{rank}") - - batch_size = 1 - seqlen = 3824 - nheads = 5 - d = 128 - dropout_p = 0 - causal = True - deterministic = False - - assert causal - assert seqlen % (2 * world_size) == 0 - assert d % 8 == 0 - - qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) - dist.broadcast(qkv, src=0) - - dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) - dist.broadcast(dout, src=0) - - local_qkv = extract_local(qkv, rank, world_size).detach().clone() - local_qkv.requires_grad = True - local_dout = extract_local(dout, rank, world_size).detach().clone() - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# forward:") - print("#" * 30) - - out, lse, _ = flash_attn_qkvpacked_func( - qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - local_out = extract_local(out, rank, world_size) - local_lse = extract_local(lse, rank, world_size, dim=2) - - ring_out, ring_lse, _ = stripe_flash_attn_qkvpacked_func( - local_qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - log("out", out, rank0_only=True) - log("lse", lse, rank0_only=True) - log("out diff", local_out - ring_out) - log("lse diff", local_lse - ring_lse) - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# backward:") - print("#" * 30) - - out.backward(dout) - dqkv = qkv.grad - - local_dqkv = extract_local(dqkv, rank, world_size) - - ring_out.backward(local_dout) - ring_dqkv = local_qkv.grad - - log("local_dq", local_dqkv[:, :, 0, :]) - log("dq diff", local_dqkv[:, :, 0, :] - ring_dqkv[:, :, 0, :]) - - log("local_dk", local_dqkv[:, :, 1, :]) - log("dk0 diff", local_dqkv[:, :, 1, :] - ring_dqkv[:, :, 1, :]) - - log("local_dv", local_dqkv[:, :, 2, :]) - log("dv diff", local_dqkv[:, :, 2, :] - ring_dqkv[:, :, 2, :]) diff --git a/ring-flash-attention/test/test_triton_kernels.py b/ring-flash-attention/test/test_triton_kernels.py deleted file mode 100644 index aa1c1fdcd338..000000000000 --- a/ring-flash-attention/test/test_triton_kernels.py +++ /dev/null @@ -1,30 +0,0 @@ -import torch -from ring_flash_attn.triton_utils import flatten_varlen_lse as triton_flatten_varlen_lse -from ring_flash_attn.triton_utils import unflatten_varlen_lse as triton_unflatten_varlen_lse -from ring_flash_attn.utils import flatten_varlen_lse, unflatten_varlen_lse - -if __name__ == "__main__": - device = torch.device("cuda:0") - - cu_seqlens = [0, 15, 156, 529] - cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) - batch_size = len(cu_seqlens) - 1 - max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item() - n_head = 5 - - lse = torch.randn((batch_size, n_head, max_seqlen), dtype=torch.float32, device=device) - flatten_lse = flatten_varlen_lse(lse, cu_seqlens_tensor) - triton_flatten_lse = triton_flatten_varlen_lse(lse, cu_seqlens_tensor) - assert torch.all(flatten_lse == triton_flatten_lse) - - flatten_lse = flatten_lse.transpose(-2, -1).unsqueeze(dim=-1) - triton_flatten_lse = triton_flatten_lse.transpose(-2, -1).unsqueeze(dim=-1) - - unflatten_lse = unflatten_varlen_lse(flatten_lse, cu_seqlens_tensor, max_seqlen) - triton_unflatten_lse = triton_unflatten_varlen_lse(triton_flatten_lse, cu_seqlens_tensor, max_seqlen) - - for i in range(batch_size): - seqlen = cu_seqlens[i + 1] - cu_seqlens[i] - assert torch.all( - unflatten_lse[i, :, :seqlen] == triton_unflatten_lse[i, :, :seqlen] - ), f"{unflatten_lse[i, :seqlen]} vs {triton_unflatten_lse[i, :seqlen]}" diff --git a/ring-flash-attention/test/test_zigzag_ring_flash_attn_func.py b/ring-flash-attention/test/test_zigzag_ring_flash_attn_func.py deleted file mode 100644 index 5f84bc58cf10..000000000000 --- a/ring-flash-attention/test/test_zigzag_ring_flash_attn_func.py +++ /dev/null @@ -1,150 +0,0 @@ -import os -import random - -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -from flash_attn import flash_attn_qkvpacked_func -from ring_flash_attn import zigzag_ring_flash_attn_qkvpacked_func - -from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention - - -def set_seed(rank, seed=42): - seed = rank + seed - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def log(msg, a, rank0_only=False): - world_size = dist.get_world_size() - rank = dist.get_rank() - if rank0_only: - if rank == 0: - print( - f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - return - - for i in range(world_size): - if i == rank: - if rank == 0: - print(f"{msg}:") - print( - f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - dist.barrier() - - -def extract_local(value, rank, world_size, dim=1): - value_chunks = value.chunk(2 * world_size, dim=dim) - local_value = torch.cat([value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim) - return local_value.contiguous() - - -def run_test(rank, world_size): - os.environ["MASTER_ADDR"] = "localhost" # or the IP of the master node - os.environ["MASTER_PORT"] = "8125" # make sure this port is free - dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) - set_seed(rank) - dtype = torch.bfloat16 - device = torch.device(f"cuda:{rank}") - - batch_size = 1 - seqlen = 3824 - nheads = 5 - d = 128 - dropout_p = 0 - causal = True - deterministic = False - - assert causal - assert seqlen % (2 * world_size) == 0 - assert d % 8 == 0 - - qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) - dist.broadcast(qkv, src=0) - - dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) - dist.broadcast(dout, src=0) - - local_qkv = extract_local(qkv, rank, world_size).detach().clone() - local_qkv.requires_grad = True - extract_local(dout, rank, world_size).detach().clone() - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# forward:") - print("#" * 30) - - out, lse, _ = flash_attn_qkvpacked_func( - qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - local_out = extract_local(out, rank, world_size) - # local_lse = extract_local(lse, rank, world_size, dim=2) - q, k, v = local_qkv.chunk(3, dim=2) - q, k, v = [x.squeeze(2).detach().clone().transpose(1, 2) for x in (q, k, v)] - q.requires_grad = k.requires_grad = v.requires_grad = True - sp_stream = torch.cuda.Stream() - sp_group = dist.new_group() - colo_out = RingAttention.attention(q, k, v, sp_group, sp_stream, AttnMaskType.CAUSAL) - - ring_out, ring_lse, _ = zigzag_ring_flash_attn_qkvpacked_func( - local_qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - log("colo_out", colo_out, rank0_only=True) - log("ring_out", ring_out, rank0_only=True) - # log("lse", lse, rank0_only=True) - log("colo_out - ring_out", colo_out - ring_out) - # log("lse diff", local_lse - ring_lse) - log("ring_out - local_out", ring_out - local_out) - log("colo_out - local_out", colo_out - local_out) - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# backward:") - print("#" * 30) - - colo_out.sum().backward() - qkv.grad - # q, k, v = [x.transpose(1, 2) for x in (q, k, v)] - colo_dq, colo_dk, colo_dv = [x.transpose(1, 2) for x in (q.grad, k.grad, v.grad)] - - ring_out.sum().backward() - ring_dqkv = local_qkv.grad - out.sum().backward() - dqkv = extract_local(qkv.grad, rank, world_size) - - # log("colo_dq", colo_dq) - log("dq diff", colo_dq - ring_dqkv[:, :, 0, :]) - - # log("colo_dk", colo_dk) - log("dk diff", colo_dk - ring_dqkv[:, :, 1, :]) - - # log("colo_dv", colo_dv) - log("dv diff", colo_dv - ring_dqkv[:, :, 2, :]) - log("colo_dv - local_dv", colo_dv - dqkv[:, :, 2, :]) - - -if __name__ == "__main__": - world_size = 4 - mp.spawn(run_test, args=(world_size,), nprocs=world_size, join=True) diff --git a/ring-flash-attention/test/test_zigzag_ring_flash_attn_varlen_func.py b/ring-flash-attention/test/test_zigzag_ring_flash_attn_varlen_func.py deleted file mode 100644 index 7f6eced6e57b..000000000000 --- a/ring-flash-attention/test/test_zigzag_ring_flash_attn_varlen_func.py +++ /dev/null @@ -1,163 +0,0 @@ -import random - -import torch -import torch.distributed as dist -from flash_attn import flash_attn_varlen_qkvpacked_func -from ring_flash_attn import zigzag_ring_flash_attn_varlen_qkvpacked_func - - -def set_seed(rank, seed=42): - seed = rank + seed - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def log(msg, a, rank0_only=False): - world_size = dist.get_world_size() - rank = dist.get_rank() - if rank0_only: - if rank == 0: - print( - f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - return - - for i in range(world_size): - if i == rank: - if rank == 0: - print(f"{msg}:") - print( - f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - dist.barrier() - - -def extract_local(value, cu_seqlens, rank, world_size): - local_values = [] - for i in range(len(cu_seqlens) - 1): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - local_value = value[start:end].chunk(2 * world_size, dim=0) - local_values.extend( - [ - local_value[rank].detach().clone(), - local_value[2 * world_size - 1 - rank].detach().clone(), - ] - ) - return torch.cat(local_values, dim=0).contiguous() - - -def extract_lse(lse, cu_seqlens): - values = [] - for i in range(len(cu_seqlens) - 1): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - value = lse[i, :, : end - start] - values.append(value) - return values - - -if __name__ == "__main__": - dist.init_process_group("nccl") - rank = dist.get_rank() - set_seed(rank) - world_size = dist.get_world_size() - dtype = torch.bfloat16 - device = torch.device(f"cuda:{rank}") - - batch_size = 1 - nheads = 5 - d = 128 - dropout_p = 0 - causal = True - deterministic = False - - cu_seqlens = [0, 128, 1248, 4240] - cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) - max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item() - total_length = cu_seqlens[-1] - num_seq = len(cu_seqlens) - 1 - - assert torch.all(cu_seqlens_tensor % (2 * world_size) == 0) - assert d % 8 == 0 - - qkv = torch.randn(total_length, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) - dist.broadcast(qkv, src=0) - - dout = torch.randn(total_length, nheads, d, device=device, dtype=dtype) - dist.broadcast(dout, src=0) - - local_cu_seqlens_tensor = cu_seqlens_tensor // world_size - local_max_seqlen = max_seqlen // world_size - - local_qkv = extract_local(qkv, cu_seqlens, rank, world_size) - local_qkv.requires_grad = True - local_dout = extract_local(dout, cu_seqlens, rank, world_size) - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# forward:") - print("#" * 30) - - out, lse, _ = flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens_tensor, - max_seqlen, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - local_out = extract_local(out, cu_seqlens, rank, world_size) - lse_list = extract_lse(lse, cu_seqlens) - - ring_out, ring_lse, _ = zigzag_ring_flash_attn_varlen_qkvpacked_func( - local_qkv, - local_cu_seqlens_tensor, - local_max_seqlen, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - ring_lse_list = extract_lse(ring_lse, local_cu_seqlens_tensor.tolist()) - - log("out", out, rank0_only=True) - log("out diff", local_out - ring_out) - - for i, (lse, ring_lse) in enumerate(zip(lse_list, ring_lse_list)): - local_lse = lse.chunk(2 * world_size, dim=-1) - local_lse = torch.cat([local_lse[rank], local_lse[2 * world_size - 1 - rank]], dim=-1) - log(f"lse {i}", lse, rank0_only=True) - log(f"lse diff {i}", local_lse - ring_lse) - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# backward:") - print("#" * 30) - - out.backward(dout) - dqkv = qkv.grad - local_dqkv = extract_local(dqkv, cu_seqlens, rank, world_size) - - ring_out.backward(local_dout) - ring_dqkv = local_qkv.grad - - log("local_dq", local_dqkv[:, 0]) - log("dq diff", local_dqkv - ring_dqkv) - - log("local_dk", local_dqkv[:, 1]) - log("dk diff", local_dqkv[:, 1] - ring_dqkv[:, 1]) - - log("local_dv", local_dqkv[:, 2]) - log("dv diff", local_dqkv[:, 2] - ring_dqkv[:, 2]) From 4f864b287597201b25d94f744f3b1a69de673f97 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 22 Jul 2024 07:42:50 +0000 Subject: [PATCH 13/71] update softmax_lse shape by new interface --- tests/test_shardformer/test_layer/test_ring_attn.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 1c7647a7d560..89b3608a4151 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -19,7 +19,13 @@ @parameterize("dtype", [torch.bfloat16, torch.float16]) def check_ring_attn(seq_len, bs, nheads, d, dtype): torch.cuda.manual_seed(2) +<<<<<<< HEAD device = get_current_device() +======= + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f"cuda:{rank}") +>>>>>>> update softmax_lse shape by new interface sp_group = dist.group.WORLD sp_size = dist.get_world_size() # Some outliers may seem large, but our errors are still lower than @@ -36,6 +42,7 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): q.requires_grad = k.requires_grad = v.requires_grad = True # Ring attention vs single GPU +<<<<<<< HEAD ring_out, ring_lse = RingAttention.attention( q, k, @@ -47,6 +54,10 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): # inner_ring_size=4 ) ring_out = ring_out.transpose(1, 2) +======= + ring_out, ring_lse = RingAttention.attention(q, k, v, sp_group, sp_stream, AttnMaskType.CAUSAL, return_softmax=True) + ring_lse = ring_lse.transpose(0, 1).view(batch_size, seq_len // world_size, nheads).transpose(1, 2).contiguous() +>>>>>>> update softmax_lse shape by new interface out, lse, _ = flash_attn_qkvpacked_func( qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True ) From 864dac64c2cfd830586fea24ecac709407e38daa Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 22 Jul 2024 10:12:48 +0000 Subject: [PATCH 14/71] change tester name --- tests/test_shardformer/test_layer/test_ring_attn.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 89b3608a4151..1c7647a7d560 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -19,13 +19,7 @@ @parameterize("dtype", [torch.bfloat16, torch.float16]) def check_ring_attn(seq_len, bs, nheads, d, dtype): torch.cuda.manual_seed(2) -<<<<<<< HEAD device = get_current_device() -======= - rank = dist.get_rank() - world_size = dist.get_world_size() - device = torch.device(f"cuda:{rank}") ->>>>>>> update softmax_lse shape by new interface sp_group = dist.group.WORLD sp_size = dist.get_world_size() # Some outliers may seem large, but our errors are still lower than @@ -42,7 +36,6 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): q.requires_grad = k.requires_grad = v.requires_grad = True # Ring attention vs single GPU -<<<<<<< HEAD ring_out, ring_lse = RingAttention.attention( q, k, @@ -54,10 +47,6 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): # inner_ring_size=4 ) ring_out = ring_out.transpose(1, 2) -======= - ring_out, ring_lse = RingAttention.attention(q, k, v, sp_group, sp_stream, AttnMaskType.CAUSAL, return_softmax=True) - ring_lse = ring_lse.transpose(0, 1).view(batch_size, seq_len // world_size, nheads).transpose(1, 2).contiguous() ->>>>>>> update softmax_lse shape by new interface out, lse, _ = flash_attn_qkvpacked_func( qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True ) From 69bf303dcb9d5572b41a3ba345a5271788bb9708 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 23 Jul 2024 11:14:53 +0000 Subject: [PATCH 15/71] remove buffer clone; support packed seq layout --- colossalai/shardformer/layer/attn.py | 4 +++- colossalai/shardformer/modeling/llama.py | 2 -- examples/language/opt/opt_benchmark.py | 1 + .../test_layer/test_ring_attn.py | 17 +++++++++++++++++ .../test_model/test_shard_llama.py | 2 +- 5 files changed, 22 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index e6fa5f65b52b..9029ed716f6a 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -646,7 +646,7 @@ def forward( # Pre-allocate double buffer for overlapping and receiving next step's inputs kv_buffers = [torch.stack((k, v))] # (2, B, Sq, H, D) - kv_buffers.append(torch.empty_like(kv_buffers[0])) + kv_buffers.append(None) # outputs out = None @@ -859,6 +859,8 @@ def backward(ctx, dout, _): cu_seqlens_half = cu_seqlens_q // 2 max_seqlen_half = max_seqlen_q // 2 misc_kwargs = ctx.misc_kwargs + is_packed = ctx.is_packed + dout = dout.contiguous() del misc_kwargs["block_table"] assert ( diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 258ba3051f8b..355594588924 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -562,9 +562,7 @@ def forward( # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - assert not self.q_proj.weight.isnan().any(), self.q_proj.weight - assert not query_states.isnan().any(), query_states if sp_mode == "ring_attn": attn_output = RingAttention.attention( query_states, diff --git a/examples/language/opt/opt_benchmark.py b/examples/language/opt/opt_benchmark.py index 5e5971d9f560..ca9b63d1a14a 100755 --- a/examples/language/opt/opt_benchmark.py +++ b/examples/language/opt/opt_benchmark.py @@ -96,6 +96,7 @@ def main(): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) model, optimizer, _, _, _ = booster.boost(model, optimizer) + SEQ_LEN = 1024 VOCAB_SIZE = 50257 diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 1c7647a7d560..05d780614d57 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -19,7 +19,13 @@ @parameterize("dtype", [torch.bfloat16, torch.float16]) def check_ring_attn(seq_len, bs, nheads, d, dtype): torch.cuda.manual_seed(2) +<<<<<<< HEAD device = get_current_device() +======= + rank = dist.get_rank() + dist.get_world_size() + device = torch.device(f"cuda:{rank}") +>>>>>>> remove buffer clone; support packed seq layout sp_group = dist.group.WORLD sp_size = dist.get_world_size() # Some outliers may seem large, but our errors are still lower than @@ -36,6 +42,7 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): q.requires_grad = k.requires_grad = v.requires_grad = True # Ring attention vs single GPU +<<<<<<< HEAD ring_out, ring_lse = RingAttention.attention( q, k, @@ -47,14 +54,24 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): # inner_ring_size=4 ) ring_out = ring_out.transpose(1, 2) +======= + ring_out, ring_lse = RingAttention.attention(q, k, v, sp_group, sp_stream, AttnMaskType.CAUSAL, return_softmax=True) +>>>>>>> remove buffer clone; support packed seq layout out, lse, _ = flash_attn_qkvpacked_func( qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True ) +<<<<<<< HEAD # Checkout out and softmax denominator local_out = split_batch_zigzag(out, sp_group) local_lse = split_batch_zigzag(lse, sp_group, seq_dim=-1) local_lse = local_lse.transpose(1, 2).contiguous().view(-1, ring_lse.shape[-1]) # (B, nHeads, Sq) -> (T, nHeads) +======= + local_out = zigzag_split_batch(out, sp_group) + local_lse = zigzag_split_batch(lse, sp_group, seq_dim=-1) + local_lse = local_lse.transpose(1, 2).contiguous().view(-1, ring_lse.shape[-1]) # (B, nHeads, Sq) -> (T, nHeads) + assert_close(ring_out, local_out, atol=atol, rtol=rtol) +>>>>>>> remove buffer clone; support packed seq layout assert_close(ring_lse, local_lse, atol=atol, rtol=rtol) assert_close(ring_out, local_out, atol=atol, rtol=rtol) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 1c7b9321583f..6ceab6c4e9b3 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -443,4 +443,4 @@ def test_llama_3d(): if __name__ == "__main__": test_llama() - # test_llama_3d() + test_llama_3d() From ec4fab704bea5c5c65f301d88a747f6040563592 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 24 Jul 2024 13:54:54 +0000 Subject: [PATCH 16/71] add varlen tests --- colossalai/shardformer/layer/attn.py | 3 +- .../test_layer/test_ring_attn.py | 20 +- .../test_model/test_shard_llama.py | 182 ++++-------------- 3 files changed, 39 insertions(+), 166 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 9029ed716f6a..1a9fb492418e 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -16,7 +16,7 @@ from .utils import RingComm, get_half_index, split_varlen_zigzag -from .utils import RingComm +from .utils import RingComm, split_varlen_zigzag __all__ = [ "AttnMaskType", @@ -468,7 +468,6 @@ def attention( """ Ring Attention forward pass supporting variable-length sequences. When using varlen mode, each sequence in the batch should have length divisible by sp_size * 2. - Args: q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D] k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D] diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 05d780614d57..a2e69adb3a8c 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -19,13 +19,7 @@ @parameterize("dtype", [torch.bfloat16, torch.float16]) def check_ring_attn(seq_len, bs, nheads, d, dtype): torch.cuda.manual_seed(2) -<<<<<<< HEAD device = get_current_device() -======= - rank = dist.get_rank() - dist.get_world_size() - device = torch.device(f"cuda:{rank}") ->>>>>>> remove buffer clone; support packed seq layout sp_group = dist.group.WORLD sp_size = dist.get_world_size() # Some outliers may seem large, but our errors are still lower than @@ -42,7 +36,6 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): q.requires_grad = k.requires_grad = v.requires_grad = True # Ring attention vs single GPU -<<<<<<< HEAD ring_out, ring_lse = RingAttention.attention( q, k, @@ -54,24 +47,14 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): # inner_ring_size=4 ) ring_out = ring_out.transpose(1, 2) -======= - ring_out, ring_lse = RingAttention.attention(q, k, v, sp_group, sp_stream, AttnMaskType.CAUSAL, return_softmax=True) ->>>>>>> remove buffer clone; support packed seq layout out, lse, _ = flash_attn_qkvpacked_func( qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True ) -<<<<<<< HEAD # Checkout out and softmax denominator local_out = split_batch_zigzag(out, sp_group) local_lse = split_batch_zigzag(lse, sp_group, seq_dim=-1) local_lse = local_lse.transpose(1, 2).contiguous().view(-1, ring_lse.shape[-1]) # (B, nHeads, Sq) -> (T, nHeads) -======= - local_out = zigzag_split_batch(out, sp_group) - local_lse = zigzag_split_batch(lse, sp_group, seq_dim=-1) - local_lse = local_lse.transpose(1, 2).contiguous().view(-1, ring_lse.shape[-1]) # (B, nHeads, Sq) -> (T, nHeads) - assert_close(ring_out, local_out, atol=atol, rtol=rtol) ->>>>>>> remove buffer clone; support packed seq layout assert_close(ring_lse, local_lse, atol=atol, rtol=rtol) assert_close(ring_out, local_out, atol=atol, rtol=rtol) @@ -183,7 +166,8 @@ def launch_single_ring(rank, world_size, port): def launch_double_ring(rank, world_size, port): colossalai.launch(rank, world_size, "localhost", port) - check_ring_attn() + # check_ring_attn() + check_packed_seq() @rerun_if_address_is_in_use() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 6ceab6c4e9b3..19bd9018ea56 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -59,14 +59,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if ( booster.plugin.zero_stage in [1, 2] and booster.plugin.shard_config.enable_sequence_parallelism - and booster.plugin.shard_config.pipeline_stage_manager is None and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): - master2working = sharded_optimizer.get_master_to_working_map() - for (name, p1), p2 in zip( - llama_model.named_parameters(), sharded_optimizer._master_param_groups_of_current_rank[0] - ): - working_p = master2working[id(p2)] + for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): + working_p = sharded_optimizer.master_to_working_param[id(p2)] grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p)) grad_index = ( 0 @@ -75,10 +71,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) grad = grads[grad_index] sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] - try: - assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) - except Exception as e: - raise RuntimeError(f"Failed to check grad for {name}") from e + assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} @@ -119,184 +112,88 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 5e-3 if org_model.__class__.__name__ == "LlamaModel": - check_output_hidden_state( - org_output, - sharded_output, - stage_manager, - atol=atol, - rtol=rtol, - shard_config=booster.plugin.shard_config, - ) + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + # check weights if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): if test_config["precision"] == "fp32": atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_weight( - llama_model, - shard_llama_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False, - ) + try: + check_weight( + llama_model, + shard_llama_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) + except Exception as e: + print(f"Failed config: {test_config}") + raise e # check grads check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() @parameterize( "test_config", [ -<<<<<<< HEAD - # Double Ring Attention - # Zigzag Ring Attention -======= - # Zigzag Ring Attention + PP - { - "tp_size": 1, - "pp_size": 2, - "sp_size": 2, - "num_microbatches": 2, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring_attn", - "use_lazy_init": True, - "zero_stage": 1, - "precision": "bf16", - "initial_scale": 1, - }, - # Ring Attention + TP ->>>>>>> precision tests passed - { - "tp_size": 2, - "pp_size": 1, - "sp_size": 4, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring_attn", - "use_lazy_init": True, -<<<<<<< HEAD - "zero_stage": 0, - "precision": "fp16", - "initial_scale": 1, - "inner_ring_size": 2, -======= - "zero_stage": 1, - "precision": "bf16", - "initial_scale": 1, ->>>>>>> precision tests passed - }, - # Ring Attention + PP - { - { # Ulysess + TP + { # Test ring + Flash attention "tp_size": 2, "pp_size": 1, "sp_size": 2, "num_microbatches": 1, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "enable_all_optimization": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": True, "use_lazy_init": True, - "zero_stage": 0, + "zero_stage": 2, "precision": "fp16", "initial_scale": 1, }, - { # Ulysess + PP + { # Ulysess + Flash attention "tp_size": 1, "pp_size": 2, "sp_size": 2, "num_microbatches": 2, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring_attn", "sequence_parallelism_mode": "all_to_all", - "enable_all_optimization": True, + "enable_flash_attention": True, "use_lazy_init": True, "zero_stage": 1, "precision": "fp16", "initial_scale": 1, }, - # Ring Attention + TP - { { - "tp_size": 4, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, -<<<<<<< HEAD - "sequence_parallelism_mode": "ring_attn", -======= - "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": True, ->>>>>>> precision tests passed - "use_lazy_init": True, - "precision": "fp16", - "initial_scale": 1, -<<<<<<< HEAD - "parallel_output": False, - }, - { # Ulysess + TP - "tp_size": 2, + "tp_size": 1, "pp_size": 1, "sp_size": 2, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "all_to_all", - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 0, - "precision": "fp16", - "initial_scale": 1, - }, - { # Ulysess + PP - "tp_size": 1, - "pp_size": 2, - "sp_size": 2, - "num_microbatches": 2, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 0, - "precision": "fp16", - "initial_scale": 1, - "parallel_output": False, -======= ->>>>>>> precision tests passed - }, - { - "tp_size": 2, - "pp_size": 1, - "sp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring", - "enable_flash_attention": True, "use_lazy_init": True, - "zero_stage": 2, + "zero_stage": 1, "precision": "fp16", "initial_scale": 1, -<<<<<<< HEAD }, { - "tp_size": 2, + "tp_size": 4, "pp_size": 1, - "sp_size": 1, "num_microbatches": 1, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring", - "enable_flash_attention": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, "use_lazy_init": True, - "zero_stage": 2, "precision": "fp16", "initial_scale": 1, - "parallel_output": False, -======= ->>>>>>> precision tests passed }, { "tp_size": 2, @@ -341,21 +238,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) def run_llama_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name: - continue -<<<<<<< HEAD -<<<<<<< HEAD -======= - ->>>>>>> precision tests passed -======= ->>>>>>> precision tests passed try: check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) except Exception as e: - print(f"Failed config: {test_config}, model name: {name}") + print(f"Failed config: {test_config}") raise e + clear_layout_converter() Randomizer.reset_index() torch.cuda.empty_cache() @@ -443,4 +333,4 @@ def test_llama_3d(): if __name__ == "__main__": test_llama() - test_llama_3d() + test_llama_3d() \ No newline at end of file From cd9349e1ca817e4bcae8b923b4d93e6c35f76b57 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 26 Jul 2024 10:00:09 +0000 Subject: [PATCH 17/71] fix typo --- colossalai/shardformer/layer/attn.py | 8 +++++--- tests/test_shardformer/test_layer/test_ring_attn.py | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 1a9fb492418e..191bfe8eb3ca 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -16,7 +16,7 @@ from .utils import RingComm, get_half_index, split_varlen_zigzag -from .utils import RingComm, split_varlen_zigzag +from .utils import RingComm, get_half_index, split_varlen_zigzag __all__ = [ "AttnMaskType", @@ -345,7 +345,7 @@ def _rescale_out_lse(out, block_out, lse, block_lse): # min_scale = torch.min(lse, block_lse) # max_scale = torch.max(lse, block_lse) - # new_lse = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) + # lse.data = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) # NOTE: directly assigning to .data here is buggy # probably due to casting dtypes/strides @@ -618,6 +618,9 @@ def forward( if is_packed: t, h, d = q.shape + # half of each seq + half_idx_front = get_half_index(cu_seqlens, front=True) + half_idx_back = get_half_index(cu_seqlens, front=False) else: b, sq, h, d = q.shape t = b * sq @@ -858,7 +861,6 @@ def backward(ctx, dout, _): cu_seqlens_half = cu_seqlens_q // 2 max_seqlen_half = max_seqlen_q // 2 misc_kwargs = ctx.misc_kwargs - is_packed = ctx.is_packed dout = dout.contiguous() del misc_kwargs["block_table"] diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index a2e69adb3a8c..982c376178dd 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -166,8 +166,8 @@ def launch_single_ring(rank, world_size, port): def launch_double_ring(rank, world_size, port): colossalai.launch(rank, world_size, "localhost", port) - # check_ring_attn() - check_packed_seq() + # check_packed_seq() + check_ring_attn() @rerun_if_address_is_in_use() From 25d3e387af3dede1e756dd03ffd3b7eff1ea1ab7 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 1 Aug 2024 03:38:49 +0000 Subject: [PATCH 18/71] all tests passed --- colossalai/shardformer/layer/attn.py | 11 +++++++++-- colossalai/shardformer/layer/loss.py | 13 ++++++++++--- colossalai/shardformer/layer/utils.py | 7 +++++-- colossalai/shardformer/modeling/llama.py | 5 ++++- tests/test_shardformer/test_layer/test_ring_attn.py | 2 +- 5 files changed, 29 insertions(+), 9 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 191bfe8eb3ca..dc455c7f461c 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -172,8 +172,10 @@ def prepare_attn_kwargs( # self attention kv_padding_mask = q_padding_mask max_seqlen_kv, cu_seqlens_kv, kv_indices = max_seqlen_q, cu_seqlens_q, q_indices + attention_mask = q_padding_mask[:, :, None].expand(b, s_q, s_kv).to(dtype=dtype, device=device) else: max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask) + attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) assert kv_padding_mask.shape == ( b, s_kv, @@ -345,7 +347,7 @@ def _rescale_out_lse(out, block_out, lse, block_lse): # min_scale = torch.min(lse, block_lse) # max_scale = torch.max(lse, block_lse) - # lse.data = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) + # new_lse = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) # NOTE: directly assigning to .data here is buggy # probably due to casting dtypes/strides @@ -621,6 +623,11 @@ def forward( # half of each seq half_idx_front = get_half_index(cu_seqlens, front=True) half_idx_back = get_half_index(cu_seqlens, front=False) + RingAttention.HALF_INDICES = (half_idx_front, half_idx_back) + RingAttention.CU_SEQLENS = cu_seqlens + + if is_packed: + t, h, d = q.shape else: b, sq, h, d = q.shape t = b * sq @@ -648,7 +655,7 @@ def forward( # Pre-allocate double buffer for overlapping and receiving next step's inputs kv_buffers = [torch.stack((k, v))] # (2, B, Sq, H, D) - kv_buffers.append(None) + kv_buffers.append(torch.empty_like(kv_buffers[0])) # outputs out = None diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 0583bcd9375d..e69c644d66b5 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -1,6 +1,5 @@ import torch import torch.distributed as dist -import torch.nn.functional as F from torch.autograd import Function from torch.distributed import ProcessGroup from torch.nn import CrossEntropyLoss @@ -213,7 +212,12 @@ def dist_cross_entropy( labels = labels.contiguous() logits = logits.contiguous() num_nonzero = (labels != _IGNORE_IDX).sum() - assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}" + try: + assert ( + labels.shape == logits.shape[:-1] + ), f"label shape {labels.shape} does not match logit shape {logits.shape}" + except: + pass # Flatten the tokens loss_fct = CrossEntropyLoss(ignore_index=_IGNORE_IDX, reduction="sum") @@ -234,7 +238,10 @@ def dist_cross_entropy( else: # NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D logits = logits.view(-1, vocab_size) - loss = loss_fct(logits, labels) + try: + loss = loss_fct(logits, labels) + except: + pass # Reduce loss instead of gathering logits over seq dim for savings if split_labels_here or sp_mode == "ring_attn": diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 53d9576894f6..4262b8d2342b 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -331,7 +331,7 @@ def split_batch_zigzag( indices = torch.tensor([sp_rank, 2 * sp_size - 1 - sp_rank], device=tensor.device) tensor = tensor.index_select(seq_dim, indices).contiguous() # (B, 2, Sq // (2 * sp_size), ...) -> (B, Sq // sp_size, ...) - batch[idx] = tensor.view(*tensor.shape[:seq_dim], -1, *tensor.shape[seq_dim + 2 :]).contiguous() + batch[idx] = tensor.view(*tensor.shape[:seq_dim], -1, *tensor.shape[seq_dim + 2 :]) if len(batch) == 1: return batch[0] @@ -377,7 +377,10 @@ def split_varlen_zigzag( assert max_seqlen % (sp_size * 2) == 0 # Recreate a padded tensor with the new max seqlen shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:]) - local_seq = torch.zeros(shape, dtype=dtype, device=device) + if is_label: + local_seq = torch.full(shape, -100, dtype=dtype, device=device) + else: + local_seq = torch.zeros(shape, dtype=dtype, device=device) else: total_seqlen = cu_seqlens[-1] assert ( diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 355594588924..a694b30312e0 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -553,7 +553,10 @@ def forward( kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + try: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + except: + pass if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 982c376178dd..5ca618bc8535 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -166,7 +166,7 @@ def launch_single_ring(rank, world_size, port): def launch_double_ring(rank, world_size, port): colossalai.launch(rank, world_size, "localhost", port) - # check_packed_seq() + check_packed_seq() check_ring_attn() From 1234d997cb21166a256750d721cee973747ec99f Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 1 Aug 2024 08:25:47 +0000 Subject: [PATCH 19/71] add dkv_group; fix mask --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 8 ++++++++ colossalai/pipeline/schedule/interleaved_pp.py | 2 -- colossalai/shardformer/layer/attn.py | 4 ++-- colossalai/shardformer/layer/loss.py | 12 ++---------- colossalai/shardformer/shard/shard_config.py | 1 + 5 files changed, 13 insertions(+), 14 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 63427192f482..d233ccc2ae15 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1135,6 +1135,14 @@ def __init__( self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) else: self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) + # According to https://github.com/InternLM/InternEvo/blob/a53a4ff4fc45761f80d7fe8e9188bc2e02d487fc/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py#L405 + # and https://zhuanlan.zhihu.com/p/706805407 + # using a different proc group may put p2p comm on a new + # NCCL stream :) + dkv_group = None + if sequence_parallelism_mode == "ring_attn": + sp_ranks = dist.get_process_group_ranks(self.sp_group) + dkv_group = dist.new_group(ranks=sp_ranks) self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 8f26f8cb5bb5..412f3896fb80 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -283,8 +283,6 @@ def forward_step( # Load input ids, attention mask and labels micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) - if input_obj is not None: - assert all(not x.isnan().any() for x in input_obj.values()), "NaN detected in input_obj" # for the first stage, input_obj is None # for other stages, input_obj is the output of the previous stage containing hidden_states etc. # Only attention_mask from micro_batch is used diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index dc455c7f461c..5027bc5c5194 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -172,10 +172,9 @@ def prepare_attn_kwargs( # self attention kv_padding_mask = q_padding_mask max_seqlen_kv, cu_seqlens_kv, kv_indices = max_seqlen_q, cu_seqlens_q, q_indices - attention_mask = q_padding_mask[:, :, None].expand(b, s_q, s_kv).to(dtype=dtype, device=device) else: max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask) - attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) + attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) assert kv_padding_mask.shape == ( b, s_kv, @@ -833,6 +832,7 @@ def _other_ring_forward(ring_num_idx, out, softmax_lse): del misc_kwargs["return_softmax"] ctx.misc_kwargs = misc_kwargs ctx.is_packed = is_packed + ctx.dkv_group = dkv_group ctx.kv_group = inner_ring_group ctx.inter_kv_group = inter_ring_group diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index e69c644d66b5..12df824d1c0c 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -212,12 +212,7 @@ def dist_cross_entropy( labels = labels.contiguous() logits = logits.contiguous() num_nonzero = (labels != _IGNORE_IDX).sum() - try: - assert ( - labels.shape == logits.shape[:-1] - ), f"label shape {labels.shape} does not match logit shape {logits.shape}" - except: - pass + assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}" # Flatten the tokens loss_fct = CrossEntropyLoss(ignore_index=_IGNORE_IDX, reduction="sum") @@ -238,10 +233,7 @@ def dist_cross_entropy( else: # NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D logits = logits.view(-1, vocab_size) - try: - loss = loss_fct(logits, labels) - except: - pass + loss = loss_fct(logits, labels) # Reduce loss instead of gathering logits over seq dim for savings if split_labels_here or sp_mode == "ring_attn": diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 589ed730ec79..084c818e18b9 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -56,6 +56,7 @@ class ShardConfig: moe_dp_group: Optional[ProcessGroup] = None ep_group: Optional[ProcessGroup] = None sp_stream: Optional[torch.cuda.Stream] = None + dkv_group: Optional[ProcessGroup] = None # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] From f6a8f128818d18d8bcf72b319951934403b7659d Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 1 Aug 2024 09:37:52 +0000 Subject: [PATCH 20/71] remove debug statements --- colossalai/shardformer/modeling/llama.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index a694b30312e0..355594588924 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -553,10 +553,7 @@ def forward( kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids) - try: - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - except: - pass + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} From c0b7e96fb9bd874d3c96393f4d1beb782d59d027 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 1 Aug 2024 13:57:23 +0000 Subject: [PATCH 21/71] adapt chatglm, command-R, qwen --- colossalai/shardformer/layer/_operation.py | 4 +- colossalai/shardformer/modeling/chatglm2.py | 56 ++++++------------- colossalai/shardformer/modeling/command.py | 36 +++++------- colossalai/shardformer/modeling/qwen2.py | 35 ++++-------- .../test_model/test_shard_command.py | 9 ++- 5 files changed, 51 insertions(+), 89 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 25983e0a93a6..efe4d80babbb 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -999,11 +999,11 @@ def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1): return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) -def gather_sp_output(hidden_states, sp_group, sp_mode): +def gather_sp_output(hidden_states, sp_group, sp_mode, sp_dim=1): """ Gather the output of the last layer for cross entropy computation """ # Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group) scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group) - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=scale) + hidden_states = gather_forward_split_backward(hidden_states, sp_dim, sp_group, grad_scale=scale) return hidden_states diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index 34d900d8de94..14fd48b19bc5 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -4,7 +4,6 @@ import torch import torch.utils.checkpoint -from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.utils import logging @@ -13,10 +12,13 @@ from colossalai.shardformer.layer import AttnMaskType, ColoAttention from colossalai.shardformer.layer._operation import ( all_to_all_comm, - gather_forward_split_backward, + gather_sp_output, + is_share_sp_tp, split_forward_gather_backward, ) +from ..layer import dist_cross_entropy + def get_flash_core_attention_forward(): from .chatglm2_6b.modeling_chatglm import CoreAttention @@ -138,6 +140,7 @@ def chatglm_model_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, + force_sp_output_gather: Optional[bool] = True, ): logger = logging.get_logger(__name__) output_hidden_states = ( @@ -239,20 +242,10 @@ def chatglm_model_forward( if use_cache: presents = presents + (kv_cache,) - if shard_config and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode == "split_gather": - hidden_states = gather_forward_split_backward( - hidden_states, - dim=0, - process_group=shard_config.tensor_parallel_process_group, - ) - elif shard_config.sequence_parallelism_mode == "all_to_all": - hidden_states = gather_forward_split_backward( - hidden_states, - dim=0, - process_group=shard_config.sequence_parallel_process_group, - grad_scale=shard_config.sequence_parallel_size, - ) + if shard_config: + sp_mode = shard_config.sequence_parallelism_mode + if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + hidden_states = gather_sp_output(hidden_states, shard_config.sequence_parallel_process_group, sp_mode) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if stage_manager.is_last_stage(): @@ -315,6 +308,7 @@ def chatglm_for_conditional_generation_forward( hidden_states=hidden_states, stage_index=stage_index, shard_config=shard_config, + force_sp_output_gather=False, ) if stage_manager.is_last_stage(): hidden_states = transformer_outputs[0] @@ -322,17 +316,9 @@ def chatglm_for_conditional_generation_forward( hidden_states = hidden_states[-1:] lm_logits = self.transformer.output_layer(hidden_states) lm_logits = lm_logits.transpose(0, 1).contiguous() - loss = None - if labels is not None: - lm_logits = lm_logits.to(torch.float32) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss(ignore_index=-100) - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - lm_logits = lm_logits.to(hidden_states.dtype) - loss = loss.to(hidden_states.dtype) + loss = dist_cross_entropy( + labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, lm_logits.dtype + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output @@ -361,6 +347,7 @@ def forward( use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + force_sp_output_gather: Optional[bool] = True, ): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -431,19 +418,8 @@ def forward( output_hidden_states=output_hidden_states, ) - if sp_mode in ["split_gather"]: - hidden_states = gather_forward_split_backward( - hidden_states, - dim=0, - process_group=shard_config.tensor_parallel_process_group, - ) - elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward( - hidden_states, - dim=0, - process_group=sp_group, - grad_scale=sp_size, - ) + if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + hidden_states = gather_sp_output(hidden_states, shard_config.sequence_parallel_process_group, sp_mode) if not return_dict: return tuple( diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 67c20eed8194..132e1576fd56 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -17,14 +17,11 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.layer._operation import ( - all_to_all_comm, - gather_forward_split_backward, - split_forward_gather_backward, -) +from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig from ..layer import ColoAttention, dist_cross_entropy +from ..layer._operation import gather_sp_output, is_share_sp_tp _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] @@ -52,6 +49,7 @@ def command_model_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, + force_sp_output_gather: bool = True, ): logger = logging.get_logger(__name__) @@ -207,20 +205,10 @@ def command_model_forward( if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) - if shard_config and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: - hidden_states = gather_forward_split_backward( - hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group, - ) - elif shard_config.sequence_parallelism_mode == "all_to_all": - hidden_states = gather_forward_split_backward( - hidden_states, - dim=1, - process_group=shard_config.sequence_parallel_process_group, - grad_scale=shard_config.sequence_parallel_size, - ) + if shard_config: + sp_mode = shard_config.sequence_parallelism_mode + if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + hidden_states = gather_sp_output(hidden_states, shard_config.sequence_parallel_process_group, sp_mode) # add hidden states from the last decoder layer if output_hidden_states: @@ -323,6 +311,7 @@ def command_for_causal_lm_forward( hidden_states=hidden_states, stage_index=stage_index, shard_config=shard_config, + force_sp_output_gather=False, ) past_key_values = None @@ -476,6 +465,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + force_sp_output_gather: bool = True, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -574,10 +564,9 @@ def forward( hidden_states = self.norm(hidden_states) - if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) - elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + # Cases that don't support parallelizing cross entropy computation along sequence + if shard_config and (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather: + hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) # add hidden states from the last decoder layer if output_hidden_states: @@ -662,6 +651,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + force_sp_output_gather=False, ) hidden_states = outputs[0] diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 538e96c32c6d..b2cc6b601391 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -32,14 +32,11 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.layer._operation import ( - all_to_all_comm, - gather_forward_split_backward, - split_forward_gather_backward, -) +from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig from ..layer import ColoAttention, dist_cross_entropy +from ..layer._operation import gather_sp_output, is_share_sp_tp class Qwen2PipelineForwards: @@ -64,6 +61,7 @@ def qwen2_model_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, + force_sp_output_gather: bool = True, ) -> Union[Tuple, BaseModelOutputWithPast]: logger = logging.get_logger(__name__) @@ -240,20 +238,10 @@ def qwen2_model_forward( if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) - if shard_config and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: - hidden_states = gather_forward_split_backward( - hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group, - ) - elif shard_config.sequence_parallelism_mode == "all_to_all": - hidden_states = gather_forward_split_backward( - hidden_states, - dim=1, - process_group=shard_config.sequence_parallel_process_group, - grad_scale=shard_config.sequence_parallel_size, - ) + if shard_config: + sp_mode = shard_config.sequence_parallelism_mode + if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + hidden_states = gather_sp_output(hidden_states, shard_config.sequence_parallel_process_group, sp_mode) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) @@ -347,6 +335,7 @@ def qwen2_for_causal_lm_forward( hidden_states=hidden_states, stage_index=stage_index, shard_config=shard_config, + force_sp_output_gather=False, ) past_key_values = None @@ -629,6 +618,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + force_sp_output_gather: bool = True, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -740,10 +730,8 @@ def forward( hidden_states = self.norm(hidden_states) - if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) - elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + hidden_states = gather_sp_output(hidden_states, shard_config.sequence_parallel_process_group, sp_mode) # add hidden states from the last decoder layer if output_hidden_states: @@ -820,6 +808,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + force_sp_output_gather=False, ) hidden_states = outputs[0] diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py index efe5cee2a2b6..2e6997597928 100644 --- a/tests/test_shardformer/test_model/test_shard_command.py +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -125,7 +125,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 5e-3 if org_model.__class__.__name__ == "CohereModel": - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + check_output_hidden_state( + org_output, + sharded_output, + stage_manager, + atol=atol, + rtol=rtol, + shard_config=booster.plugin.shard_config, + ) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) From 0eb6fdf07f4934e6f0826582363cdd48c932aa53 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 5 Aug 2024 05:38:09 +0000 Subject: [PATCH 22/71] debug --- colossalai/shardformer/layer/loss.py | 7 ++- colossalai/shardformer/modeling/chatglm2.py | 41 +++++++------- colossalai/shardformer/modeling/command.py | 6 +-- colossalai/shardformer/modeling/llama.py | 4 +- colossalai/shardformer/modeling/qwen2.py | 54 +++++++++++-------- .../test_model/test_shard_qwen2.py | 52 +++++++++--------- 6 files changed, 91 insertions(+), 73 deletions(-) diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 12df824d1c0c..8e7a4b3a073a 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -212,7 +212,12 @@ def dist_cross_entropy( labels = labels.contiguous() logits = logits.contiguous() num_nonzero = (labels != _IGNORE_IDX).sum() - assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}" + try: + assert ( + labels.shape == logits.shape[:-1] + ), f"label shape {labels.shape} does not match logit shape {logits.shape}" + except Exception as e: + raise e # Flatten the tokens loss_fct = CrossEntropyLoss(ignore_index=_IGNORE_IDX, reduction="sum") diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index 14fd48b19bc5..5be4b9d78e11 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -203,20 +203,23 @@ def chatglm_model_forward( all_hidden_states = () if output_hidden_states else None start_idx, end_idx = stage_index[0], stage_index[1] - if shard_config and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode == "split_gather": - hidden_states = split_forward_gather_backward( - hidden_states, - dim=0, - process_group=shard_config.tensor_parallel_process_group, - ) - elif shard_config.sequence_parallelism_mode == "all_to_all": - hidden_states = split_forward_gather_backward( - hidden_states, - dim=0, - process_group=shard_config.sequence_parallel_process_group, - grad_scale=1 / shard_config.sequence_parallel_size, - ) + # Keep the input split across all PP stages + if stage_manager.is_first_stage(): + if shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode == "split_gather": + hidden_states = split_forward_gather_backward( + hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group, + ) + elif shard_config.sequence_parallelism_mode == "all_to_all": + hidden_states = split_forward_gather_backward( + hidden_states, + dim=0, + process_group=shard_config.sequence_parallel_process_group, + grad_scale=1 / shard_config.sequence_parallel_size, + ) + for idx in range(start_idx, end_idx): layer = self.encoder._get_layer(idx) if output_hidden_states: @@ -242,16 +245,18 @@ def chatglm_model_forward( if use_cache: presents = presents + (kv_cache,) - if shard_config: - sp_mode = shard_config.sequence_parallelism_mode - if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): - hidden_states = gather_sp_output(hidden_states, shard_config.sequence_parallel_process_group, sp_mode) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if stage_manager.is_last_stage(): # final layer_norm if self.encoder.post_layer_norm: hidden_states = self.encoder.final_layernorm(hidden_states) + + # Gather seq-wise in the final output stage + sp_mode = shard_config.sequence_parallelism_mode + if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + hidden_states = gather_sp_output(hidden_states, shard_config.sequence_parallel_process_group, sp_mode) + if not return_dict: return tuple( v diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 132e1576fd56..62299bf0656e 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -23,7 +23,7 @@ from ..layer import ColoAttention, dist_cross_entropy from ..layer._operation import gather_sp_output, is_share_sp_tp -_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] +_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring"] class CommandPipelineForwards: @@ -134,7 +134,7 @@ def command_model_forward( ) use_cache = False - if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.enable_sequence_parallelism: if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: hidden_states = split_forward_gather_backward( hidden_states, @@ -204,8 +204,6 @@ def command_model_forward( if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) - - if shard_config: sp_mode = shard_config.sequence_parallelism_mode if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): hidden_states = gather_sp_output(hidden_states, shard_config.sequence_parallel_process_group, sp_mode) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 355594588924..219933c705e9 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -101,8 +101,8 @@ def llama_model_forward( sp_mode = shard_config.sequence_parallelism_mode sp_group = shard_config.sequence_parallel_process_group sp_size = shard_config.sequence_parallel_size - if sp_mode == "all_to_all" and not stage_manager.is_first_stage(): - # For generating full positions ids, as the states will be gather along the seq dim in the attention layer later. + # For generating full positions ids (the states will be gathered along the seq dim before attention fwd). + if sp_mode != "ring_attn" and not stage_manager.is_first_stage(): seq_length *= sp_size past_seen_tokens = 0 diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index b2cc6b601391..d44c7382fdf6 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -36,7 +36,8 @@ from colossalai.shardformer.shard import ShardConfig from ..layer import ColoAttention, dist_cross_entropy -from ..layer._operation import gather_sp_output, is_share_sp_tp +from ..layer._operation import gather_sp_output +from ..layer.utils import is_share_sp_tp class Qwen2PipelineForwards: @@ -113,6 +114,14 @@ def qwen2_model_forward( past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length + # Support SP + PP + sp_size = shard_config.sequence_parallel_size + sp_group = shard_config.sequence_parallel_process_group + sp_mode = shard_config.sequence_parallelism_mode + # For generating full positions ids (the states will be gathered along the seq dim before attention fwd). + if sp_mode != "ring_attn" and not stage_manager.is_first_stage(): + seq_length *= sp_size + if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( @@ -149,7 +158,6 @@ def qwen2_model_forward( elif self._attn_implementation == "sdpa" and not output_attentions: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), @@ -158,7 +166,6 @@ def qwen2_model_forward( ) else: # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), @@ -167,20 +174,21 @@ def qwen2_model_forward( sliding_window=self.config.sliding_window, ) - if shard_config and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: - hidden_states = split_forward_gather_backward( - hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group, - ) - elif shard_config.sequence_parallelism_mode == "all_to_all": - hidden_states = split_forward_gather_backward( - hidden_states, - dim=1, - process_group=shard_config.sequence_parallel_process_group, - grad_scale=1 / shard_config.sequence_parallel_size, - ) + if stage_manager.is_first_stage(): + if shard_config.enable_sequence_parallelism: + if is_share_sp_tp(sp_mode): + hidden_states = split_forward_gather_backward( + hidden_states, + dim=1, + process_group=sp_group, + ) + elif sp_mode == "all_to_all": + hidden_states = split_forward_gather_backward( + hidden_states, + dim=1, + process_group=sp_group, + grad_scale=1 / sp_size, + ) # decoder layers all_hidden_states = () if output_hidden_states else None @@ -237,11 +245,9 @@ def qwen2_model_forward( if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) - - if shard_config: - sp_mode = shard_config.sequence_parallelism_mode if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): hidden_states = gather_sp_output(hidden_states, shard_config.sequence_parallel_process_group, sp_mode) + # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) @@ -341,6 +347,8 @@ def qwen2_for_causal_lm_forward( if stage_manager.is_last_stage(): hidden_states = outputs[0] + if hidden_states.shape[1] == 2: + pass logits = self.lm_head(hidden_states) loss = dist_cross_entropy( labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype @@ -526,8 +534,10 @@ def forward( # Because the input can be padded, the absolute sequence length depends on the max position id. rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + try: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + except Exception as e: + raise e if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute diff --git a/tests/test_shardformer/test_model/test_shard_qwen2.py b/tests/test_shardformer/test_model/test_shard_qwen2.py index c87415b7562d..865563adc625 100644 --- a/tests/test_shardformer/test_model/test_shard_qwen2.py +++ b/tests/test_shardformer/test_model/test_shard_qwen2.py @@ -94,6 +94,32 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 2, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + { # Ulysess + Flash attention + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2, @@ -135,32 +161,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, - { # Ulysess + Flash attention - "tp_size": 1, - "pp_size": 2, - "sp_size": 2, - "num_microbatches": 2, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "enable_flash_attention": True, - "use_lazy_init": True, - "zero_stage": 1, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 2, - "pp_size": 2, - "sp_size": 2, - "num_microbatches": 2, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": True, - "use_lazy_init": True, - "zero_stage": 1, - "precision": "fp16", - "initial_scale": 1, - }, { "tp_size": 2, "pp_size": 2, From 2fae7945ccadb37b9dcfd1b33450caa6167e523a Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 28 Jun 2024 07:42:56 +0000 Subject: [PATCH 23/71] halfway --- colossalai/shardformer/layer/_operation.py | 4 ++++ colossalai/shardformer/layer/attn.py | 10 +++++++++- colossalai/shardformer/modeling/llama.py | 8 ++++++++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index efe4d80babbb..e031fecc15e0 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -812,7 +812,11 @@ def backward(ctx, *grad_output): process_group = ctx.process_group scatter_dim = ctx.gather_dim gather_dim = ctx.scatter_dim + if torch.distributed.get_rank() == 0: + print(f"shape before A2A: {grad_output[0].shape}") return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) + if torch.distributed.get_rank() == 0: + print(f"shape after A2A: {return_grad.shape}") return (return_grad, None, None, None) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 5027bc5c5194..df29a5751d60 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -24,7 +24,10 @@ ] _flash_attn_forward = _flash_attn_backward = None +<<<<<<< HEAD _unpad_input = _pad_input = None +======= +>>>>>>> halfway class AttnMaskType(Enum): @@ -252,7 +255,12 @@ def attention( # sanity check if attention_mask is not None: assert torch.is_floating_point(attention_mask), "attention_mask should be a floating point tensor." - if attention_mask_type in (AttnMaskType.CUSTOM, AttnMaskType.CAUSAL): + if attention_mask_type in ( + AttnMaskType.CUSTOM, + AttnMaskType.CAUSAL, + AttnMaskType.PADDED, + AttnMaskType.PADDED_CAUSAL, + ): assert ( cu_seqlens_q is None and cu_seqlens_kv is None diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 219933c705e9..207ed40ff9b5 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -838,6 +838,14 @@ def forward( # [B, max_seq_len // sp_size] labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) + sp_mode = shard_config.sequence_parallelism_mode + sp_group = self.shard_config.sequence_parallel_process_group + assert not ( + shard_config.sp_mode == "ring_attn" and use_cache + ), "Ring attention requires q, k, v to have the same length and doesn't work for inference" + if sp_mode == "ring_attn": + inputs_embeds = ring_attn_split_forward(inputs_embeds, sp_group) + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, From 91bab8431226b8650ba25642e27351aa83f08969 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 28 Jun 2024 13:36:43 +0000 Subject: [PATCH 24/71] fix cross-PP-stage position id length diff bug --- tests/test_shardformer/test_model/test_shard_llama.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 19bd9018ea56..a20d17daa4e7 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -71,7 +71,11 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) grad = grads[grad_index] sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] - assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) + try: + assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) + except Exception as e: + print(f"Failed param name: {name}") + raise e # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} From c5c14b6f69fad1f16ec65f684a70c1f077e79e7f Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Sat, 29 Jun 2024 02:34:57 +0000 Subject: [PATCH 25/71] fix typo --- tests/test_shardformer/test_model/test_shard_llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index a20d17daa4e7..f4acf3ed5afc 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -59,6 +59,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if ( booster.plugin.zero_stage in [1, 2] and booster.plugin.shard_config.enable_sequence_parallelism + and booster.plugin.shard_config.pipeline_stage_manager is None and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): From a00c93bacdfa92e9f85cdc7b96b83641909a4be2 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Sat, 29 Jun 2024 07:39:53 +0000 Subject: [PATCH 26/71] fix typo --- tests/test_shardformer/test_model/test_shard_llama.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index f4acf3ed5afc..a7781508ce68 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -62,8 +62,11 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, and booster.plugin.shard_config.pipeline_stage_manager is None and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): - for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): - working_p = sharded_optimizer.master_to_working_param[id(p2)] + master2working = sharded_optimizer.get_master_to_working_map() + for (name, p1), p2 in zip( + llama_model.named_parameters(), sharded_optimizer._master_param_groups_of_current_rank[0] + ): + working_p = master2working[id(p2)] grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p)) grad_index = ( 0 From ab9b784ffb24d087dae4b8cbb9d87f318b21792b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 29 Jun 2024 07:40:57 +0000 Subject: [PATCH 27/71] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_shardformer/test_model/test_shard_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index a7781508ce68..530b8b5463a5 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -59,7 +59,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if ( booster.plugin.zero_stage in [1, 2] and booster.plugin.shard_config.enable_sequence_parallelism - and booster.plugin.shard_config.pipeline_stage_manager is None + and booster.plugin.shard_config.pipeline_stage_manager is None and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): master2working = sharded_optimizer.get_master_to_working_map() From 7d99bc073574565af3a15b0fde39c7749d0ff4bf Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 2 Jul 2024 09:35:42 +0000 Subject: [PATCH 28/71] unified cross entropy func for all shardformer models --- examples/language/opt/opt_benchmark.py | 1 + tests/test_shardformer/test_model/test_shard_llama.py | 10 ++-------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/examples/language/opt/opt_benchmark.py b/examples/language/opt/opt_benchmark.py index ca9b63d1a14a..90f41fe1f767 100755 --- a/examples/language/opt/opt_benchmark.py +++ b/examples/language/opt/opt_benchmark.py @@ -135,4 +135,5 @@ def main(): if __name__ == "__main__": + print("--------------------------------------") main() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 530b8b5463a5..1459d606966f 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -63,9 +63,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): master2working = sharded_optimizer.get_master_to_working_map() - for (name, p1), p2 in zip( - llama_model.named_parameters(), sharded_optimizer._master_param_groups_of_current_rank[0] - ): + for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): working_p = master2working[id(p2)] grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p)) grad_index = ( @@ -75,11 +73,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) grad = grads[grad_index] sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] - try: - assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) - except Exception as e: - print(f"Failed param name: {name}") - raise e + assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} From 423102a719e644b03bbac4184455d37842a4b253 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 2 Jul 2024 11:10:15 +0000 Subject: [PATCH 29/71] remove redundant lines --- examples/language/opt/opt_benchmark.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/language/opt/opt_benchmark.py b/examples/language/opt/opt_benchmark.py index 90f41fe1f767..ca9b63d1a14a 100755 --- a/examples/language/opt/opt_benchmark.py +++ b/examples/language/opt/opt_benchmark.py @@ -135,5 +135,4 @@ def main(): if __name__ == "__main__": - print("--------------------------------------") main() From 91fb3c10c981d996e4ada969b1d0a699f4f3124c Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 8 Jul 2024 02:03:40 +0000 Subject: [PATCH 30/71] add basic ring attn; debug cross entropy --- colossalai/pipeline/schedule/one_f_one_b.py | 3 + colossalai/shardformer/layer/_operation.py | 2 +- colossalai/shardformer/layer/attn.py | 55 ++++++++++++++++- colossalai/shardformer/layer/loss.py | 1 + colossalai/shardformer/layer/utils.py | 4 ++ colossalai/shardformer/modeling/llama.py | 16 +++-- colossalai/shardformer/shard/shard_config.py | 1 - .../test_model/test_shard_llama.py | 60 ++++++++++++------- 8 files changed, 111 insertions(+), 31 deletions(-) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 03df67ae78c3..4c8519030b1c 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -32,6 +32,7 @@ def __init__( num_microbatches: Optional[int] = None, microbatch_size: Optional[int] = None, enable_metadata_cache: bool = True, + shard_config=None, ) -> None: """1F1B pipeline schedule. @@ -39,6 +40,7 @@ def __init__( stage_manager (PipelineStageManager): Pipeline stage manager num_microbatches (Optional[int], optional): The number of microbatches. If not provided, it will be derived from microbatch size. Defaults to None. microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None. + shard_config: Shard configuration for gathering Sequence Parallel loss. """ super().__init__(stage_manager) assert ( @@ -53,6 +55,7 @@ def __init__( self.batch_size: Optional[int] = None self.last_batch_size: Optional[int] = None self.microbatch_offset: Optional[int] = None + self.shard_config = shard_config # P2PMeta cache self.enable_metadata_cache = enable_metadata_cache diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index e031fecc15e0..a9f1c5e2968a 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -651,7 +651,7 @@ def backward(ctx, grad_output): ).contiguous() handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have - # all-reduce scheduled first and have GPU resources allocated + # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py grad_weight = total_input.t().matmul(grad_output) grad_bias = grad_output.sum(dim=0) if use_bias else None diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index df29a5751d60..e2daa37cf9f6 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -24,10 +24,7 @@ ] _flash_attn_forward = _flash_attn_backward = None -<<<<<<< HEAD _unpad_input = _pad_input = None -======= ->>>>>>> halfway class AttnMaskType(Enum): @@ -371,6 +368,58 @@ def _rescale_out_lse(out, block_out, lse, block_lse): return out, lse +@triton.jit +def flash_attn_fwd_out_corr_triton( + out_ptr, out_per_step_ptr, seq_dim, softmax_lse_ptr, softmax_lse_per_step_ptr, BLOCK_SIZE: tl.constexpr +): + # Calculate the global id + pid = tl.program_id(0) + + # Offsets for the current row + offsets = tl.arange(0, BLOCK_SIZE) + + # Pointers to the current row in out and out_per_step + row_start = pid * seq_dim + out_ptrs = out_ptr + row_start + offsets + out_per_step_ptrs = out_per_step_ptr + row_start + offsets + + # Load softmax_lse and softmax_lse_per_step + softmax_lse = tl.load(softmax_lse_ptr + pid) + softmax_lse_per_step = tl.load(softmax_lse_per_step_ptr + pid) + + # Compute the corrected exponentiation + softmax_lse_corrected_exp = tl.exp(softmax_lse_per_step - softmax_lse) + + out_per_step_vals = tl.load(out_per_step_ptrs) + + # Correct the out_per_step by the exponentiation + out_corrected = out_per_step_vals * softmax_lse_corrected_exp + + # Load the current out values + out_vals = tl.load(out_ptrs) + + # Add the corrected output to out + updated_out_vals = out_vals + out_corrected + + # Store the updated out values + tl.store(out_ptrs, updated_out_vals) + + +# Modified from Megatron-LM. TODO: try Triton +def flash_attn_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step): + softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim) + softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) + out_corrected = out_per_step * softmax_lse_corrected_exp + out.add_(out_corrected) + + +def flash_attn_softmax_lse_correction(softmax_lse, softmax_lse_per_step): + max_scale = torch.max(softmax_lse, softmax_lse_per_step) + min_scale = torch.min(softmax_lse, softmax_lse_per_step) + new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) + softmax_lse.copy_(new_scale) + + class RingAttention(torch.autograd.Function): """Implements the Ring Attention from `Ring Attention with Blockwise Transformers for Near-Infinite Context` (https://arxiv.org/abs/2310.01889). diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 8e7a4b3a073a..8ea956004f19 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -1,5 +1,6 @@ import torch import torch.distributed as dist +import torch.nn.functional as F from torch.autograd import Function from torch.distributed import ProcessGroup from torch.nn import CrossEntropyLoss diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 4262b8d2342b..540f8e85e726 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -1,5 +1,9 @@ from contextlib import contextmanager +<<<<<<< HEAD from typing import List, Optional, Union +======= +from typing import Dict, List +>>>>>>> add basic ring attn; debug cross entropy import torch import torch.distributed as dist diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 207ed40ff9b5..9cd9043eff9b 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -839,12 +839,16 @@ def forward( labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) sp_mode = shard_config.sequence_parallelism_mode - sp_group = self.shard_config.sequence_parallel_process_group - assert not ( - shard_config.sp_mode == "ring_attn" and use_cache - ), "Ring attention requires q, k, v to have the same length and doesn't work for inference" - if sp_mode == "ring_attn": - inputs_embeds = ring_attn_split_forward(inputs_embeds, sp_group) + sp_group = shard_config.sequence_parallel_process_group + is_sp = shard_config.enable_sequence_parallelism + # Split labels + if is_sp: + assert not ( + sp_mode == "ring_attn" and use_cache + ), "Ring attention requires q, k, v to have the same length and doesn't work for inference" + if sp_mode == "ring_attn": + batch = ring_attn_split_forward({"labels": labels}, sp_group) + labels = batch["labels"] # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 084c818e18b9..589ed730ec79 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -56,7 +56,6 @@ class ShardConfig: moe_dp_group: Optional[ProcessGroup] = None ep_group: Optional[ProcessGroup] = None sp_stream: Optional[torch.cuda.Stream] = None - dkv_group: Optional[ProcessGroup] = None # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 1459d606966f..74ab1bca5344 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -1,4 +1,5 @@ import os +from copy import deepcopy import pytest import torch @@ -63,7 +64,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): master2working = sharded_optimizer.get_master_to_working_map() - for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): + for (name, p1), p2 in zip( + llama_model.named_parameters(), sharded_optimizer._master_param_groups_of_current_rank[0] + ): working_p = master2working[id(p2)] grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p)) grad_index = ( @@ -73,7 +76,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) grad = grads[grad_index] sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] - assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) + if name == "embed_tokens.weight": + continue + try: + assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) + except Exception as e: + raise RuntimeError(f"Failed to check grad for {name}") from e # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} @@ -114,7 +122,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 5e-3 if org_model.__class__.__name__ == "LlamaModel": - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + check_output_hidden_state( + org_output, + sharded_output, + stage_manager, + atol=atol, + rtol=rtol, + shard_config=booster.plugin.shard_config, + ) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) @@ -124,20 +139,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 - try: - check_weight( - llama_model, - shard_llama_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False, - ) - except Exception as e: - print(f"Failed config: {test_config}") - raise e + check_weight( + llama_model, + shard_llama_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) # check grads check_all_grad_tensors(grads_to_check) @@ -160,6 +171,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "zero_stage": 2, "precision": "fp16", "initial_scale": 1, + "parallel_output": False, }, { # Ulysess + Flash attention "tp_size": 1, @@ -173,6 +185,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "zero_stage": 1, "precision": "fp16", "initial_scale": 1, + "parallel_output": False, }, { "tp_size": 1, @@ -185,6 +198,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "zero_stage": 1, "precision": "fp16", "initial_scale": 1, + "parallel_output": False, }, { "tp_size": 4, @@ -192,10 +206,11 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": False, + "enable_flash_attention": True, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, + "parallel_output": False, }, { "tp_size": 2, @@ -243,9 +258,14 @@ def run_llama_test(test_config): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): try: - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + config = test_config + if name == "transformers_llama_for_casual_lm": + # Test the cross entropy loss distributed along sequence + config = deepcopy(test_config) + config["parallel_output"] = True + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, config) except Exception as e: - print(f"Failed config: {test_config}") + print(f"Failed config: {test_config}, model name: {name}") raise e clear_layout_converter() From c050293809ce98cdb14c27530d4cf4639b34c7ae Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Sat, 13 Jul 2024 16:10:02 +0000 Subject: [PATCH 31/71] fwd bwd logic complete --- .../hybrid_parallel_checkpoint_io.py | 1 + colossalai/shardformer/layer/_operation.py | 10 ++++- colossalai/shardformer/layer/attn.py | 15 +++++++ colossalai/shardformer/layer/utils.py | 1 + colossalai/shardformer/modeling/llama.py | 2 +- examples/language/opt/opt_benchmark.py | 2 +- tests/kit/model_zoo/transformers/llama.py | 8 ++++ .../test_model/test_shard_llama.py | 42 ++++--------------- 8 files changed, 44 insertions(+), 37 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 043e5c2b0618..6edc89313097 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -656,6 +656,7 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict. state_dict_list = [None for _ in range(self.pp_size)] dist.barrier(self.pp_group) + # torch.cuda.set_device(os.environ["LOCAL_RANK"]) dist.all_gather_object(state_dict_list, state_dict, self.pp_group) # Only the master rank do the saving. if self.coordinator.is_master(): diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index a9f1c5e2968a..b31d8a596a21 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -651,7 +651,7 @@ def backward(ctx, grad_output): ).contiguous() handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have - # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py + # all-reduce scheduled first and have GPU resources allocated grad_weight = total_input.t().matmul(grad_output) grad_bias = grad_output.sum(dim=0) if use_bias else None @@ -1003,11 +1003,19 @@ def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1): return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) +<<<<<<< HEAD def gather_sp_output(hidden_states, sp_group, sp_mode, sp_dim=1): +======= +def gather_sp_output(hidden_states, sp_group, sp_mode): +>>>>>>> fwd bwd logic complete """ Gather the output of the last layer for cross entropy computation """ # Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group) scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group) +<<<<<<< HEAD hidden_states = gather_forward_split_backward(hidden_states, sp_dim, sp_group, grad_scale=scale) +======= + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=scale) +>>>>>>> fwd bwd logic complete return hidden_states diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index e2daa37cf9f6..9470baa6cf95 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -405,6 +405,17 @@ def flash_attn_fwd_out_corr_triton( tl.store(out_ptrs, updated_out_vals) +def flash_attn_out_lse_rescale(out, out_per_step, lse, lse_step): + """ + out: (B, Sq, H, D) + out_per_step: (B, Sq, H, D) + lse: (B, H, Sq, 1) + """ + new_lse = lse + torch.log(1 + torch.exp(lse_step - lse)) + out.copy_(torch.exp(lse - new_lse) * out + torch.exp(lse_step - new_lse) * out_per_step) + lse.copy_(new_lse) + + # Modified from Megatron-LM. TODO: try Triton def flash_attn_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step): softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim) @@ -414,6 +425,10 @@ def flash_attn_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_l def flash_attn_softmax_lse_correction(softmax_lse, softmax_lse_per_step): + """ + softmax_lse: (B, H, Sq) + softmax_lse_per_step: (B, H, Sq) + """ max_scale = torch.max(softmax_lse, softmax_lse_per_step) min_scale = torch.min(softmax_lse, softmax_lse_per_step) new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 540f8e85e726..4e8cbd05e43d 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -12,6 +12,7 @@ from torch.distributed import ProcessGroup, get_world_size from colossalai.accelerator import get_accelerator +from colossalai.shardformer.layer.attn import get_pad_info class SeqParallelUtils: diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 9cd9043eff9b..a2dc36453c61 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -847,7 +847,7 @@ def forward( sp_mode == "ring_attn" and use_cache ), "Ring attention requires q, k, v to have the same length and doesn't work for inference" if sp_mode == "ring_attn": - batch = ring_attn_split_forward({"labels": labels}, sp_group) + batch = ring_attn_split_batch({"labels": labels}, sp_group) labels = batch["labels"] # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) diff --git a/examples/language/opt/opt_benchmark.py b/examples/language/opt/opt_benchmark.py index ca9b63d1a14a..7b30f1939cf0 100755 --- a/examples/language/opt/opt_benchmark.py +++ b/examples/language/opt/opt_benchmark.py @@ -96,7 +96,7 @@ def main(): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) model, optimizer, _, _, _ = booster.boost(model, optimizer) - + booster.save_model(model, "model.pt") SEQ_LEN = 1024 VOCAB_SIZE = 50257 diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index ac729cb1a3e2..943c5cf1c58e 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -98,6 +98,14 @@ def data_gen_for_causal_lm(): loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True), ) + model_zoo.register( + name="transformers_llama", + model_fn=lambda: transformers.LlamaModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), + ) model_zoo.register( name="transformers_llama_for_sequence_classification", model_fn=lambda: transformers.LlamaForSequenceClassification(config), diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 74ab1bca5344..d905ae68fa05 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -1,5 +1,4 @@ import os -from copy import deepcopy import pytest import torch @@ -76,8 +75,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) grad = grads[grad_index] sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] - if name == "embed_tokens.weight": - continue try: assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) except Exception as e: @@ -130,9 +127,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, rtol=rtol, shard_config=booster.plugin.shard_config, ) - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # check weights if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): if test_config["precision"] == "fp32": @@ -152,26 +147,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grads check_all_grad_tensors(grads_to_check) - torch.cuda.empty_cache() @parameterize( "test_config", [ - { # Test ring + Flash attention - "tp_size": 2, + { + "tp_size": 1, "pp_size": 1, "sp_size": 2, "num_microbatches": 1, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring", - "enable_flash_attention": True, + "sequence_parallelism_mode": "all_to_all", "use_lazy_init": True, - "zero_stage": 2, + "zero_stage": 1, "precision": "fp16", "initial_scale": 1, - "parallel_output": False, + "parallel_output": True, }, { # Ulysess + Flash attention "tp_size": 1, @@ -185,20 +178,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "zero_stage": 1, "precision": "fp16", "initial_scale": 1, - "parallel_output": False, - }, - { - "tp_size": 1, - "pp_size": 1, - "sp_size": 2, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "use_lazy_init": True, - "zero_stage": 1, - "precision": "fp16", - "initial_scale": 1, - "parallel_output": False, + "parallel_output": True, }, { "tp_size": 4, @@ -210,7 +190,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, - "parallel_output": False, + "parallel_output": True, }, { "tp_size": 2, @@ -255,15 +235,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) def run_llama_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): try: - config = test_config - if name == "transformers_llama_for_casual_lm": - # Test the cross entropy loss distributed along sequence - config = deepcopy(test_config) - config["parallel_output"] = True - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, config) + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) except Exception as e: print(f"Failed config: {test_config}, model name: {name}") raise e From 65b4b76162de2c2400a5df8d41b1caf8c8d67b76 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Sun, 14 Jul 2024 14:18:12 +0000 Subject: [PATCH 32/71] fwd bwd logic complete; add experimental triton rescale --- .../hybrid_parallel_checkpoint_io.py | 1 - colossalai/pipeline/schedule/one_f_one_b.py | 3 - colossalai/shardformer/layer/attn.py | 146 ++++++++++++------ colossalai/shardformer/layer/utils.py | 1 - colossalai/shardformer/modeling/command.py | 2 + colossalai/shardformer/modeling/llama.py | 12 -- colossalai/shardformer/policies/llama.py | 5 + .../test_model/test_shard_llama.py | 23 ++- 8 files changed, 121 insertions(+), 72 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 6edc89313097..043e5c2b0618 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -656,7 +656,6 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict. state_dict_list = [None for _ in range(self.pp_size)] dist.barrier(self.pp_group) - # torch.cuda.set_device(os.environ["LOCAL_RANK"]) dist.all_gather_object(state_dict_list, state_dict, self.pp_group) # Only the master rank do the saving. if self.coordinator.is_master(): diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 4c8519030b1c..03df67ae78c3 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -32,7 +32,6 @@ def __init__( num_microbatches: Optional[int] = None, microbatch_size: Optional[int] = None, enable_metadata_cache: bool = True, - shard_config=None, ) -> None: """1F1B pipeline schedule. @@ -40,7 +39,6 @@ def __init__( stage_manager (PipelineStageManager): Pipeline stage manager num_microbatches (Optional[int], optional): The number of microbatches. If not provided, it will be derived from microbatch size. Defaults to None. microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None. - shard_config: Shard configuration for gathering Sequence Parallel loss. """ super().__init__(stage_manager) assert ( @@ -55,7 +53,6 @@ def __init__( self.batch_size: Optional[int] = None self.last_batch_size: Optional[int] = None self.microbatch_offset: Optional[int] = None - self.shard_config = shard_config # P2PMeta cache self.enable_metadata_cache = enable_metadata_cache diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 9470baa6cf95..097b3a968f4e 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -369,43 +369,87 @@ def _rescale_out_lse(out, block_out, lse, block_lse): @triton.jit -def flash_attn_fwd_out_corr_triton( - out_ptr, out_per_step_ptr, seq_dim, softmax_lse_ptr, softmax_lse_per_step_ptr, BLOCK_SIZE: tl.constexpr +def flash_attn_out_lse_rescale_kernel( + out_ptr, + out_per_step_ptr, + lse_ptr, + lse_step_ptr, + B, + Sq, + H, + D, + stride_out_0, + stride_out_1, + stride_out_2, + stride_out_3, + stride_out_per_step_0, + stride_out_per_step_1, + stride_out_per_step_2, + stride_out_per_step_3, + stride_lse_0, + stride_lse_1, + stride_lse_2, + stride_lse_3, ): - # Calculate the global id - pid = tl.program_id(0) - - # Offsets for the current row - offsets = tl.arange(0, BLOCK_SIZE) - - # Pointers to the current row in out and out_per_step - row_start = pid * seq_dim - out_ptrs = out_ptr + row_start + offsets - out_per_step_ptrs = out_per_step_ptr + row_start + offsets - - # Load softmax_lse and softmax_lse_per_step - softmax_lse = tl.load(softmax_lse_ptr + pid) - softmax_lse_per_step = tl.load(softmax_lse_per_step_ptr + pid) - - # Compute the corrected exponentiation - softmax_lse_corrected_exp = tl.exp(softmax_lse_per_step - softmax_lse) - - out_per_step_vals = tl.load(out_per_step_ptrs) - - # Correct the out_per_step by the exponentiation - out_corrected = out_per_step_vals * softmax_lse_corrected_exp - - # Load the current out values - out_vals = tl.load(out_ptrs) - - # Add the corrected output to out - updated_out_vals = out_vals + out_corrected - - # Store the updated out values - tl.store(out_ptrs, updated_out_vals) - - -def flash_attn_out_lse_rescale(out, out_per_step, lse, lse_step): + batch_id = tl.program_id(0) + sq_id = tl.program_id(1) + h_id = tl.program_id(2) + d_id = tl.arange(0, D) + + out_idx = batch_id * stride_out_0 + sq_id * stride_out_1 + h_id * stride_out_2 + d_id * stride_out_3 + out_per_step_idx = ( + batch_id * stride_out_per_step_0 + + sq_id * stride_out_per_step_1 + + h_id * stride_out_per_step_2 + + d_id * stride_out_per_step_3 + ) + lse_idx = batch_id * stride_lse_0 + h_id * stride_lse_1 + sq_id * stride_lse_2 + tl.zeros(D) * stride_lse_3 + lse_step_idx = batch_id * stride_lse_0 + h_id * stride_lse_1 + sq_id * stride_lse_2 + tl.zeros(D) * stride_lse_3 + + out = tl.load(out_ptr + out_idx) + out_per_step = tl.load(out_per_step_ptr + out_per_step_idx) + lse = tl.load(lse_ptr + lse_idx) + lse_step = tl.load(lse_step_ptr + lse_step_idx) + + new_lse = lse + tl.log(1 + tl.exp(lse_step - lse)) + out = tl.exp(lse - new_lse) * out + tl.exp(lse_step - new_lse) * out_per_step + + tl.store(out_ptr + out_idx, out) + tl.store(lse_ptr + lse_idx, new_lse) + + +def rescale_out_lse_triton(out, out_per_step, lse, lse_step): + B, Sq, H, D = out.shape + + assert out.is_contiguous() and out_per_step.is_contiguous() and lse.is_contiguous() and lse_step.is_contiguous() + + grid = (B, Sq, H) + + flash_attn_out_lse_rescale_kernel[grid]( + out, + out_per_step, + lse, + lse_step, + B, + Sq, + H, + D, + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + out_per_step.stride(0), + out_per_step.stride(1), + out_per_step.stride(2), + out_per_step.stride(3), + lse.stride(0), + lse.stride(1), + lse.stride(2), + lse.stride(3), + ) + + +def rescale_out_lse(out, out_per_step, lse, lse_step): """ out: (B, Sq, H, D) out_per_step: (B, Sq, H, D) @@ -416,23 +460,23 @@ def flash_attn_out_lse_rescale(out, out_per_step, lse, lse_step): lse.copy_(new_lse) -# Modified from Megatron-LM. TODO: try Triton -def flash_attn_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step): - softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim) - softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) - out_corrected = out_per_step * softmax_lse_corrected_exp - out.add_(out_corrected) +# From Megatron-LM. TODO: try Triton +# def flash_attn_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step): +# softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim) +# softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) +# out_corrected = out_per_step * softmax_lse_corrected_exp +# out.add_(out_corrected) -def flash_attn_softmax_lse_correction(softmax_lse, softmax_lse_per_step): - """ - softmax_lse: (B, H, Sq) - softmax_lse_per_step: (B, H, Sq) - """ - max_scale = torch.max(softmax_lse, softmax_lse_per_step) - min_scale = torch.min(softmax_lse, softmax_lse_per_step) - new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) - softmax_lse.copy_(new_scale) +# def flash_attn_softmax_lse_correction(softmax_lse, softmax_lse_per_step): +# """ +# softmax_lse: (B, H, Sq) +# softmax_lse_per_step: (B, H, Sq) +# """ +# max_scale = torch.max(softmax_lse, softmax_lse_per_step) +# min_scale = torch.min(softmax_lse, softmax_lse_per_step) +# new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) +# softmax_lse.copy_(new_scale) class RingAttention(torch.autograd.Function): diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 4e8cbd05e43d..540f8e85e726 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -12,7 +12,6 @@ from torch.distributed import ProcessGroup, get_world_size from colossalai.accelerator import get_accelerator -from colossalai.shardformer.layer.attn import get_pad_info class SeqParallelUtils: diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 62299bf0656e..cac325dcbea6 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -25,6 +25,8 @@ _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring"] +_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] + class CommandPipelineForwards: """ diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index a2dc36453c61..219933c705e9 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -838,18 +838,6 @@ def forward( # [B, max_seq_len // sp_size] labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) - sp_mode = shard_config.sequence_parallelism_mode - sp_group = shard_config.sequence_parallel_process_group - is_sp = shard_config.enable_sequence_parallelism - # Split labels - if is_sp: - assert not ( - sp_mode == "ring_attn" and use_cache - ), "Ring attention requires q, k, v to have the same length and doesn't work for inference" - if sp_mode == "ring_attn": - batch = ring_attn_split_batch({"labels": labels}, sp_group) - labels = batch["labels"] - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index ce1e0e3de35d..60dbe2630829 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -85,6 +85,11 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: num_q_heads = self.model.config.num_attention_heads num_kv_heads = getattr(self.model.config, "num_key_value_heads", None) + tp_size = self.shard_config.tensor_parallel_size + # Modified by SP and TP + num_q_heads = self.model.config.num_attention_heads + num_kv_heads = getattr(self.model.config, "num_key_value_heads", None) + if sp_mode == "all_to_all": num_q_heads //= sp_size decoder_attribute_replacement = {"num_heads": num_q_heads} diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index d905ae68fa05..8f3065ae8d7a 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -153,27 +153,42 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + # Zigzag Ring Attention { - "tp_size": 1, + "tp_size": 2, "pp_size": 1, "sp_size": 2, "num_microbatches": 1, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", + "sequence_parallelism_mode": "ring_attn", "use_lazy_init": True, "zero_stage": 1, "precision": "fp16", "initial_scale": 1, "parallel_output": True, }, - { # Ulysess + Flash attention + { # Ulysess + TP + "tp_size": 2, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 0, + "precision": "fp16", + "initial_scale": 1, + "parallel_output": True, + }, + { # Ulysess + PP "tp_size": 1, "pp_size": 2, "sp_size": 2, "num_microbatches": 2, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "all_to_all", - "enable_flash_attention": True, + "enable_all_optimization": True, "use_lazy_init": True, "zero_stage": 1, "precision": "fp16", From 5824ede11ec119707e7c9c691ae7061390a98196 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 18 Jul 2024 07:17:08 +0000 Subject: [PATCH 33/71] precision tests passed --- colossalai/shardformer/layer/_operation.py | 4 - colossalai/shardformer/layer/attn.py | 67 ++++--- colossalai/shardformer/layer/loss.py | 29 +++ colossalai/shardformer/layer/utils.py | 186 ++++++++++++++++++ colossalai/shardformer/modeling/llama.py | 48 +++++ colossalai/shardformer/policies/command.py | 4 + colossalai/shardformer/policies/llama.py | 3 + examples/language/opt/opt_benchmark.py | 1 - .../test_model/test_shard_llama.py | 26 ++- 9 files changed, 336 insertions(+), 32 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index b31d8a596a21..c6f61d3bb99f 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -812,11 +812,7 @@ def backward(ctx, *grad_output): process_group = ctx.process_group scatter_dim = ctx.gather_dim gather_dim = ctx.scatter_dim - if torch.distributed.get_rank() == 0: - print(f"shape before A2A: {grad_output[0].shape}") return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) - if torch.distributed.get_rank() == 0: - print(f"shape after A2A: {return_grad.shape}") return (return_grad, None, None, None) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 097b3a968f4e..37d714946bc3 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -368,16 +368,17 @@ def _rescale_out_lse(out, block_out, lse, block_lse): return out, lse +def _not_nan(x): + return not (x.isnan().any() or x.isinf().any()) + + @triton.jit -def flash_attn_out_lse_rescale_kernel( +def _rescale_out_lse_kernel( out_ptr, out_per_step_ptr, lse_ptr, lse_step_ptr, - B, - Sq, - H, - D, + D, # Each thread handles D elements stride_out_0, stride_out_1, stride_out_2, @@ -390,6 +391,7 @@ def flash_attn_out_lse_rescale_kernel( stride_lse_1, stride_lse_2, stride_lse_3, + BLOCK_M: tl.constexpr, ): batch_id = tl.program_id(0) sq_id = tl.program_id(1) @@ -406,11 +408,13 @@ def flash_attn_out_lse_rescale_kernel( lse_idx = batch_id * stride_lse_0 + h_id * stride_lse_1 + sq_id * stride_lse_2 + tl.zeros(D) * stride_lse_3 lse_step_idx = batch_id * stride_lse_0 + h_id * stride_lse_1 + sq_id * stride_lse_2 + tl.zeros(D) * stride_lse_3 + # Load inputs out = tl.load(out_ptr + out_idx) out_per_step = tl.load(out_per_step_ptr + out_per_step_idx) lse = tl.load(lse_ptr + lse_idx) lse_step = tl.load(lse_step_ptr + lse_step_idx) + # Element-wise rescale new_lse = lse + tl.log(1 + tl.exp(lse_step - lse)) out = tl.exp(lse - new_lse) * out + tl.exp(lse_step - new_lse) * out_per_step @@ -418,18 +422,18 @@ def flash_attn_out_lse_rescale_kernel( tl.store(lse_ptr + lse_idx, new_lse) -def rescale_out_lse_triton(out, out_per_step, lse, lse_step): +def _rescale_out_lse_triton(out, block_out, lse, block_lse): B, Sq, H, D = out.shape - assert out.is_contiguous() and out_per_step.is_contiguous() and lse.is_contiguous() and lse_step.is_contiguous() + assert out.is_contiguous() and block_out.is_contiguous() and lse.is_contiguous() and block_lse.is_contiguous() - grid = (B, Sq, H) - - flash_attn_out_lse_rescale_kernel[grid]( + # TODO: use 1d kernel? + grid = lambda META: (triton.cdiv(Sq, META["BLOCK_M"]), B, H) + _rescale_out_lse_kernel[grid]( out, - out_per_step, + block_out, lse, - lse_step, + block_lse, B, Sq, H, @@ -438,10 +442,10 @@ def rescale_out_lse_triton(out, out_per_step, lse, lse_step): out.stride(1), out.stride(2), out.stride(3), - out_per_step.stride(0), - out_per_step.stride(1), - out_per_step.stride(2), - out_per_step.stride(3), + block_out.stride(0), + block_out.stride(1), + block_out.stride(2), + block_out.stride(3), lse.stride(0), lse.stride(1), lse.stride(2), @@ -449,16 +453,35 @@ def rescale_out_lse_triton(out, out_per_step, lse, lse_step): ) -def rescale_out_lse(out, out_per_step, lse, lse_step): +def _rescale_out_lse(out, block_out, lse, block_lse): """ - out: (B, Sq, H, D) - out_per_step: (B, Sq, H, D) - lse: (B, H, Sq, 1) + Compute the new attention denominator: + exp(lse) + exp(block_lse) = exp(max_scale) * (exp(min_scale - max_scale) + 1) + Args: + out: (B, Sq, H, D) + block_out: (B, Sq, H, D) + lse: (B, H, Sq, 1) + block_lse: (B, H, Sq, 1) """ - new_lse = lse + torch.log(1 + torch.exp(lse_step - lse)) - out.copy_(torch.exp(lse - new_lse) * out + torch.exp(lse_step - new_lse) * out_per_step) + + # min_scale = torch.min(lse, block_lse) + # max_scale = torch.max(lse, block_lse) + # new_lse = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) + new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + new_block_lse = torch.exp(block_lse - new_lse) + assert _not_nan(new_lse), new_lse + # dist.barrier() + assert _not_nan(new_block_lse), new_block_lse + + out.copy_(torch.exp(lse - new_lse) * out + new_block_lse * block_out) lse.copy_(new_lse) + # block_out = block_out.float() + # out.copy_(out - F.sigmoid(block_lse - lse) * (out - block_out)) + # lse.copy_(lse - F.logsigmoid(lse - block_lse)) + # assert not lse.isnan().any(), lse + # assert not out.isnan().any(), out + # From Megatron-LM. TODO: try Triton # def flash_attn_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step): diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 8ea956004f19..3740265fc437 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -151,7 +151,11 @@ def cross_entropy_1d( def dist_cross_entropy( +<<<<<<< HEAD labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] +======= + labels: torch.Tensor, # [B, S] +>>>>>>> precision tests passed logits: torch.Tensor, # [B, S, Vocab_size] shard_config: ShardConfig, out_features: int, @@ -178,6 +182,7 @@ def dist_cross_entropy( logits = logits.reshape(-1, *logits.shape[2:]) seq_dim = 0 +<<<<<<< HEAD # Shift labels to predict the next token, and remove the tail logit predicting is_sp = sp_size > 1 and (not is_share_sp_tp(sp_mode)) split_labels_here = seq_len // sp_size == logits.size(seq_dim) # ring attn splits labels before forward @@ -195,6 +200,23 @@ def dist_cross_entropy( if split_labels_here: labels = labels.split(seq_len // sp_size, dim=-1)[sp_rank] +======= + bs, seq_len = labels.shape + + # Shift labels to predict the next token, and remove the tail logit predicting + is_sp = sp_size > 1 and (not is_share_sp_tp(sp_mode)) + split_labels_here = seq_len // sp_size == logits.size(seq_dim) # ring attn splits labels before forward + if is_sp: + # Just don't shift twice + if split_labels_here or sp_rank == sp_size - 1: + labels = labels[..., 1:] + + # Split labels when logits are split + if split_labels_here: + labels = labels.split(seq_len // sp_size, dim=-1)[sp_rank] + + # The rank holding the last seq chunk +>>>>>>> precision tests passed if sp_rank == sp_size - 1: logits = logits[..., :-1, :] # Pad logits and labels to the same shape across all ranks for TP all_reduce @@ -209,6 +231,7 @@ def dist_cross_entropy( labels = torch.cat([labels, padding], dim=seq_dim) else: labels = labels[..., 1:] +<<<<<<< HEAD logits = logits[..., :-1, :] labels = labels.contiguous() logits = logits.contiguous() @@ -219,6 +242,12 @@ def dist_cross_entropy( ), f"label shape {labels.shape} does not match logit shape {logits.shape}" except Exception as e: raise e +======= + logits = logits[..., :-1, :].contiguous() + labels = labels.contiguous() + num_nonzero = (labels != _IGNORE_IDX).sum() + assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}" +>>>>>>> precision tests passed # Flatten the tokens loss_fct = CrossEntropyLoss(ignore_index=_IGNORE_IDX, reduction="sum") diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 540f8e85e726..ed6951afca7e 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -1,9 +1,13 @@ from contextlib import contextmanager <<<<<<< HEAD +<<<<<<< HEAD from typing import List, Optional, Union ======= from typing import Dict, List >>>>>>> add basic ring attn; debug cross entropy +======= +from typing import List +>>>>>>> precision tests passed import torch import torch.distributed as dist @@ -295,9 +299,13 @@ def create_randomizer_with_offset( return Randomizer(seed=base_seed) +<<<<<<< HEAD def split_batch_zigzag( batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False ) -> Union[torch.Tensor, List[torch.Tensor]]: +======= +def zigzag_split_batch(batch: List[torch.Tensor], sp_group: ProcessGroup, varlen: bool = False): +>>>>>>> precision tests passed """ Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask in the causal setting will result in the preceding ranks having much less workload. @@ -305,6 +313,7 @@ def split_batch_zigzag( For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |. Args: +<<<<<<< HEAD batch (List[torch.Tensor] or Tensor): The input tensor(s) to split. sp_group (ProcessGroup): The process group for sequence parallelism. seq_dim (int): The sequence dimension to split. @@ -326,6 +335,22 @@ def split_batch_zigzag( assert tensor.dim() == 2, "Label shape should be (B, Seqlen)" tensor = torch.cat([tensor[:, 1:], torch.full_like(tensor[:, :1], -100)], dim=1) +======= + batch (List[torch.Tensor]): The input tensors to split. + sp_group (ProcessGroup): The process group for sequence parallelism. + varlen (bool): If the input is padded (aka "packing" mode), such that + sequences in a batch have different lengths, and we need to unpad and + split each sequence evenly by sp_size. + """ + sp_size = dist.get_world_size(sp_group) + sp_rank = dist.get_rank(sp_group) + seq_dim = 1 + if sp_size > 1: + for idx, tensor in enumerate(batch): + assert ( + tensor.numel() // (sp_size * 2) > 1 + ), f"Bro, the seq length for tensor {idx} in batch is too short to split!" +>>>>>>> precision tests passed tensor = tensor.view( *tensor.shape[:seq_dim], 2 * sp_size, @@ -336,6 +361,7 @@ def split_batch_zigzag( tensor = tensor.index_select(seq_dim, indices).contiguous() # (B, 2, Sq // (2 * sp_size), ...) -> (B, Sq // sp_size, ...) batch[idx] = tensor.view(*tensor.shape[:seq_dim], -1, *tensor.shape[seq_dim + 2 :]) +<<<<<<< HEAD if len(batch) == 1: return batch[0] @@ -420,6 +446,8 @@ def split_varlen_zigzag( batch[i] = local_seq.contiguous() else: batch[i] = torch.cat(local_seq, dim=0) +======= +>>>>>>> precision tests passed if len(batch) == 1: batch = batch[0] @@ -434,6 +462,7 @@ def is_share_sp_tp(sp_mode: str): return sp_mode in ["ring", "split_gather"] +<<<<<<< HEAD class RingComm: def __init__(self, process_group: dist.ProcessGroup): self._process_group = process_group @@ -492,3 +521,160 @@ def get_half_index(cu_seqlens, *, front: bool): start = (start + end) // 2 index[start:end] = True return index +======= +# Copied from https://github.com/zhuzilin/ring-flash-attention/tree/main/ring_flash_attn +# Use Triton kernel if installed else use torch +try: + import triton + import triton.language as tl + + @triton.jit + def flatten_kernel( + # pointers to matrices + OUT, + LSE, + CU_SEQLENS, + # strides + stride_out_nheads, + stride_out_seqlen, + stride_lse_batch, + stride_lse_nheads, + stride_lse_seqlen, + # meta-parameters + BLOCK_M: tl.constexpr, + ): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + LSE = LSE + pid_batch * stride_lse_batch + pid_head * stride_lse_nheads + OUT = OUT + pid_head * stride_out_nheads + start_idx * stride_out_seqlen + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + + LSE = LSE + rm[:, None] * stride_lse_seqlen + x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) + + OUT = OUT + rm[:, None] * stride_out_seqlen + tl.store(OUT, x, mask=rm[:, None] < seqlen) + + def flatten_varlen_lse(lse, cu_seqlens): + """ + Arguments: + lse: (batch_size, nheads, max_seqlen) + cu_seqlens: (batch_size + 1,) + Return: + flatten_lse: (nheads, total_seqlen) + """ + total_seqlen = cu_seqlens[-1] + batch_size, nheads, max_seqlen = lse.shape + output = torch.empty((nheads, total_seqlen), dtype=lse.dtype, device=lse.device) + + grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) + BLOCK_M = 4 + + with torch.cuda.device(lse.device.index): + flatten_kernel[grid]( + output, + lse, + cu_seqlens, + # strides + output.stride(0), + output.stride(1), + lse.stride(0), + lse.stride(1), + lse.stride(2), + BLOCK_M, + ) + return output + + @triton.jit + def unflatten_kernel( + # pointers to matrices + OUT, + LSE, + CU_SEQLENS, + # strides + stride_out_batch, + stride_out_nheads, + stride_out_seqlen, + stride_lse_seqlen, + stride_lse_nheads, + # meta-parameters + BLOCK_M: tl.constexpr, + ): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + LSE = LSE + pid_head * stride_lse_nheads + start_idx * stride_lse_seqlen + OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + + LSE = LSE + rm[:, None] * stride_lse_seqlen + x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) + + OUT = OUT + rm[:, None] * stride_out_seqlen + tl.store(OUT, x, mask=rm[:, None] < seqlen) + + def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): + """ + Arguments: + lse: (total_seqlen, nheads, 1) + cu_seqlens: (batch_size + 1,) + max_seqlen: int + Return: + unflatten_lse: (batch_size, nheads, max_seqlen) + """ + lse = lse.unsqueeze(dim=-1) + batch_size = len(cu_seqlens) - 1 + nheads = lse.shape[1] + output = torch.empty( + (batch_size, nheads, max_seqlen), + dtype=lse.dtype, + device=lse.device, + ) + + grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) + BLOCK_M = 4 + + with torch.cuda.device(lse.device.index): + unflatten_kernel[grid]( + output, + lse, + cu_seqlens, + # strides + output.stride(0), + output.stride(1), + output.stride(2), + lse.stride(0), + lse.stride(1), + BLOCK_M, + ) + return output + +except: + # Triton not installed, use torch instead + @torch.jit.script + def flatten_varlen_lse(lse, cu_seqlens): + new_lse = [] + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + new_lse.append(lse[i, :, : end - start]) + return torch.cat(new_lse, dim=1) + + @torch.jit.script + def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): + num_seq = len(cu_seqlens) - 1 + num_head = lse.shape[-2] + new_lse = torch.empty((num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device) + for i in range(num_seq): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + new_lse[i, : end - start] = lse[start:end] + return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous() +>>>>>>> precision tests passed diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 219933c705e9..0b47e0632387 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -313,6 +313,7 @@ def llama_for_causal_lm_forward( logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False +<<<<<<< HEAD if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output: # Split labels in a zigzag fashion too sp_group = shard_config.sequence_parallel_process_group @@ -321,6 +322,11 @@ def llama_for_causal_lm_forward( else: # [B, max_seqlen // sp_size] labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) +======= + if stage_manager.is_first_stage(): + if shard_config.sequence_parallelism_mode == "ring_attn": + labels = zigzag_split_batch(labels, shard_config.sequence_parallel_process_group) +>>>>>>> precision tests passed # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = LlamaPipelineForwards.llama_model_forward( @@ -569,8 +575,14 @@ def forward( key_states, value_states, sp_group, +<<<<<<< HEAD **attention_mask, inner_ring_size=shard_config.inner_ring_size, +======= + shard_config.sp_stream, + attention_mask["attention_mask"], + attention_mask["attention_mask_type"], +>>>>>>> precision tests passed ) elif shard_config.enable_flash_attention: @@ -682,8 +694,13 @@ def forward( position_ids = cache_position.unsqueeze(0) if shard_config.enable_flash_attention: +<<<<<<< HEAD mask_shape = (batch_size, 1, seq_len, past_seen_tokens + seq_len) attn_kwargs: dict = ColoAttention.prepare_attn_kwargs( +======= + mask_shape = (batch_size, 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len) + attn_mask: dict = ColoAttention.prepare_attn_kwargs( +>>>>>>> precision tests passed mask_shape, inputs_embeds.dtype, inputs_embeds.device, @@ -693,11 +710,16 @@ def forward( ) else: +<<<<<<< HEAD attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) +======= + attn_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) +>>>>>>> precision tests passed # Ring Attention zigzag batch processing if sp_mode == "ring_attn": assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." +<<<<<<< HEAD if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: inputs_embeds, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch( attention_mask, sp_group, inputs_embeds, position_ids @@ -707,6 +729,20 @@ def forward( attn_kwargs = {"attention_mask_type": attn_kwargs["attention_mask_type"]} # drop redundant tensors elif is_share_sp_tp(sp_mode): +======= + if attn_mask["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: + attn_mask["cu_seqlens"], attn_mask["max_seqlen"], attn_mask["indices"] = get_pad_info( + attn_mask["attention_mask"].squeeze(1).any(dim=-1) + ) # [B, 1, Sq, Skv] -> [B, Sq] + + else: + attn_mask["cu_seqlens"] = attn_mask["max_seqlen"] = attn_mask["indices"] = None + batch = [inputs_embeds, position_ids] + # inputs_embeds, attention_mask["attention_mask"], position_ids = zigzag_split_batch(batch, sp_group) + inputs_embeds, position_ids = zigzag_split_batch(batch, sp_group) + + elif sp_mode in ["ring", "split_gather"]: +>>>>>>> precision tests passed inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) elif sp_mode == "all_to_all": inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) @@ -724,7 +760,11 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, +<<<<<<< HEAD attn_kwargs, +======= + attn_mask, +>>>>>>> precision tests passed position_ids, past_key_values, output_attentions, @@ -735,7 +775,11 @@ def forward( else: layer_outputs = decoder_layer( hidden_states, +<<<<<<< HEAD attention_mask=attn_kwargs, +======= + attention_mask=attn_mask, +>>>>>>> precision tests passed position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, @@ -827,6 +871,7 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if shard_config.sequence_parallelism_mode == "ring_attn": +<<<<<<< HEAD labels = zigzag_split_batch(labels, shard_config.sequence_parallel_process_group) if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output: @@ -837,6 +882,9 @@ def forward( else: # [B, max_seq_len // sp_size] labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) +======= + labels = zigzag_split_batch([labels], shard_config.sequence_parallel_process_group)[0] +>>>>>>> precision tests passed # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index 1efd3d0179af..fff026e39872 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -292,7 +292,11 @@ class CommandForCausalLMPolicy(CommandPolicy): def module_policy(self): from transformers import CohereForCausalLM +<<<<<<< HEAD self.is_causal = True +======= + self.is_casual = True +>>>>>>> precision tests passed policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 60dbe2630829..5483719aa212 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -71,6 +71,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: sp_partial_derived = sp_mode in ["split_gather", "ring"] if sp_mode == "ring_attn" and not self.is_causal: raise ValueError("Ring attention is only meant for causal language modeling.") +<<<<<<< HEAD <<<<<<< HEAD tp_size = self.shard_config.tensor_parallel_size @@ -84,6 +85,8 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: # Modified by SP and TP num_q_heads = self.model.config.num_attention_heads num_kv_heads = getattr(self.model.config, "num_key_value_heads", None) +======= +>>>>>>> precision tests passed tp_size = self.shard_config.tensor_parallel_size # Modified by SP and TP diff --git a/examples/language/opt/opt_benchmark.py b/examples/language/opt/opt_benchmark.py index 7b30f1939cf0..5e5971d9f560 100755 --- a/examples/language/opt/opt_benchmark.py +++ b/examples/language/opt/opt_benchmark.py @@ -96,7 +96,6 @@ def main(): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) model, optimizer, _, _, _ = booster.boost(model, optimizer) - booster.save_model(model, "model.pt") SEQ_LEN = 1024 VOCAB_SIZE = 50257 diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 8f3065ae8d7a..06ed5f22405f 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -163,9 +163,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "sequence_parallelism_mode": "ring_attn", "use_lazy_init": True, "zero_stage": 1, - "precision": "fp16", + "precision": "bf16", "initial_scale": 1, - "parallel_output": True, }, { # Ulysess + TP "tp_size": 2, @@ -179,7 +178,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "zero_stage": 0, "precision": "fp16", "initial_scale": 1, - "parallel_output": True, }, { # Ulysess + PP "tp_size": 1, @@ -190,10 +188,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "sequence_parallelism_mode": "all_to_all", "enable_all_optimization": True, "use_lazy_init": True, +<<<<<<< HEAD "zero_stage": 1, +======= + "zero_stage": 0, "precision": "fp16", "initial_scale": 1, - "parallel_output": True, }, { "tp_size": 4, @@ -203,9 +203,22 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "sequence_parallelism_mode": "split_gather", "enable_flash_attention": True, "use_lazy_init": True, +>>>>>>> precision tests passed + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 1, + "sp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 2, "precision": "fp16", "initial_scale": 1, - "parallel_output": True, }, { "tp_size": 2, @@ -251,6 +264,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, def run_llama_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name: + continue + try: check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) except Exception as e: From 0e72997701f3d6e724bc9ac64b0bb64cfa491da3 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Sun, 21 Jul 2024 14:32:49 +0000 Subject: [PATCH 34/71] precision tests passed --- .../pipeline/schedule/interleaved_pp.py | 2 + colossalai/shardformer/layer/attn.py | 29 +- colossalai/shardformer/layer/loss.py | 23 +- colossalai/shardformer/layer/utils.py | 68 ++- colossalai/shardformer/modeling/llama.py | 39 +- colossalai/shardformer/policies/command.py | 4 + examples/language/llama/benchmark.py | 3 + .../benchmark/benchmark_qkvpacked_func.py | 87 ++++ .../benchmark_varlen_qkvpacked_func.py | 91 ++++ .../ring_flash_attn/__init__.py | 16 + .../ring_flash_attn/ring_flash_attn.py | 281 +++++++++++ .../ring_flash_attn/ring_flash_attn_varlen.py | 318 +++++++++++++ .../ring_flash_attn/stripe_flash_attn.py | 325 +++++++++++++ .../ring_flash_attn/triton_utils.py | 137 ++++++ ring-flash-attention/ring_flash_attn/utils.py | 110 +++++ .../ring_flash_attn/zigzag_ring_flash_attn.py | 327 +++++++++++++ .../zigzag_ring_flash_attn_varlen.py | 441 ++++++++++++++++++ ring-flash-attention/setup.py | 9 + .../test/test_ring_flash_attn_func.py | 124 +++++ .../test/test_ring_flash_attn_varlen_func.py | 157 +++++++ .../test/test_stripe_flash_attn_func.py | 130 ++++++ .../test/test_triton_kernels.py | 30 ++ .../test/test_zigzag_ring_flash_attn_func.py | 150 ++++++ ...test_zigzag_ring_flash_attn_varlen_func.py | 163 +++++++ tests/kit/model_zoo/transformers/llama.py | 3 + .../test_layer/test_ring_attn.py | 66 +++ .../test_model/test_shard_llama.py | 24 +- 27 files changed, 3127 insertions(+), 30 deletions(-) create mode 100644 ring-flash-attention/benchmark/benchmark_qkvpacked_func.py create mode 100644 ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py create mode 100644 ring-flash-attention/ring_flash_attn/__init__.py create mode 100644 ring-flash-attention/ring_flash_attn/ring_flash_attn.py create mode 100644 ring-flash-attention/ring_flash_attn/ring_flash_attn_varlen.py create mode 100644 ring-flash-attention/ring_flash_attn/stripe_flash_attn.py create mode 100644 ring-flash-attention/ring_flash_attn/triton_utils.py create mode 100644 ring-flash-attention/ring_flash_attn/utils.py create mode 100644 ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn.py create mode 100644 ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py create mode 100644 ring-flash-attention/setup.py create mode 100644 ring-flash-attention/test/test_ring_flash_attn_func.py create mode 100644 ring-flash-attention/test/test_ring_flash_attn_varlen_func.py create mode 100644 ring-flash-attention/test/test_stripe_flash_attn_func.py create mode 100644 ring-flash-attention/test/test_triton_kernels.py create mode 100644 ring-flash-attention/test/test_zigzag_ring_flash_attn_func.py create mode 100644 ring-flash-attention/test/test_zigzag_ring_flash_attn_varlen_func.py diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 412f3896fb80..8f26f8cb5bb5 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -283,6 +283,8 @@ def forward_step( # Load input ids, attention mask and labels micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) + if input_obj is not None: + assert all(not x.isnan().any() for x in input_obj.values()), "NaN detected in input_obj" # for the first stage, input_obj is None # for other stages, input_obj is the output of the previous stage containing hidden_states etc. # Only attention_mask from micro_batch is used diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 37d714946bc3..1cba15cb2c6c 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -18,6 +18,8 @@ from .utils import RingComm, get_half_index, split_varlen_zigzag +from .utils import RingComm + __all__ = [ "AttnMaskType", "ColoAttention", @@ -467,14 +469,14 @@ def _rescale_out_lse(out, block_out, lse, block_lse): # min_scale = torch.min(lse, block_lse) # max_scale = torch.max(lse, block_lse) # new_lse = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) + new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) new_block_lse = torch.exp(block_lse - new_lse) - assert _not_nan(new_lse), new_lse - # dist.barrier() - assert _not_nan(new_block_lse), new_block_lse - out.copy_(torch.exp(lse - new_lse) * out + new_block_lse * block_out) lse.copy_(new_lse) + assert _not_nan(new_lse), new_lse + assert _not_nan(new_block_lse), new_block_lse + assert _not_nan(out), out # block_out = block_out.float() # out.copy_(out - F.sigmoid(block_lse - lse) * (out - block_out)) @@ -483,23 +485,8 @@ def _rescale_out_lse(out, block_out, lse, block_lse): # assert not out.isnan().any(), out -# From Megatron-LM. TODO: try Triton -# def flash_attn_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step): -# softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim) -# softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) -# out_corrected = out_per_step * softmax_lse_corrected_exp -# out.add_(out_corrected) - - -# def flash_attn_softmax_lse_correction(softmax_lse, softmax_lse_per_step): -# """ -# softmax_lse: (B, H, Sq) -# softmax_lse_per_step: (B, H, Sq) -# """ -# max_scale = torch.max(softmax_lse, softmax_lse_per_step) -# min_scale = torch.min(softmax_lse, softmax_lse_per_step) -# new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) -# softmax_lse.copy_(new_scale) +def _not_nan(x): + return not (x.isnan().any() or x.isinf().any()) class RingAttention(torch.autograd.Function): diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 3740265fc437..01d914e297fe 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -173,6 +173,7 @@ def dist_cross_entropy( sp_mode = shard_config.sequence_parallelism_mode parallel_output = shard_config.parallel_output is_tp = shard_config.enable_tensor_parallelism +<<<<<<< HEAD is_packed = labels.dim() == 2 if is_packed: bs, seq_len = labels.shape @@ -199,6 +200,8 @@ def dist_cross_entropy( labels = labels[..., 1:] if split_labels_here: labels = labels.split(seq_len // sp_size, dim=-1)[sp_rank] +======= +>>>>>>> precision tests passed ======= bs, seq_len = labels.shape @@ -207,14 +210,14 @@ def dist_cross_entropy( is_sp = sp_size > 1 and (not is_share_sp_tp(sp_mode)) split_labels_here = seq_len // sp_size == logits.size(seq_dim) # ring attn splits labels before forward if is_sp: - # Just don't shift twice - if split_labels_here or sp_rank == sp_size - 1: + # shift only once + if split_labels_here or (sp_rank == sp_size - 1): labels = labels[..., 1:] - # Split labels when logits are split if split_labels_here: labels = labels.split(seq_len // sp_size, dim=-1)[sp_rank] +<<<<<<< HEAD # The rank holding the last seq chunk >>>>>>> precision tests passed if sp_rank == sp_size - 1: @@ -244,7 +247,21 @@ def dist_cross_entropy( raise e ======= logits = logits[..., :-1, :].contiguous() +======= + # Pad to the same shape across all ranks in TP all_reduce + if sp_rank == sp_size - 1: + logits = logits[..., :-1, :] + if is_tp and parallel_output: + pad_shape = [0] * logits.dim() * 2 + pad_shape[-3] = 1 # Right side, dim = -2 + logits = F.pad(logits, pad_shape, value=_IGNORE_IDX) + labels = F.pad(labels, (0, 1, 0, 0), value=_IGNORE_IDX) + else: + labels = labels[..., 1:] + logits = logits[..., :-1, :] +>>>>>>> precision tests passed labels = labels.contiguous() + logits = logits.contiguous() num_nonzero = (labels != _IGNORE_IDX).sum() assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}" >>>>>>> precision tests passed diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index ed6951afca7e..ce423d1dd00c 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -1,6 +1,7 @@ from contextlib import contextmanager <<<<<<< HEAD <<<<<<< HEAD +<<<<<<< HEAD from typing import List, Optional, Union ======= from typing import Dict, List @@ -8,6 +9,9 @@ ======= from typing import List >>>>>>> precision tests passed +======= +from typing import List, Optional, Union +>>>>>>> precision tests passed import torch import torch.distributed as dist @@ -299,12 +303,18 @@ def create_randomizer_with_offset( return Randomizer(seed=base_seed) +<<<<<<< HEAD <<<<<<< HEAD def split_batch_zigzag( batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False ) -> Union[torch.Tensor, List[torch.Tensor]]: ======= def zigzag_split_batch(batch: List[torch.Tensor], sp_group: ProcessGroup, varlen: bool = False): +>>>>>>> precision tests passed +======= +def zigzag_split_batch( + batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim=1, varlen: bool = False +): >>>>>>> precision tests passed """ Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask @@ -313,6 +323,7 @@ def zigzag_split_batch(batch: List[torch.Tensor], sp_group: ProcessGroup, varlen For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |. Args: +<<<<<<< HEAD <<<<<<< HEAD batch (List[torch.Tensor] or Tensor): The input tensor(s) to split. sp_group (ProcessGroup): The process group for sequence parallelism. @@ -337,19 +348,30 @@ def zigzag_split_batch(batch: List[torch.Tensor], sp_group: ProcessGroup, varlen ======= batch (List[torch.Tensor]): The input tensors to split. +======= + batch (List[torch.Tensor] or Tensor): The input tensor(s) to split. +>>>>>>> precision tests passed sp_group (ProcessGroup): The process group for sequence parallelism. + seq_dim (int): The sequence dimension to split. varlen (bool): If the input is padded (aka "packing" mode), such that sequences in a batch have different lengths, and we need to unpad and split each sequence evenly by sp_size. """ sp_size = dist.get_world_size(sp_group) sp_rank = dist.get_rank(sp_group) - seq_dim = 1 + if isinstance(batch, torch.Tensor): + batch = [batch] + seq_dim = seq_dim if seq_dim != -1 else batch[0].dim() - 1 + if sp_size > 1: for idx, tensor in enumerate(batch): assert ( tensor.numel() // (sp_size * 2) > 1 ), f"Bro, the seq length for tensor {idx} in batch is too short to split!" +<<<<<<< HEAD +>>>>>>> precision tests passed +======= + >>>>>>> precision tests passed tensor = tensor.view( *tensor.shape[:seq_dim], @@ -360,6 +382,7 @@ def zigzag_split_batch(batch: List[torch.Tensor], sp_group: ProcessGroup, varlen indices = torch.tensor([sp_rank, 2 * sp_size - 1 - sp_rank], device=tensor.device) tensor = tensor.index_select(seq_dim, indices).contiguous() # (B, 2, Sq // (2 * sp_size), ...) -> (B, Sq // sp_size, ...) +<<<<<<< HEAD batch[idx] = tensor.view(*tensor.shape[:seq_dim], -1, *tensor.shape[seq_dim + 2 :]) <<<<<<< HEAD @@ -451,9 +474,52 @@ def split_varlen_zigzag( if len(batch) == 1: batch = batch[0] +======= + batch[idx] = tensor.view(*tensor.shape[:seq_dim], -1, *tensor.shape[seq_dim + 2 :]).contiguous() + + if len(batch) == 1: + return batch[0] +>>>>>>> precision tests passed return batch +class RingComm: + def __init__(self, process_group: dist.ProcessGroup): + self._process_group = process_group + self._ops = [] + self.rank = dist.get_rank(self._process_group) + self.world_size = dist.get_world_size(self._process_group) + self._reqs = [] + + self.send_rank = (self.rank + 1) % self.world_size + self.recv_rank = (self.rank - 1) % self.world_size + + if process_group is not None: + self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) + self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) + + def send_recv(self, send_tensor: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None) -> torch.Tensor: + if recv_tensor is None: + res = torch.empty_like(send_tensor) + else: + res = recv_tensor + + # NOTE: looks like batch_isend_irecv doesn't deadlock even + # when we never swap send recv ops across ranks + send_op = dist.P2POp(dist.isend, send_tensor, self.send_rank, group=self._process_group) + recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) + self._ops.append(send_op) + self._ops.append(recv_op) + self._reqs = dist.batch_isend_irecv(self._ops) + return res + + def wait(self): + for req in self._reqs: + req.wait() + self._reqs = [] + self._ops = [] + + def is_share_sp_tp(sp_mode: str): """sp_mode "ring" and "split_gather" use the TP group as SP group to split both the vocab and sequence, so we must gather the sequence diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 0b47e0632387..801e1c91bef4 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -137,7 +137,11 @@ def llama_model_forward( elif shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) +<<<<<<< HEAD attn_kwargs = ColoAttention.prepare_attn_kwargs( +======= + attn_mask = ColoAttention.prepare_attn_kwargs( +>>>>>>> precision tests passed mask_shape, hidden_states.dtype, hidden_states.device, @@ -146,7 +150,11 @@ def llama_model_forward( invert=(sp_mode != "ring_attn"), ) else: +<<<<<<< HEAD attn_kwargs = self._update_causal_mask(attention_mask, hidden_states, cache_position) +======= + attn_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) +>>>>>>> precision tests passed # Support SP + PP # TODO: support padded casual cu_seqlens across stages @@ -154,12 +162,24 @@ def llama_model_forward( # Ring Attention zigzag batch processing if sp_mode == "ring_attn": assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." +<<<<<<< HEAD if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch( attention_mask, sp_group, hidden_states, position_ids ) else: hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group) +======= + if attn_mask["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: + attn_mask["cu_seqlens"], attn_mask["max_seqlen"], attn_mask["indices"] = get_pad_info( + attn_mask["attention_mask"].squeeze(1).any(dim=-1) + ) # [B, 1, Sq, Skv] -> [B, Sq] + else: + attn_mask["cu_seqlens"] = attn_mask["max_seqlen"] = attn_mask["indices"] = None + batch = [hidden_states, position_ids] + # inputs_embeds, attention_mask["attention_mask"], position_ids = zigzag_split_batch(batch, sp_group) + hidden_states, position_ids = zigzag_split_batch(batch, sp_group) +>>>>>>> precision tests passed elif is_share_sp_tp(sp_mode): hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) @@ -200,7 +220,11 @@ def llama_model_forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, +<<<<<<< HEAD attn_kwargs, +======= + attn_mask, +>>>>>>> precision tests passed position_ids, past_key_values, output_attentions, @@ -210,7 +234,11 @@ def llama_model_forward( else: layer_outputs = decoder_layer( hidden_states, +<<<<<<< HEAD attention_mask=attn_kwargs, +======= + attention_mask=attn_mask, +>>>>>>> precision tests passed position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, @@ -568,7 +596,9 @@ def forward( # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) + assert not self.q_proj.weight.isnan().any(), self.q_proj.weight + assert not query_states.isnan().any(), query_states if sp_mode == "ring_attn": attn_output = RingAttention.attention( query_states, @@ -580,7 +610,6 @@ def forward( inner_ring_size=shard_config.inner_ring_size, ======= shard_config.sp_stream, - attention_mask["attention_mask"], attention_mask["attention_mask_type"], >>>>>>> precision tests passed ) @@ -741,7 +770,11 @@ def forward( # inputs_embeds, attention_mask["attention_mask"], position_ids = zigzag_split_batch(batch, sp_group) inputs_embeds, position_ids = zigzag_split_batch(batch, sp_group) +<<<<<<< HEAD elif sp_mode in ["ring", "split_gather"]: +>>>>>>> precision tests passed +======= + elif is_share_sp_tp(sp_mode): >>>>>>> precision tests passed inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) elif sp_mode == "all_to_all": @@ -871,6 +904,7 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if shard_config.sequence_parallelism_mode == "ring_attn": +<<<<<<< HEAD <<<<<<< HEAD labels = zigzag_split_batch(labels, shard_config.sequence_parallel_process_group) @@ -885,6 +919,9 @@ def forward( ======= labels = zigzag_split_batch([labels], shard_config.sequence_parallel_process_group)[0] >>>>>>> precision tests passed +======= + labels = zigzag_split_batch(labels, shard_config.sequence_parallel_process_group) +>>>>>>> precision tests passed # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index fff026e39872..23e31f2e5dac 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -292,10 +292,14 @@ class CommandForCausalLMPolicy(CommandPolicy): def module_policy(self): from transformers import CohereForCausalLM +<<<<<<< HEAD <<<<<<< HEAD self.is_causal = True ======= self.is_casual = True +>>>>>>> precision tests passed +======= + self.is_causal = True >>>>>>> precision tests passed policy = super().module_policy() diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 093377e7a034..cbf497c1f8c5 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -332,8 +332,11 @@ def empty_init(): performance_evaluator.on_step_start(step) outputs = model(**batch) loss = outputs[0] +<<<<<<< HEAD del outputs # free memory +======= +>>>>>>> precision tests passed if dist.get_rank() == dist.get_world_size() - 1: print(f"Step {step} loss: {loss}") booster.backward(loss, optimizer) diff --git a/ring-flash-attention/benchmark/benchmark_qkvpacked_func.py b/ring-flash-attention/benchmark/benchmark_qkvpacked_func.py new file mode 100644 index 000000000000..a6742e04a696 --- /dev/null +++ b/ring-flash-attention/benchmark/benchmark_qkvpacked_func.py @@ -0,0 +1,87 @@ +import torch +import torch.cuda +import torch.distributed as dist +from flash_attn import flash_attn_qkvpacked_func +from ring_flash_attn import ( + ring_flash_attn_qkvpacked_func, + stripe_flash_attn_qkvpacked_func, + zigzag_ring_flash_attn_qkvpacked_func, +) + + +def benchmark(f, num_iter=100, forward_only=True, log=True): + dtype = torch.bfloat16 + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + batch_size = 1 + seqlen = 1024 * 8 + nheads = 5 + d = 128 + dropout_p = 0 + causal = True + deterministic = False + + assert seqlen % (2 * world_size) == 0 + assert d % 8 == 0 + + qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) + dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) + + begin = torch.cuda.Event(enable_timing=True) + begin.record() + + if forward_only: + with torch.no_grad(): + for _ in range(num_iter): + _ = f( + qkv, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=False, + ) + + else: + for _ in range(num_iter): + qkv.grad = None + out = f( + qkv, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=False, + ) + out.backward(dout) + end = torch.cuda.Event(enable_timing=True) + end.record() + torch.cuda.synchronize(device=device) + time = begin.elapsed_time(end) / 1000.0 + + if rank == 0 and log: + print(f"{num_iter / time:.3f} iter/s, {time:.3f} sec") + + +if __name__ == "__main__": + dist.init_process_group("nccl") + rank = dist.get_rank() + + forward_only = False + + for f in [ + flash_attn_qkvpacked_func, + ring_flash_attn_qkvpacked_func, + zigzag_ring_flash_attn_qkvpacked_func, + stripe_flash_attn_qkvpacked_func, + ]: + torch.cuda.empty_cache() + if rank == 0: + print(f"# {f.__name__}") + benchmark(f, forward_only=forward_only, log=False) + benchmark(f, forward_only=forward_only, log=True) diff --git a/ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py b/ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py new file mode 100644 index 000000000000..18c8cafc0078 --- /dev/null +++ b/ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py @@ -0,0 +1,91 @@ +import torch +import torch.cuda +import torch.distributed as dist +from flash_attn import flash_attn_varlen_qkvpacked_func +from ring_flash_attn import ring_flash_attn_varlen_qkvpacked_func, zigzag_ring_flash_attn_varlen_qkvpacked_func + + +def benchmark(f, num_iter=100, forward_only=True, log=True): + dtype = torch.bfloat16 + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + seqlen = 1024 * 8 + nheads = 5 + d = 128 + dropout_p = 0 + causal = True + deterministic = False + + assert seqlen % (2 * world_size) == 0 + assert d % 8 == 0 + + qkv = torch.randn(seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) + dout = torch.randn(seqlen, nheads, d, device=device, dtype=dtype) + + cu_seqlens_list = [ + torch.tensor([0, 8192], device=device, dtype=torch.int32), + torch.tensor([0, 256, 7648, 8192], device=device, dtype=torch.int32), + torch.tensor([0, 4096, 8192], device=device, dtype=torch.int32), + torch.tensor([0, 3104, 6304, 7904, 8064, 8192], device=device, dtype=torch.int32), + ] + max_seqlen_list = [(cu_seqlens[1:] - cu_seqlens[:1]).max().item() for cu_seqlens in cu_seqlens_list] + + begin = torch.cuda.Event(enable_timing=True) + begin.record() + if forward_only: + with torch.no_grad(): + for i in range(num_iter): + _ = f( + qkv, + cu_seqlens_list[i % len(cu_seqlens_list)], + max_seqlen_list[i % len(max_seqlen_list)], + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=False, + ) + else: + for i in range(num_iter): + qkv.grad = None + out = f( + qkv, + cu_seqlens_list[i % len(cu_seqlens_list)], + max_seqlen_list[i % len(max_seqlen_list)], + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=False, + ) + out.backward(dout) + end = torch.cuda.Event(enable_timing=True) + end.record() + torch.cuda.synchronize(device=device) + time = begin.elapsed_time(end) / 1000.0 + + if rank == 0 and log: + print(f"{num_iter / time} iter/s, {time} sec") + + +if __name__ == "__main__": + dist.init_process_group("nccl") + rank = dist.get_rank() + + forward_only = False + + for f in [ + flash_attn_varlen_qkvpacked_func, + ring_flash_attn_varlen_qkvpacked_func, + zigzag_ring_flash_attn_varlen_qkvpacked_func, + ]: + torch.cuda.empty_cache() + if rank == 0: + print(f"# {f.__name__}") + benchmark(f, forward_only=forward_only, log=False) + benchmark(f, forward_only=forward_only, log=True) diff --git a/ring-flash-attention/ring_flash_attn/__init__.py b/ring-flash-attention/ring_flash_attn/__init__.py new file mode 100644 index 000000000000..01d5ec36218c --- /dev/null +++ b/ring-flash-attention/ring_flash_attn/__init__.py @@ -0,0 +1,16 @@ +from .ring_flash_attn import ring_flash_attn_func, ring_flash_attn_kvpacked_func, ring_flash_attn_qkvpacked_func +from .ring_flash_attn_varlen import ( + ring_flash_attn_varlen_func, + ring_flash_attn_varlen_kvpacked_func, + ring_flash_attn_varlen_qkvpacked_func, +) +from .stripe_flash_attn import stripe_flash_attn_func, stripe_flash_attn_kvpacked_func, stripe_flash_attn_qkvpacked_func +from .zigzag_ring_flash_attn import ( + zigzag_ring_flash_attn_func, + zigzag_ring_flash_attn_kvpacked_func, + zigzag_ring_flash_attn_qkvpacked_func, +) +from .zigzag_ring_flash_attn_varlen import ( + zigzag_ring_flash_attn_varlen_func, + zigzag_ring_flash_attn_varlen_qkvpacked_func, +) diff --git a/ring-flash-attention/ring_flash_attn/ring_flash_attn.py b/ring-flash-attention/ring_flash_attn/ring_flash_attn.py new file mode 100644 index 000000000000..b36484dbd145 --- /dev/null +++ b/ring-flash-attention/ring_flash_attn/ring_flash_attn.py @@ -0,0 +1,281 @@ +import torch +from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward + +from .utils import RingComm, update_out_and_lse + + +def ring_flash_attn_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + comm = RingComm(process_group) + + out = None + lse = None + + next_k, next_v = None, None + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k: torch.Tensor = comm.send_recv(k) + next_v: torch.Tensor = comm.send_recv(v) + comm.commit() + + if not causal or step <= comm.rank: + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( + q, + k, + v, + dropout_p, + softmax_scale, + causal=causal and step == 0, + window_size=window_size, + alibi_slopes=alibi_slopes, + return_softmax=True and dropout_p > 0, + ) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + + if step + 1 != comm.world_size: + comm.wait() + k = next_k + v = next_v + + out = out.to(q.dtype) + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +def ring_flash_attn_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + + block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + + next_dk, next_dv = None, None + next_k, next_v = None, None + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k = kv_comm.send_recv(k) + next_v = kv_comm.send_recv(v) + kv_comm.commit() + if step <= kv_comm.rank or not causal: + bwd_causal = causal and step == 0 + _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + block_dq_buffer, + block_dk_buffer, + block_dv_buffer, + dropout_p, + softmax_scale, + bwd_causal, + window_size, + alibi_slopes, + deterministic, + rng_state=None, + ) + + if dq is None: + dq = block_dq_buffer.to(torch.float32) + dk = block_dk_buffer.to(torch.float32) + dv = block_dv_buffer.to(torch.float32) + else: + dq += block_dq_buffer + d_kv_comm.wait() + dk = block_dk_buffer + next_dk + dv = block_dv_buffer + next_dv + elif step != 0: + d_kv_comm.wait() + dk = next_dk + dv = next_dv + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k = next_k + v = next_v + + next_dk = d_kv_comm.send_recv(dk) + next_dv = d_kv_comm.send_recv(dv) + d_kv_comm.commit() + + d_kv_comm.wait() + + return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class RingFlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + k = k.contiguous() + v = v.contiguous() + out, softmax_lse = ring_flash_attn_forward( + group, + q, + k, + v, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse = ctx.saved_tensors + dq, dk, dv = ring_flash_attn_backward( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None + + +def ring_flash_attn_qkvpacked_func( + qkv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return RingFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def ring_flash_attn_kvpacked_func( + q, + kv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return RingFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def ring_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return RingFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/ring-flash-attention/ring_flash_attn/ring_flash_attn_varlen.py b/ring-flash-attention/ring_flash_attn/ring_flash_attn_varlen.py new file mode 100644 index 000000000000..118bdea4c7d0 --- /dev/null +++ b/ring-flash-attention/ring_flash_attn/ring_flash_attn_varlen.py @@ -0,0 +1,318 @@ +import torch +from flash_attn.flash_attn_interface import _flash_attn_varlen_backward, _flash_attn_varlen_forward + +from .utils import RingComm, update_out_and_lse + +try: + from .triton_utils import flatten_varlen_lse, unflatten_varlen_lse +except: + from .utils import flatten_varlen_lse, unflatten_varlen_lse + + +def ring_flash_attn_varlen_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens, + max_seqlen, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + comm = RingComm(process_group) + + out = None + lse = None + next_k, next_v = None, None + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k: torch.Tensor = comm.send_recv(k) + next_v: torch.Tensor = comm.send_recv(v) + comm.commit() + if not causal or step <= comm.rank: + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_varlen_forward( + q, + k, + v, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + dropout_p, + softmax_scale, + causal=causal and step == 0, + window_size=window_size, + alibi_slopes=alibi_slopes, + return_softmax=True and dropout_p > 0, + ) + block_lse = flatten_varlen_lse( + block_lse, + cu_seqlens=cu_seqlens, + ) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + + if step + 1 != comm.world_size: + comm.wait() + k = next_k + v = next_v + + out = out.to(q.dtype) + lse = unflatten_varlen_lse(lse, cu_seqlens, max_seqlen) + return out, lse + + +def ring_flash_attn_varlen_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + cu_seqlens, + max_seqlen, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + + block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + + next_dk, next_dv = None, None + next_k, next_v = None, None + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k = kv_comm.send_recv(k) + next_v = kv_comm.send_recv(v) + kv_comm.commit() + if step <= kv_comm.rank or not causal: + bwd_causal = causal and step == 0 + _flash_attn_varlen_backward( + dout, + q, + k, + v, + out, + softmax_lse, + block_dq_buffer, + block_dk_buffer, + block_dv_buffer, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + dropout_p, + softmax_scale, + bwd_causal, + window_size, + alibi_slopes, + deterministic, + rng_state=None, + ) + + if dq is None: + dq = block_dq_buffer.to(torch.float32) + dk = block_dk_buffer.to(torch.float32) + dv = block_dv_buffer.to(torch.float32) + else: + dq += block_dq_buffer + d_kv_comm.wait() + dk = block_dk_buffer + next_dk + dv = block_dv_buffer + next_dv + elif step != 0: + d_kv_comm.wait() + dk = next_dk + dv = next_dv + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k = next_k + v = next_v + + next_dk = d_kv_comm.send_recv(dk) + next_dv = d_kv_comm.send_recv(dv) + d_kv_comm.commit() + + d_kv_comm.wait() + + return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class RingFlashAttnVarlenFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + k = k.contiguous() + v = v.contiguous() + out, softmax_lse = ring_flash_attn_varlen_forward( + group, + q, + k, + v, + cu_seqlens, + max_seqlen, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens) + ctx.max_seqlen = max_seqlen + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens = ctx.saved_tensors + dq, dk, dv = ring_flash_attn_varlen_backward( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + cu_seqlens, + ctx.max_seqlen, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None + + +def ring_flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens, + max_seqlen, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return RingFlashAttnVarlenFunc.apply( + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def ring_flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens, + max_seqlen, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return RingFlashAttnVarlenFunc.apply( + q, + kv[:, 0], + kv[:, 1], + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def ring_flash_attn_varlen_func( + q, + k, + v, + cu_seqlens, + max_seqlen, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return RingFlashAttnVarlenFunc.apply( + q, + k, + v, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/ring-flash-attention/ring_flash_attn/stripe_flash_attn.py b/ring-flash-attention/ring_flash_attn/stripe_flash_attn.py new file mode 100644 index 000000000000..ca426920f4ed --- /dev/null +++ b/ring-flash-attention/ring_flash_attn/stripe_flash_attn.py @@ -0,0 +1,325 @@ +import torch +from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward + +from .utils import RingComm, update_out_and_lse + + +def stripe_flash_attn_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal, "stripe flash attn only supports causal attention, if not causal, use ring flash attn instead" + comm = RingComm(process_group) + + out = None + lse = None + + next_k, next_v = None, None + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k: torch.Tensor = comm.send_recv(k) + next_v: torch.Tensor = comm.send_recv(v) + comm.commit() + + if step <= comm.rank: + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( + q, + k, + v, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + return_softmax=True and dropout_p > 0, + ) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + else: + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( + q[:, 1:], + k[:, :-1], + v[:, :-1], + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + return_softmax=True and dropout_p > 0, + ) + out, lse = update_out_and_lse(out, lse, block_out, block_lse, slice_=(slice(None), slice(1, None))) + + if step + 1 != comm.world_size: + comm.wait() + k = next_k + v = next_v + + out = out.to(q.dtype) + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +def stripe_flash_attn_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal, "stripe flash attn only supports causal attention, if not causal, ring flash attn instead" + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k = kv_comm.send_recv(k) + next_v = kv_comm.send_recv(v) + kv_comm.commit() + + shift_causal = step > kv_comm.rank + softmax_lse_1 = None + if not shift_causal: + _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + block_dq_buffer, + block_dk_buffer, + block_dv_buffer, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + rng_state=None, + ) + else: + if softmax_lse_1 is None: + # lazy init, since the last rank does not need softmax_lse_1 + softmax_lse_1 = softmax_lse[:, :, 1:].contiguous() + _flash_attn_backward( + dout[:, 1:], + q[:, 1:], + k[:, :-1], + v[:, :-1], + out[:, 1:], + softmax_lse_1, + block_dq_buffer[:, 1:], + block_dk_buffer[:, :-1], + block_dv_buffer[:, :-1], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + rng_state=None, + ) + + if dq is None: + dq = block_dq_buffer.to(torch.float32) + dk = block_dk_buffer.to(torch.float32) + dv = block_dv_buffer.to(torch.float32) + else: + if not shift_causal: + dq += block_dq_buffer + else: + dq[:, 1:] += block_dq_buffer[:, 1:] + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk = next_dk + dv = next_dv + + if not shift_causal: + dk = block_dk_buffer + dk + dv = block_dv_buffer + dv + else: + dk[:, :-1] += block_dk_buffer[:, :-1] + dv[:, :-1] += block_dv_buffer[:, :-1] + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k = next_k + v = next_v + + next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) + next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) + d_kv_comm.commit() + + d_kv_comm.wait() + + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class StripeFlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + k = k.contiguous() + v = v.contiguous() + out, softmax_lse = stripe_flash_attn_forward( + group, + q, + k, + v, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse = ctx.saved_tensors + dq, dk, dv = stripe_flash_attn_backward( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None + + +def stripe_flash_attn_qkvpacked_func( + qkv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return StripeFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def stripe_flash_attn_kvpacked_func( + q, + kv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return StripeFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def stripe_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return StripeFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/ring-flash-attention/ring_flash_attn/triton_utils.py b/ring-flash-attention/ring_flash_attn/triton_utils.py new file mode 100644 index 000000000000..66e362d93d68 --- /dev/null +++ b/ring-flash-attention/ring_flash_attn/triton_utils.py @@ -0,0 +1,137 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def flatten_kernel( + # pointers to matrices + OUT, + LSE, + CU_SEQLENS, + # strides + stride_out_nheads, + stride_out_seqlen, + stride_lse_batch, + stride_lse_nheads, + stride_lse_seqlen, + # meta-parameters + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + LSE = LSE + pid_batch * stride_lse_batch + pid_head * stride_lse_nheads + OUT = OUT + pid_head * stride_out_nheads + start_idx * stride_out_seqlen + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + + LSE = LSE + rm[:, None] * stride_lse_seqlen + x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) + + OUT = OUT + rm[:, None] * stride_out_seqlen + tl.store(OUT, x, mask=rm[:, None] < seqlen) + + +def flatten_varlen_lse(lse, cu_seqlens): + """ + Arguments: + lse: (batch_size, nheads, max_seqlen) + cu_seqlens: (batch_size + 1,) + Return: + flatten_lse: (nheads, total_seqlen) + """ + total_seqlen = cu_seqlens[-1] + batch_size, nheads, max_seqlen = lse.shape + output = torch.empty((nheads, total_seqlen), dtype=lse.dtype, device=lse.device) + + grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) + BLOCK_M = 4 + + with torch.cuda.device(lse.device.index): + flatten_kernel[grid]( + output, + lse, + cu_seqlens, + # strides + output.stride(0), + output.stride(1), + lse.stride(0), + lse.stride(1), + lse.stride(2), + BLOCK_M, + ) + return output + + +@triton.jit +def unflatten_kernel( + # pointers to matrices + OUT, + LSE, + CU_SEQLENS, + # strides + stride_out_batch, + stride_out_nheads, + stride_out_seqlen, + stride_lse_seqlen, + stride_lse_nheads, + # meta-parameters + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + LSE = LSE + pid_head * stride_lse_nheads + start_idx * stride_lse_seqlen + OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + + LSE = LSE + rm[:, None] * stride_lse_seqlen + x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) + + OUT = OUT + rm[:, None] * stride_out_seqlen + tl.store(OUT, x, mask=rm[:, None] < seqlen) + + +def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): + """ + Arguments: + lse: (total_seqlen, nheads, 1) + cu_seqlens: (batch_size + 1,) + max_seqlen: int + Return: + unflatten_lse: (batch_size, nheads, max_seqlen) + """ + lse = lse.unsqueeze(dim=-1) + batch_size = len(cu_seqlens) - 1 + nheads = lse.shape[1] + output = torch.empty( + (batch_size, nheads, max_seqlen), + dtype=lse.dtype, + device=lse.device, + ) + + grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) + BLOCK_M = 4 + + with torch.cuda.device(lse.device.index): + unflatten_kernel[grid]( + output, + lse, + cu_seqlens, + # strides + output.stride(0), + output.stride(1), + output.stride(2), + lse.stride(0), + lse.stride(1), + BLOCK_M, + ) + return output diff --git a/ring-flash-attention/ring_flash_attn/utils.py b/ring-flash-attention/ring_flash_attn/utils.py new file mode 100644 index 000000000000..787732af8135 --- /dev/null +++ b/ring-flash-attention/ring_flash_attn/utils.py @@ -0,0 +1,110 @@ +from typing import Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +__all__ = ["update_out_and_lse", "RingComm"] + + +@torch.jit.script +def _update_out_and_lse( + out: torch.Tensor, + lse: torch.Tensor, + block_out: torch.Tensor, + block_lse: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + + block_out = block_out.to(torch.float32) + block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + + # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out + # For additional context and discussion, please refer to: + # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 + out = out - F.sigmoid(block_lse - lse) * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) + + return out, lse + + +def update_out_and_lse( + out: Optional[torch.Tensor], + lse: Optional[torch.Tensor], + block_out: torch.Tensor, + block_lse: torch.Tensor, + slice_=None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if out is None: + if slice_ is not None: + raise RuntimeError("first update_out_and_lse should not pass slice_ args") + out = block_out.to(torch.float32) + lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + elif slice_ is not None: + slice_out, slice_lse = out[slice_], lse[slice_] + slice_out, slice_lse = _update_out_and_lse(slice_out, slice_lse, block_out, block_lse) + out[slice_], lse[slice_] = slice_out, slice_lse + else: + out, lse = _update_out_and_lse(out, lse, block_out, block_lse) + return out, lse + + +@torch.jit.script +def flatten_varlen_lse(lse, cu_seqlens): + new_lse = [] + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + new_lse.append(lse[i, :, : end - start]) + return torch.cat(new_lse, dim=1) + + +@torch.jit.script +def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): + num_seq = len(cu_seqlens) - 1 + num_head = lse.shape[-2] + new_lse = torch.empty((num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device) + for i in range(num_seq): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + new_lse[i, : end - start] = lse[start:end] + return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous() + + +class RingComm: + def __init__(self, process_group: dist.ProcessGroup): + self._process_group = process_group + self._ops = [] + self.rank = dist.get_rank(self._process_group) + self.world_size = dist.get_world_size(self._process_group) + self._reqs = None + + self.send_rank = (self.rank + 1) % self.world_size + self.recv_rank = (self.rank - 1) % self.world_size + + if process_group is not None: + self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) + self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) + + def send_recv(self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None) -> torch.Tensor: + if recv_tensor is None: + res = torch.empty_like(to_send) + else: + res = recv_tensor + + send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group) + recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) + self._ops.append(send_op) + self._ops.append(recv_op) + return res + + def commit(self): + if self._reqs is not None: + raise RuntimeError("commit called twice") + self._reqs = dist.batch_isend_irecv(self._ops) + + def wait(self): + if self._reqs is None: + raise RuntimeError("wait called before commit") + for req in self._reqs: + req.wait() + self._reqs = None + self._ops = [] diff --git a/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn.py b/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn.py new file mode 100644 index 000000000000..d3e2821c5d4d --- /dev/null +++ b/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn.py @@ -0,0 +1,327 @@ +import torch +import torch.distributed as dist +from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward + +from .utils import RingComm, update_out_and_lse + + +def zigzag_ring_flash_attn_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal == True, "zigzag ring is meaningless for causal=False" + comm = RingComm(process_group) + + block_seq_len = q.shape[1] // 2 + q1 = q[:, block_seq_len:] + + out = None + lse = None + next_k, next_v = None, None + + def forward(q, k, v, causal): + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( + q, + k, + v, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + return_softmax=True and dropout_p > 0, + ) + return block_out, block_lse + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k: torch.Tensor = comm.send_recv(k) + next_v: torch.Tensor = comm.send_recv(v) + comm.commit() + + if step == 0: + block_out, block_lse = forward(q, k, v, causal=True) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + elif step <= comm.rank: + k0 = k[:, :block_seq_len] + v0 = v[:, :block_seq_len] + block_out, block_lse = forward(q, k0, v0, causal=False) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + else: + block_out, block_lse = forward(q1, k, v, causal=False) + out, lse = update_out_and_lse( + out, + lse, + block_out, + block_lse, + slice_=(slice(None), slice(block_seq_len, None)), + ) + + if step + 1 != comm.world_size: + comm.wait() + k = next_k + v = next_v + + out = out.to(q.dtype) + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +def zigzag_ring_flash_attn_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal == True, "zigzag ring is meaningless for causal=False" + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + dout1 = dout.chunk(2, dim=1)[1] + q1 = q.chunk(2, dim=1)[1] + out1 = out.chunk(2, dim=1)[1] + softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous() + block_seq_len = q.shape[1] // 2 + + # repeatly allocating buffer may be slow... + dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + + def backward(dout, q, k, v, out, softmax_lse, causal): + seqlen_q = q.shape[1] + seqlen_kv = k.shape[1] + _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq_buffer[:, :seqlen_q], + dk_buffer[:, :seqlen_kv], + dv_buffer[:, :seqlen_kv], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + rng_state=None, + ) + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k = kv_comm.send_recv(k) + next_v = kv_comm.send_recv(v) + kv_comm.commit() + + if step == 0: + backward(dout, q, k, v, out, softmax_lse, causal=True) + dq = dq_buffer.to(torch.float32) + dk = dk_buffer.to(torch.float32) + dv = dv_buffer.to(torch.float32) + else: + if step <= kv_comm.rank: + k0 = k[:, :block_seq_len] + v0 = v[:, :block_seq_len] + backward(dout, q, k0, v0, out, softmax_lse, causal=False) + dq += dq_buffer + else: + backward(dout1, q1, k, v, out1, softmax_lse1, causal=False) + # always use the first half in dq_buffer. + dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len] + + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + if step <= kv_comm.rank: + dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len] + dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len] + else: + dk += dk_buffer + dv += dv_buffer + if dist.get_rank() == 0: + torch.save(torch.stack((dk, dv)), f"step_{step}.pt") + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k = next_k + v = next_v + + next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) + next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) + d_kv_comm.commit() + + d_kv_comm.wait() + + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class ZigZagRingFlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + k = k.contiguous() + v = v.contiguous() + out, softmax_lse = zigzag_ring_flash_attn_forward( + group, + q, + k, + v, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse = ctx.saved_tensors + dq, dk, dv = zigzag_ring_flash_attn_backward( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None + + +def zigzag_ring_flash_attn_qkvpacked_func( + qkv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def zigzag_ring_flash_attn_kvpacked_func( + q, + kv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def zigzag_ring_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py b/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py new file mode 100644 index 000000000000..5d4a8dd2daf0 --- /dev/null +++ b/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py @@ -0,0 +1,441 @@ +import torch +from flash_attn.flash_attn_interface import _flash_attn_varlen_backward, _flash_attn_varlen_forward + +from .utils import RingComm, update_out_and_lse + +try: + from .triton_utils import flatten_varlen_lse, unflatten_varlen_lse +except: + from .utils import flatten_varlen_lse, unflatten_varlen_lse + + +def get_half_index(cu_seqlens, *, front: bool): + if len(cu_seqlens) == 2: + if front: + return slice(None, cu_seqlens[-1] // 2) + else: + return slice(cu_seqlens[-1] // 2, None) + + index = torch.zeros((cu_seqlens[-1],), dtype=bool) + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + if front: + end = (start + end) // 2 + else: + start = (start + end) // 2 + index[start:end] = True + return index + + +@torch.jit.script +def get_half_lse(lse, cu_seqlens, *, front: bool): + new_lse = torch.empty( + (lse.shape[0], lse.shape[1], lse.shape[2] // 2), + dtype=lse.dtype, + device=lse.device, + ) + for i in range(len(cu_seqlens) - 1): + seqlen = (cu_seqlens[i + 1] - cu_seqlens[i]).item() + if front: + start, end = 0, seqlen // 2 + else: + start, end = seqlen // 2, seqlen + new_lse[i, :, : seqlen // 2] = lse[i, :, start:end] + return new_lse + + +def zigzag_ring_flash_attn_varlen_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens, + max_seqlen, + half_index0, + half_index1, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal == True, "zigzag ring is meaningless for causal=False" + comm = RingComm(process_group) + + block_seq_len = q.shape[0] // 2 + q1 = q[half_index1] + + out = None + lse = None + next_k, next_v = None, None + half_cu_seqlens = cu_seqlens // 2 + half_max_seqlen = max_seqlen // 2 + + def forward(q, k, v, causal): + seqlen_q = q.shape[0] + seqlen_kv = k.shape[0] + cu_seqlens_q = half_cu_seqlens if seqlen_q == block_seq_len else cu_seqlens + max_seqlen_q = half_max_seqlen if seqlen_q == block_seq_len else max_seqlen + cu_seqlens_kv = half_cu_seqlens if seqlen_kv == block_seq_len else cu_seqlens + max_seqlen_kv = half_max_seqlen if seqlen_kv == block_seq_len else max_seqlen + block_out, _, _, _, _, block_lse, _, _ = _flash_attn_varlen_forward( + q, + k, + v, + # the first half and the second half are the same + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + return_softmax=True and dropout_p > 0, + ) + return block_out, block_lse + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k: torch.Tensor = comm.send_recv(k) + next_v: torch.Tensor = comm.send_recv(v) + comm.commit() + + if step == 0: + block_out, block_lse = forward(q, k, v, causal=True) + block_lse = flatten_varlen_lse( + block_lse, + cu_seqlens=cu_seqlens, + ) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + elif step <= comm.rank: + k0 = k[half_index0] + v0 = v[half_index0] + block_out, block_lse = forward(q, k0, v0, causal=False) + block_lse = flatten_varlen_lse( + block_lse, + cu_seqlens=cu_seqlens, + ) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + else: + block_out, block_lse = forward(q1, k, v, causal=False) + block_lse = flatten_varlen_lse( + block_lse, + cu_seqlens=half_cu_seqlens, + ) + out[half_index1], lse[half_index1] = update_out_and_lse( + out[half_index1], lse[half_index1], block_out, block_lse + ) + + if step + 1 != comm.world_size: + comm.wait() + k = next_k + v = next_v + + out = out.to(q.dtype) + lse = unflatten_varlen_lse(lse, cu_seqlens, max_seqlen) + return out, lse + + +def zigzag_ring_flash_attn_varlen_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + cu_seqlens, + max_seqlen, + half_index0, + half_index1, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal == True, "zigzag ring is meaningless for causal=False" + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + dout1 = dout[half_index1] + q1 = q[half_index1] + out1 = out[half_index1] + softmax_lse1 = get_half_lse(softmax_lse, cu_seqlens, front=False) + block_seq_len = q.shape[0] // 2 + + half_cu_seqlens = cu_seqlens // 2 + half_max_seqlen = max_seqlen // 2 + + # repeatly allocating buffer may be slow... + dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + + def backward(dout, q, k, v, out, softmax_lse, causal): + seqlen_q = q.shape[0] + seqlen_kv = k.shape[0] + cu_seqlens_q = half_cu_seqlens if seqlen_q == block_seq_len else cu_seqlens + max_seqlen_q = half_max_seqlen if seqlen_q == block_seq_len else max_seqlen + cu_seqlens_kv = half_cu_seqlens if seqlen_kv == block_seq_len else cu_seqlens + max_seqlen_kv = half_max_seqlen if seqlen_kv == block_seq_len else max_seqlen + _flash_attn_varlen_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq_buffer[:seqlen_q], + dk_buffer[:seqlen_kv], + dv_buffer[:seqlen_kv], + # the first half and the second half are the same + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + rng_state=None, + ) + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k = kv_comm.send_recv(k) + next_v = kv_comm.send_recv(v) + kv_comm.commit() + + if step == 0: + backward(dout, q, k, v, out, softmax_lse, causal=True) + dq = dq_buffer.to(torch.float32) + dk = dk_buffer.to(torch.float32) + dv = dv_buffer.to(torch.float32) + else: + if step <= kv_comm.rank: + k0 = k[half_index0] + v0 = v[half_index0] + backward(dout, q, k0, v0, out, softmax_lse, causal=False) + dq += dq_buffer + else: + backward(dout1, q1, k, v, out1, softmax_lse1, causal=False) + dq[half_index1] += dq_buffer[:block_seq_len] + + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + if step <= kv_comm.rank: + dk[half_index0] += dk_buffer[:block_seq_len] + dv[half_index0] += dv_buffer[:block_seq_len] + else: + dk += dk_buffer + dv += dv_buffer + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k = next_k + v = next_v + + next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) + next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) + d_kv_comm.commit() + + d_kv_comm.wait() + + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class ZigZagRingFlashAttnVarlenFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + k = k.contiguous() + v = v.contiguous() + half_index0 = get_half_index(cu_seqlens, front=True) + half_index1 = get_half_index(cu_seqlens, front=False) + out, softmax_lse = zigzag_ring_flash_attn_varlen_forward( + group, + q, + k, + v, + cu_seqlens, + max_seqlen, + half_index0, + half_index1, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + is_half_index_tensor = isinstance(half_index0, torch.Tensor) + ctx.is_half_index_tensor = is_half_index_tensor + if is_half_index_tensor: + ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens, half_index0, half_index1) + else: + ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens) + ctx.half_index0 = half_index0 + ctx.half_index1 = half_index1 + ctx.max_seqlen = max_seqlen + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + if ctx.is_half_index_tensor: + (q, k, v, out, softmax_lse, cu_seqlens, half_index0, half_index1) = ctx.saved_tensors + else: + q, k, v, out, softmax_lse, cu_seqlens = ctx.saved_tensors + half_index0 = ctx.half_index0 + half_index1 = ctx.half_index1 + dq, dk, dv = zigzag_ring_flash_attn_varlen_backward( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + cu_seqlens, + ctx.max_seqlen, + half_index0, + half_index1, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None + + +def zigzag_ring_flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens, + max_seqlen, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnVarlenFunc.apply( + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def zigzag_ring_flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens, + max_seqlen, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnVarlenFunc.apply( + q, + kv[:, 0], + kv[:, 1], + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def zigzag_ring_flash_attn_varlen_func( + q, + k, + v, + cu_seqlens, + max_seqlen, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnVarlenFunc.apply( + q, + k, + v, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/ring-flash-attention/setup.py b/ring-flash-attention/setup.py new file mode 100644 index 000000000000..58413e1b54f3 --- /dev/null +++ b/ring-flash-attention/setup.py @@ -0,0 +1,9 @@ +from setuptools import find_packages, setup + +setup( + name="ring_flash_attn", + version="0.1", + author="zhuzilin", + url="https://github.com/zhuzilin/ring-flash-attention", + packages=find_packages(), +) diff --git a/ring-flash-attention/test/test_ring_flash_attn_func.py b/ring-flash-attention/test/test_ring_flash_attn_func.py new file mode 100644 index 000000000000..50edd03bef4e --- /dev/null +++ b/ring-flash-attention/test/test_ring_flash_attn_func.py @@ -0,0 +1,124 @@ +import random + +import torch +import torch.distributed as dist +from flash_attn import flash_attn_qkvpacked_func +from ring_flash_attn import ring_flash_attn_qkvpacked_func + + +def set_seed(rank, seed=42): + seed = rank + seed + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def log(msg, a, rank0_only=False): + world_size = dist.get_world_size() + rank = dist.get_rank() + if rank0_only: + if rank == 0: + print( + f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", + flush=True, + ) + return + + for i in range(world_size): + if i == rank: + if rank == 0: + print(f"{msg}:") + print( + f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", + flush=True, + ) + dist.barrier() + + +if __name__ == "__main__": + dist.init_process_group("nccl") + rank = dist.get_rank() + set_seed(rank) + world_size = dist.get_world_size() + dtype = torch.bfloat16 + device = torch.device(f"cuda:{rank}") + + batch_size = 1 + seqlen = 3816 + nheads = 5 + d = 128 + dropout_p = 0 + causal = True + deterministic = False + + assert seqlen % world_size == 0 + assert d % 8 == 0 + + qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) + dist.broadcast(qkv, src=0) + + dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) + dist.broadcast(dout, src=0) + + local_qkv = qkv.chunk(world_size, dim=1)[rank].detach().clone() + local_qkv.requires_grad = True + local_dout = dout.chunk(world_size, dim=1)[rank].detach().clone() + + dist.barrier() + if rank == 0: + print("#" * 30) + print("# forward:") + print("#" * 30) + + out, lse, _ = flash_attn_qkvpacked_func( + qkv, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=True, + ) + + local_out = out.chunk(world_size, dim=1)[rank] + local_lse = lse.chunk(world_size, dim=-1)[rank] + + fn = ring_flash_attn_qkvpacked_func + + ring_out, ring_lse, _ = fn( + local_qkv, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=True, + ) + + log("out", out, rank0_only=True) + log("lse", lse, rank0_only=True) + log("out diff", local_out - ring_out) + log("lse diff", local_lse - ring_lse) + + dist.barrier() + if rank == 0: + print("#" * 30) + print("# backward:") + print("#" * 30) + + out.backward(dout) + dqkv = qkv.grad + local_dqkv = dqkv.chunk(world_size, dim=1)[rank] + + ring_out.backward(local_dout) + ring_dqkv = local_qkv.grad + + log("local_dq", local_dqkv[:, :, 0, :]) + log("dq diff", local_dqkv[:, :, 0, :] - ring_dqkv[:, :, 0, :]) + + log("local_dk", local_dqkv[:, :, 1, :]) + log("dk diff", local_dqkv[:, :, 1, :] - ring_dqkv[:, :, 1, :]) + + log("local_dv", local_dqkv[:, :, 2, :]) + log("dv diff", local_dqkv[:, :, 2, :] - ring_dqkv[:, :, 2, :]) diff --git a/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py b/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py new file mode 100644 index 000000000000..51bb1ec5d67d --- /dev/null +++ b/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py @@ -0,0 +1,157 @@ +import random + +import torch +import torch.distributed as dist +from flash_attn import flash_attn_varlen_qkvpacked_func +from ring_flash_attn import ring_flash_attn_varlen_qkvpacked_func + + +def set_seed(rank, seed=42): + seed = rank + seed + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def log(msg, a, rank0_only=False): + world_size = dist.get_world_size() + rank = dist.get_rank() + if rank0_only: + if rank == 0: + print( + f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", + flush=True, + ) + return + + for i in range(world_size): + if i == rank: + if rank == 0: + print(f"{msg}:") + print( + f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", + flush=True, + ) + dist.barrier() + + +def extract_local(value, cu_seqlens, rank, world_size): + local_values = [] + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + local_value = value[start:end].chunk(world_size, dim=0)[rank].detach().clone() + local_values.append(local_value) + return torch.cat(local_values, dim=0).contiguous() + + +def extract_lse(lse, cu_seqlens): + values = [] + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + value = lse[i, :, : end - start] + values.append(value) + return values + + +if __name__ == "__main__": + dist.init_process_group("nccl") + rank = dist.get_rank() + set_seed(rank) + world_size = dist.get_world_size() + dtype = torch.bfloat16 + device = torch.device(f"cuda:{rank}") + + batch_size = 1 + nheads = 5 + d = 128 + dropout_p = 0 + causal = True + deterministic = False + + cu_seqlens = [0, 120, 1248, 4232] + cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item() + total_length = cu_seqlens[-1] + num_seq = len(cu_seqlens) - 1 + + assert torch.all(cu_seqlens_tensor % world_size == 0) + assert d % 8 == 0 + + qkv = torch.randn(total_length, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) + dist.broadcast(qkv, src=0) + + dout = torch.randn(total_length, nheads, d, device=device, dtype=dtype) + dist.broadcast(dout, src=0) + + local_cu_seqlens_tensor = cu_seqlens_tensor // world_size + local_max_seqlen = max_seqlen // world_size + + local_qkv = extract_local(qkv, cu_seqlens, rank, world_size) + local_qkv.requires_grad = True + local_dout = extract_local(dout, cu_seqlens, rank, world_size) + + dist.barrier() + if rank == 0: + print("#" * 30) + print("# forward:") + print("#" * 30) + + out, lse, _ = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens_tensor, + max_seqlen, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=True, + ) + + local_out = extract_local(out, cu_seqlens, rank, world_size) + lse_list = extract_lse(lse, cu_seqlens) + + ring_out, ring_lse, _ = ring_flash_attn_varlen_qkvpacked_func( + local_qkv, + local_cu_seqlens_tensor, + local_max_seqlen, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=True, + ) + + ring_lse_list = extract_lse(ring_lse, local_cu_seqlens_tensor.tolist()) + + log("out", out, rank0_only=True) + log("out diff", local_out - ring_out) + + for lse, ring_lse in zip(lse_list, ring_lse_list): + local_lse = lse.chunk(world_size, dim=-1)[rank] + log("lse", lse, rank0_only=True) + log("lse diff", local_lse - ring_lse) + + dist.barrier() + if rank == 0: + print("#" * 30) + print("# backward:") + print("#" * 30) + + out.backward(dout) + dqkv = qkv.grad + local_dqkv = extract_local(dqkv, cu_seqlens, rank, world_size) + + ring_out.backward(local_dout) + ring_dqkv = local_qkv.grad + + log("local_dq", local_dqkv[:, 0]) + log("dq diff", local_dqkv[:, 0] - ring_dqkv[:, 0]) + + log("local_dk", local_dqkv[:, 1]) + log("dk diff", local_dqkv[:, 1] - ring_dqkv[:, 1]) + + log("local_dv", local_dqkv[:, 2]) + log("dv diff", local_dqkv[:, 2] - ring_dqkv[:, 2]) diff --git a/ring-flash-attention/test/test_stripe_flash_attn_func.py b/ring-flash-attention/test/test_stripe_flash_attn_func.py new file mode 100644 index 000000000000..dc9f5248d69d --- /dev/null +++ b/ring-flash-attention/test/test_stripe_flash_attn_func.py @@ -0,0 +1,130 @@ +import random + +import torch +import torch.distributed as dist +from flash_attn import flash_attn_qkvpacked_func +from ring_flash_attn import stripe_flash_attn_qkvpacked_func + + +def set_seed(rank, seed=42): + seed = rank + seed + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def log(msg, a, rank0_only=False): + world_size = dist.get_world_size() + rank = dist.get_rank() + if rank0_only: + if rank == 0: + print( + f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", + flush=True, + ) + return + + for i in range(world_size): + if i == rank: + if rank == 0: + print(f"{msg}:") + print( + f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", + flush=True, + ) + dist.barrier() + + +def extract_local(value, rank, world_size, dim=1): + value = torch.stack(value.split(world_size, dim=dim), dim=dim).transpose(dim, dim + 1) + slicer = [rank if i == dim else slice(None) for i in range(len(value.shape))] + return value[slicer].contiguous() + + +if __name__ == "__main__": + dist.init_process_group("nccl") + rank = dist.get_rank() + set_seed(rank) + world_size = dist.get_world_size() + dtype = torch.bfloat16 + device = torch.device(f"cuda:{rank}") + + batch_size = 1 + seqlen = 3824 + nheads = 5 + d = 128 + dropout_p = 0 + causal = True + deterministic = False + + assert causal + assert seqlen % (2 * world_size) == 0 + assert d % 8 == 0 + + qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) + dist.broadcast(qkv, src=0) + + dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) + dist.broadcast(dout, src=0) + + local_qkv = extract_local(qkv, rank, world_size).detach().clone() + local_qkv.requires_grad = True + local_dout = extract_local(dout, rank, world_size).detach().clone() + + dist.barrier() + if rank == 0: + print("#" * 30) + print("# forward:") + print("#" * 30) + + out, lse, _ = flash_attn_qkvpacked_func( + qkv, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=True, + ) + + local_out = extract_local(out, rank, world_size) + local_lse = extract_local(lse, rank, world_size, dim=2) + + ring_out, ring_lse, _ = stripe_flash_attn_qkvpacked_func( + local_qkv, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=True, + ) + + log("out", out, rank0_only=True) + log("lse", lse, rank0_only=True) + log("out diff", local_out - ring_out) + log("lse diff", local_lse - ring_lse) + + dist.barrier() + if rank == 0: + print("#" * 30) + print("# backward:") + print("#" * 30) + + out.backward(dout) + dqkv = qkv.grad + + local_dqkv = extract_local(dqkv, rank, world_size) + + ring_out.backward(local_dout) + ring_dqkv = local_qkv.grad + + log("local_dq", local_dqkv[:, :, 0, :]) + log("dq diff", local_dqkv[:, :, 0, :] - ring_dqkv[:, :, 0, :]) + + log("local_dk", local_dqkv[:, :, 1, :]) + log("dk0 diff", local_dqkv[:, :, 1, :] - ring_dqkv[:, :, 1, :]) + + log("local_dv", local_dqkv[:, :, 2, :]) + log("dv diff", local_dqkv[:, :, 2, :] - ring_dqkv[:, :, 2, :]) diff --git a/ring-flash-attention/test/test_triton_kernels.py b/ring-flash-attention/test/test_triton_kernels.py new file mode 100644 index 000000000000..aa1c1fdcd338 --- /dev/null +++ b/ring-flash-attention/test/test_triton_kernels.py @@ -0,0 +1,30 @@ +import torch +from ring_flash_attn.triton_utils import flatten_varlen_lse as triton_flatten_varlen_lse +from ring_flash_attn.triton_utils import unflatten_varlen_lse as triton_unflatten_varlen_lse +from ring_flash_attn.utils import flatten_varlen_lse, unflatten_varlen_lse + +if __name__ == "__main__": + device = torch.device("cuda:0") + + cu_seqlens = [0, 15, 156, 529] + cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + batch_size = len(cu_seqlens) - 1 + max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item() + n_head = 5 + + lse = torch.randn((batch_size, n_head, max_seqlen), dtype=torch.float32, device=device) + flatten_lse = flatten_varlen_lse(lse, cu_seqlens_tensor) + triton_flatten_lse = triton_flatten_varlen_lse(lse, cu_seqlens_tensor) + assert torch.all(flatten_lse == triton_flatten_lse) + + flatten_lse = flatten_lse.transpose(-2, -1).unsqueeze(dim=-1) + triton_flatten_lse = triton_flatten_lse.transpose(-2, -1).unsqueeze(dim=-1) + + unflatten_lse = unflatten_varlen_lse(flatten_lse, cu_seqlens_tensor, max_seqlen) + triton_unflatten_lse = triton_unflatten_varlen_lse(triton_flatten_lse, cu_seqlens_tensor, max_seqlen) + + for i in range(batch_size): + seqlen = cu_seqlens[i + 1] - cu_seqlens[i] + assert torch.all( + unflatten_lse[i, :, :seqlen] == triton_unflatten_lse[i, :, :seqlen] + ), f"{unflatten_lse[i, :seqlen]} vs {triton_unflatten_lse[i, :seqlen]}" diff --git a/ring-flash-attention/test/test_zigzag_ring_flash_attn_func.py b/ring-flash-attention/test/test_zigzag_ring_flash_attn_func.py new file mode 100644 index 000000000000..5f84bc58cf10 --- /dev/null +++ b/ring-flash-attention/test/test_zigzag_ring_flash_attn_func.py @@ -0,0 +1,150 @@ +import os +import random + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from flash_attn import flash_attn_qkvpacked_func +from ring_flash_attn import zigzag_ring_flash_attn_qkvpacked_func + +from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention + + +def set_seed(rank, seed=42): + seed = rank + seed + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def log(msg, a, rank0_only=False): + world_size = dist.get_world_size() + rank = dist.get_rank() + if rank0_only: + if rank == 0: + print( + f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", + flush=True, + ) + return + + for i in range(world_size): + if i == rank: + if rank == 0: + print(f"{msg}:") + print( + f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", + flush=True, + ) + dist.barrier() + + +def extract_local(value, rank, world_size, dim=1): + value_chunks = value.chunk(2 * world_size, dim=dim) + local_value = torch.cat([value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim) + return local_value.contiguous() + + +def run_test(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" # or the IP of the master node + os.environ["MASTER_PORT"] = "8125" # make sure this port is free + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + set_seed(rank) + dtype = torch.bfloat16 + device = torch.device(f"cuda:{rank}") + + batch_size = 1 + seqlen = 3824 + nheads = 5 + d = 128 + dropout_p = 0 + causal = True + deterministic = False + + assert causal + assert seqlen % (2 * world_size) == 0 + assert d % 8 == 0 + + qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) + dist.broadcast(qkv, src=0) + + dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) + dist.broadcast(dout, src=0) + + local_qkv = extract_local(qkv, rank, world_size).detach().clone() + local_qkv.requires_grad = True + extract_local(dout, rank, world_size).detach().clone() + + dist.barrier() + if rank == 0: + print("#" * 30) + print("# forward:") + print("#" * 30) + + out, lse, _ = flash_attn_qkvpacked_func( + qkv, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=True, + ) + + local_out = extract_local(out, rank, world_size) + # local_lse = extract_local(lse, rank, world_size, dim=2) + q, k, v = local_qkv.chunk(3, dim=2) + q, k, v = [x.squeeze(2).detach().clone().transpose(1, 2) for x in (q, k, v)] + q.requires_grad = k.requires_grad = v.requires_grad = True + sp_stream = torch.cuda.Stream() + sp_group = dist.new_group() + colo_out = RingAttention.attention(q, k, v, sp_group, sp_stream, AttnMaskType.CAUSAL) + + ring_out, ring_lse, _ = zigzag_ring_flash_attn_qkvpacked_func( + local_qkv, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=True, + ) + log("colo_out", colo_out, rank0_only=True) + log("ring_out", ring_out, rank0_only=True) + # log("lse", lse, rank0_only=True) + log("colo_out - ring_out", colo_out - ring_out) + # log("lse diff", local_lse - ring_lse) + log("ring_out - local_out", ring_out - local_out) + log("colo_out - local_out", colo_out - local_out) + + dist.barrier() + if rank == 0: + print("#" * 30) + print("# backward:") + print("#" * 30) + + colo_out.sum().backward() + qkv.grad + # q, k, v = [x.transpose(1, 2) for x in (q, k, v)] + colo_dq, colo_dk, colo_dv = [x.transpose(1, 2) for x in (q.grad, k.grad, v.grad)] + + ring_out.sum().backward() + ring_dqkv = local_qkv.grad + out.sum().backward() + dqkv = extract_local(qkv.grad, rank, world_size) + + # log("colo_dq", colo_dq) + log("dq diff", colo_dq - ring_dqkv[:, :, 0, :]) + + # log("colo_dk", colo_dk) + log("dk diff", colo_dk - ring_dqkv[:, :, 1, :]) + + # log("colo_dv", colo_dv) + log("dv diff", colo_dv - ring_dqkv[:, :, 2, :]) + log("colo_dv - local_dv", colo_dv - dqkv[:, :, 2, :]) + + +if __name__ == "__main__": + world_size = 4 + mp.spawn(run_test, args=(world_size,), nprocs=world_size, join=True) diff --git a/ring-flash-attention/test/test_zigzag_ring_flash_attn_varlen_func.py b/ring-flash-attention/test/test_zigzag_ring_flash_attn_varlen_func.py new file mode 100644 index 000000000000..7f6eced6e57b --- /dev/null +++ b/ring-flash-attention/test/test_zigzag_ring_flash_attn_varlen_func.py @@ -0,0 +1,163 @@ +import random + +import torch +import torch.distributed as dist +from flash_attn import flash_attn_varlen_qkvpacked_func +from ring_flash_attn import zigzag_ring_flash_attn_varlen_qkvpacked_func + + +def set_seed(rank, seed=42): + seed = rank + seed + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def log(msg, a, rank0_only=False): + world_size = dist.get_world_size() + rank = dist.get_rank() + if rank0_only: + if rank == 0: + print( + f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", + flush=True, + ) + return + + for i in range(world_size): + if i == rank: + if rank == 0: + print(f"{msg}:") + print( + f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", + flush=True, + ) + dist.barrier() + + +def extract_local(value, cu_seqlens, rank, world_size): + local_values = [] + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + local_value = value[start:end].chunk(2 * world_size, dim=0) + local_values.extend( + [ + local_value[rank].detach().clone(), + local_value[2 * world_size - 1 - rank].detach().clone(), + ] + ) + return torch.cat(local_values, dim=0).contiguous() + + +def extract_lse(lse, cu_seqlens): + values = [] + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + value = lse[i, :, : end - start] + values.append(value) + return values + + +if __name__ == "__main__": + dist.init_process_group("nccl") + rank = dist.get_rank() + set_seed(rank) + world_size = dist.get_world_size() + dtype = torch.bfloat16 + device = torch.device(f"cuda:{rank}") + + batch_size = 1 + nheads = 5 + d = 128 + dropout_p = 0 + causal = True + deterministic = False + + cu_seqlens = [0, 128, 1248, 4240] + cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item() + total_length = cu_seqlens[-1] + num_seq = len(cu_seqlens) - 1 + + assert torch.all(cu_seqlens_tensor % (2 * world_size) == 0) + assert d % 8 == 0 + + qkv = torch.randn(total_length, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) + dist.broadcast(qkv, src=0) + + dout = torch.randn(total_length, nheads, d, device=device, dtype=dtype) + dist.broadcast(dout, src=0) + + local_cu_seqlens_tensor = cu_seqlens_tensor // world_size + local_max_seqlen = max_seqlen // world_size + + local_qkv = extract_local(qkv, cu_seqlens, rank, world_size) + local_qkv.requires_grad = True + local_dout = extract_local(dout, cu_seqlens, rank, world_size) + + dist.barrier() + if rank == 0: + print("#" * 30) + print("# forward:") + print("#" * 30) + + out, lse, _ = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens_tensor, + max_seqlen, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=True, + ) + + local_out = extract_local(out, cu_seqlens, rank, world_size) + lse_list = extract_lse(lse, cu_seqlens) + + ring_out, ring_lse, _ = zigzag_ring_flash_attn_varlen_qkvpacked_func( + local_qkv, + local_cu_seqlens_tensor, + local_max_seqlen, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=True, + ) + + ring_lse_list = extract_lse(ring_lse, local_cu_seqlens_tensor.tolist()) + + log("out", out, rank0_only=True) + log("out diff", local_out - ring_out) + + for i, (lse, ring_lse) in enumerate(zip(lse_list, ring_lse_list)): + local_lse = lse.chunk(2 * world_size, dim=-1) + local_lse = torch.cat([local_lse[rank], local_lse[2 * world_size - 1 - rank]], dim=-1) + log(f"lse {i}", lse, rank0_only=True) + log(f"lse diff {i}", local_lse - ring_lse) + + dist.barrier() + if rank == 0: + print("#" * 30) + print("# backward:") + print("#" * 30) + + out.backward(dout) + dqkv = qkv.grad + local_dqkv = extract_local(dqkv, cu_seqlens, rank, world_size) + + ring_out.backward(local_dout) + ring_dqkv = local_qkv.grad + + log("local_dq", local_dqkv[:, 0]) + log("dq diff", local_dqkv - ring_dqkv) + + log("local_dk", local_dqkv[:, 1]) + log("dk diff", local_dqkv[:, 1] - ring_dqkv[:, 1]) + + log("local_dv", local_dqkv[:, 2]) + log("dv diff", local_dqkv[:, 2] - ring_dqkv[:, 2]) diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 943c5cf1c58e..db69c9818411 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -80,6 +80,7 @@ def data_gen_for_causal_lm(): data_gen_fn=data_gen_for_causal_lm, output_transform_fn=output_transform_fn, loss_fn=loss_fn_for_causal_lm, +<<<<<<< HEAD model_attribute=ModelAttribute(has_control_flow=True), ) model_zoo.register( @@ -96,6 +97,8 @@ def data_gen_for_causal_lm(): data_gen_fn=data_gen, output_transform_fn=output_transform_fn, loss_fn=loss_fn, +======= +>>>>>>> precision tests passed model_attribute=ModelAttribute(has_control_flow=True), ) model_zoo.register( diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 5ca618bc8535..805808887db9 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -1,5 +1,6 @@ import torch import torch.distributed as dist +<<<<<<< HEAD import torch.nn.functional as F from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_func from torch.testing import assert_close @@ -31,11 +32,44 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): # Setup inputs qkv = torch.randn(bs, seq_len, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) local_qkv = split_batch_zigzag(qkv, sp_group) +======= +from flash_attn import flash_attn_qkvpacked_func +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention +from colossalai.shardformer.layer.utils import zigzag_split_batch +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +@parameterize("seq_len", [4096]) +@parameterize("batch_size", [1]) +@parameterize("nheads", [5]) +@parameterize("d", [128]) +@parameterize("dtype", [torch.bfloat16]) +def test_ring_attn(seq_len, batch_size, nheads, d, dtype): + torch.cuda.manual_seed(2) + rank = dist.get_rank() + device = torch.device(f"cuda:{rank}") + sp_group = dist.group.WORLD + sp_stream = torch.cuda.Stream() + + # Some outliers may seem large, but our errors are still much lower than + # than Megatron-LM's context parallel + # https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215 + # and the original zigzag implementation: https://github.com/zhuzilin/ring-flash-attention/tree/main + atol = rtol = 7e-3 + + # Setup inputs + qkv = torch.randn(batch_size, seq_len, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) + local_qkv = zigzag_split_batch(qkv, sp_group) +>>>>>>> precision tests passed q, k, v = local_qkv.unbind(dim=-3) q, k, v = [x.squeeze(2).detach().clone().transpose(1, 2) for x in (q, k, v)] # (B, nHeads, Sq, D) q.requires_grad = k.requires_grad = v.requires_grad = True # Ring attention vs single GPU +<<<<<<< HEAD ring_out, ring_lse = RingAttention.attention( q, k, @@ -47,10 +81,14 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): # inner_ring_size=4 ) ring_out = ring_out.transpose(1, 2) +======= + ring_out, ring_lse = RingAttention.attention(q, k, v, sp_group, sp_stream, AttnMaskType.CAUSAL, return_softmax=True) +>>>>>>> precision tests passed out, lse, _ = flash_attn_qkvpacked_func( qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True ) +<<<<<<< HEAD # Checkout out and softmax denominator local_out = split_batch_zigzag(out, sp_group) local_lse = split_batch_zigzag(lse, sp_group, seq_dim=-1) @@ -59,10 +97,18 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): assert_close(ring_out, local_out, atol=atol, rtol=rtol) # Check grads +======= + local_out = zigzag_split_batch(out, sp_group) + local_lse = zigzag_split_batch(lse, sp_group, seq_dim=-1) + assert_close(ring_out, local_out, atol=atol, rtol=rtol) + assert_close(ring_lse, local_lse, atol=atol, rtol=rtol) + +>>>>>>> precision tests passed ring_out.sum().backward() out.sum().backward() ring_dq, ring_dk, ring_dv = [x.transpose(1, 2) for x in (q.grad, k.grad, v.grad)] dqkv = qkv.grad +<<<<<<< HEAD local_dqkv = split_batch_zigzag(dqkv, sp_group) assert_close(ring_dq, local_dqkv[:, :, 0], atol=atol, rtol=rtol) @@ -185,3 +231,23 @@ def test_double_ring(world_size): if __name__ == "__main__": test_ring_attn() test_double_ring() +======= + local_dqkv = zigzag_split_batch(dqkv, sp_group) + assert_close(ring_dq, local_dqkv[:, :, 0], atol=atol, rtol=rtol) + assert_close(ring_dk, local_dqkv[:, :, 1], atol=atol, rtol=rtol) + assert_close(ring_dv, local_dqkv[:, :, 2], atol=atol, rtol=rtol) + + +def launch(rank, world_size, port): + colossalai.launch(rank, world_size, "localhost", port) + test_ring_attn() + + +@rerun_if_address_is_in_use() +def run_ring_attn(): + spawn(launch, nprocs=8) + + +if __name__ == "__main__": + run_ring_attn() +>>>>>>> precision tests passed diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 06ed5f22405f..d4f0968f20bc 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -153,7 +153,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - # Zigzag Ring Attention + # Zigzag Ring Attention + PP + { + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring_attn", + "use_lazy_init": True, + "zero_stage": 1, + "precision": "bf16", + "initial_scale": 1, + }, + # Ring Attention + TP { "tp_size": 2, "pp_size": 1, @@ -170,7 +183,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 1, "sp_size": 2, - "num_microbatches": 2, + "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "all_to_all", "enable_all_optimization": True, @@ -266,7 +279,6 @@ def run_llama_test(test_config): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name: continue - try: check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) except Exception as e: @@ -360,4 +372,8 @@ def test_llama_3d(): if __name__ == "__main__": test_llama() - test_llama_3d() \ No newline at end of file +<<<<<<< HEAD + test_llama_3d() +======= + # test_llama_3d() +>>>>>>> precision tests passed From f5a1b9964bbc7de602e75b188c090f56fb6f1ae9 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 22 Jul 2024 03:39:19 +0000 Subject: [PATCH 35/71] fix typos and remove misc files --- colossalai/shardformer/layer/attn.py | 17 +- .../benchmark/benchmark_qkvpacked_func.py | 87 ---- .../benchmark_varlen_qkvpacked_func.py | 91 ---- .../ring_flash_attn/__init__.py | 16 - .../ring_flash_attn/ring_flash_attn.py | 281 ----------- .../ring_flash_attn/ring_flash_attn_varlen.py | 318 ------------- .../ring_flash_attn/stripe_flash_attn.py | 325 ------------- .../ring_flash_attn/triton_utils.py | 137 ------ ring-flash-attention/ring_flash_attn/utils.py | 110 ----- .../ring_flash_attn/zigzag_ring_flash_attn.py | 327 ------------- .../zigzag_ring_flash_attn_varlen.py | 441 ------------------ ring-flash-attention/setup.py | 9 - .../test/test_ring_flash_attn_func.py | 124 ----- .../test/test_ring_flash_attn_varlen_func.py | 157 ------- .../test/test_stripe_flash_attn_func.py | 130 ------ .../test/test_triton_kernels.py | 30 -- .../test/test_zigzag_ring_flash_attn_func.py | 150 ------ ...test_zigzag_ring_flash_attn_varlen_func.py | 163 ------- .../test_layer/test_ring_attn.py | 2 +- 19 files changed, 3 insertions(+), 2912 deletions(-) delete mode 100644 ring-flash-attention/benchmark/benchmark_qkvpacked_func.py delete mode 100644 ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py delete mode 100644 ring-flash-attention/ring_flash_attn/__init__.py delete mode 100644 ring-flash-attention/ring_flash_attn/ring_flash_attn.py delete mode 100644 ring-flash-attention/ring_flash_attn/ring_flash_attn_varlen.py delete mode 100644 ring-flash-attention/ring_flash_attn/stripe_flash_attn.py delete mode 100644 ring-flash-attention/ring_flash_attn/triton_utils.py delete mode 100644 ring-flash-attention/ring_flash_attn/utils.py delete mode 100644 ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn.py delete mode 100644 ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py delete mode 100644 ring-flash-attention/setup.py delete mode 100644 ring-flash-attention/test/test_ring_flash_attn_func.py delete mode 100644 ring-flash-attention/test/test_ring_flash_attn_varlen_func.py delete mode 100644 ring-flash-attention/test/test_stripe_flash_attn_func.py delete mode 100644 ring-flash-attention/test/test_triton_kernels.py delete mode 100644 ring-flash-attention/test/test_zigzag_ring_flash_attn_func.py delete mode 100644 ring-flash-attention/test/test_zigzag_ring_flash_attn_varlen_func.py diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 1cba15cb2c6c..bc352e9f2790 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -254,12 +254,7 @@ def attention( # sanity check if attention_mask is not None: assert torch.is_floating_point(attention_mask), "attention_mask should be a floating point tensor." - if attention_mask_type in ( - AttnMaskType.CUSTOM, - AttnMaskType.CAUSAL, - AttnMaskType.PADDED, - AttnMaskType.PADDED_CAUSAL, - ): + if attention_mask_type in (AttnMaskType.CUSTOM, AttnMaskType.CAUSAL): assert ( cu_seqlens_q is None and cu_seqlens_kv is None @@ -471,24 +466,16 @@ def _rescale_out_lse(out, block_out, lse, block_lse): # new_lse = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + assert not (new_lse.isnan().any() or new_lse.isinf().any()), f"lse is nan: {new_lse}" new_block_lse = torch.exp(block_lse - new_lse) out.copy_(torch.exp(lse - new_lse) * out + new_block_lse * block_out) lse.copy_(new_lse) - assert _not_nan(new_lse), new_lse - assert _not_nan(new_block_lse), new_block_lse - assert _not_nan(out), out # block_out = block_out.float() - # out.copy_(out - F.sigmoid(block_lse - lse) * (out - block_out)) - # lse.copy_(lse - F.logsigmoid(lse - block_lse)) # assert not lse.isnan().any(), lse # assert not out.isnan().any(), out -def _not_nan(x): - return not (x.isnan().any() or x.isinf().any()) - - class RingAttention(torch.autograd.Function): """Implements the Ring Attention from `Ring Attention with Blockwise Transformers for Near-Infinite Context` (https://arxiv.org/abs/2310.01889). diff --git a/ring-flash-attention/benchmark/benchmark_qkvpacked_func.py b/ring-flash-attention/benchmark/benchmark_qkvpacked_func.py deleted file mode 100644 index a6742e04a696..000000000000 --- a/ring-flash-attention/benchmark/benchmark_qkvpacked_func.py +++ /dev/null @@ -1,87 +0,0 @@ -import torch -import torch.cuda -import torch.distributed as dist -from flash_attn import flash_attn_qkvpacked_func -from ring_flash_attn import ( - ring_flash_attn_qkvpacked_func, - stripe_flash_attn_qkvpacked_func, - zigzag_ring_flash_attn_qkvpacked_func, -) - - -def benchmark(f, num_iter=100, forward_only=True, log=True): - dtype = torch.bfloat16 - rank = dist.get_rank() - world_size = dist.get_world_size() - device = torch.device(f"cuda:{rank}") - torch.cuda.set_device(device) - - batch_size = 1 - seqlen = 1024 * 8 - nheads = 5 - d = 128 - dropout_p = 0 - causal = True - deterministic = False - - assert seqlen % (2 * world_size) == 0 - assert d % 8 == 0 - - qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) - dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) - - begin = torch.cuda.Event(enable_timing=True) - begin.record() - - if forward_only: - with torch.no_grad(): - for _ in range(num_iter): - _ = f( - qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=False, - ) - - else: - for _ in range(num_iter): - qkv.grad = None - out = f( - qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=False, - ) - out.backward(dout) - end = torch.cuda.Event(enable_timing=True) - end.record() - torch.cuda.synchronize(device=device) - time = begin.elapsed_time(end) / 1000.0 - - if rank == 0 and log: - print(f"{num_iter / time:.3f} iter/s, {time:.3f} sec") - - -if __name__ == "__main__": - dist.init_process_group("nccl") - rank = dist.get_rank() - - forward_only = False - - for f in [ - flash_attn_qkvpacked_func, - ring_flash_attn_qkvpacked_func, - zigzag_ring_flash_attn_qkvpacked_func, - stripe_flash_attn_qkvpacked_func, - ]: - torch.cuda.empty_cache() - if rank == 0: - print(f"# {f.__name__}") - benchmark(f, forward_only=forward_only, log=False) - benchmark(f, forward_only=forward_only, log=True) diff --git a/ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py b/ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py deleted file mode 100644 index 18c8cafc0078..000000000000 --- a/ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch -import torch.cuda -import torch.distributed as dist -from flash_attn import flash_attn_varlen_qkvpacked_func -from ring_flash_attn import ring_flash_attn_varlen_qkvpacked_func, zigzag_ring_flash_attn_varlen_qkvpacked_func - - -def benchmark(f, num_iter=100, forward_only=True, log=True): - dtype = torch.bfloat16 - rank = dist.get_rank() - world_size = dist.get_world_size() - device = torch.device(f"cuda:{rank}") - torch.cuda.set_device(device) - - seqlen = 1024 * 8 - nheads = 5 - d = 128 - dropout_p = 0 - causal = True - deterministic = False - - assert seqlen % (2 * world_size) == 0 - assert d % 8 == 0 - - qkv = torch.randn(seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) - dout = torch.randn(seqlen, nheads, d, device=device, dtype=dtype) - - cu_seqlens_list = [ - torch.tensor([0, 8192], device=device, dtype=torch.int32), - torch.tensor([0, 256, 7648, 8192], device=device, dtype=torch.int32), - torch.tensor([0, 4096, 8192], device=device, dtype=torch.int32), - torch.tensor([0, 3104, 6304, 7904, 8064, 8192], device=device, dtype=torch.int32), - ] - max_seqlen_list = [(cu_seqlens[1:] - cu_seqlens[:1]).max().item() for cu_seqlens in cu_seqlens_list] - - begin = torch.cuda.Event(enable_timing=True) - begin.record() - if forward_only: - with torch.no_grad(): - for i in range(num_iter): - _ = f( - qkv, - cu_seqlens_list[i % len(cu_seqlens_list)], - max_seqlen_list[i % len(max_seqlen_list)], - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=False, - ) - else: - for i in range(num_iter): - qkv.grad = None - out = f( - qkv, - cu_seqlens_list[i % len(cu_seqlens_list)], - max_seqlen_list[i % len(max_seqlen_list)], - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=False, - ) - out.backward(dout) - end = torch.cuda.Event(enable_timing=True) - end.record() - torch.cuda.synchronize(device=device) - time = begin.elapsed_time(end) / 1000.0 - - if rank == 0 and log: - print(f"{num_iter / time} iter/s, {time} sec") - - -if __name__ == "__main__": - dist.init_process_group("nccl") - rank = dist.get_rank() - - forward_only = False - - for f in [ - flash_attn_varlen_qkvpacked_func, - ring_flash_attn_varlen_qkvpacked_func, - zigzag_ring_flash_attn_varlen_qkvpacked_func, - ]: - torch.cuda.empty_cache() - if rank == 0: - print(f"# {f.__name__}") - benchmark(f, forward_only=forward_only, log=False) - benchmark(f, forward_only=forward_only, log=True) diff --git a/ring-flash-attention/ring_flash_attn/__init__.py b/ring-flash-attention/ring_flash_attn/__init__.py deleted file mode 100644 index 01d5ec36218c..000000000000 --- a/ring-flash-attention/ring_flash_attn/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from .ring_flash_attn import ring_flash_attn_func, ring_flash_attn_kvpacked_func, ring_flash_attn_qkvpacked_func -from .ring_flash_attn_varlen import ( - ring_flash_attn_varlen_func, - ring_flash_attn_varlen_kvpacked_func, - ring_flash_attn_varlen_qkvpacked_func, -) -from .stripe_flash_attn import stripe_flash_attn_func, stripe_flash_attn_kvpacked_func, stripe_flash_attn_qkvpacked_func -from .zigzag_ring_flash_attn import ( - zigzag_ring_flash_attn_func, - zigzag_ring_flash_attn_kvpacked_func, - zigzag_ring_flash_attn_qkvpacked_func, -) -from .zigzag_ring_flash_attn_varlen import ( - zigzag_ring_flash_attn_varlen_func, - zigzag_ring_flash_attn_varlen_qkvpacked_func, -) diff --git a/ring-flash-attention/ring_flash_attn/ring_flash_attn.py b/ring-flash-attention/ring_flash_attn/ring_flash_attn.py deleted file mode 100644 index b36484dbd145..000000000000 --- a/ring-flash-attention/ring_flash_attn/ring_flash_attn.py +++ /dev/null @@ -1,281 +0,0 @@ -import torch -from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward - -from .utils import RingComm, update_out_and_lse - - -def ring_flash_attn_forward( - process_group, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - comm = RingComm(process_group) - - out = None - lse = None - - next_k, next_v = None, None - - for step in range(comm.world_size): - if step + 1 != comm.world_size: - next_k: torch.Tensor = comm.send_recv(k) - next_v: torch.Tensor = comm.send_recv(v) - comm.commit() - - if not causal or step <= comm.rank: - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal and step == 0, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=True and dropout_p > 0, - ) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - - if step + 1 != comm.world_size: - comm.wait() - k = next_k - v = next_v - - out = out.to(q.dtype) - lse = lse.squeeze(dim=-1).transpose(1, 2) - return out, lse - - -def ring_flash_attn_backward( - process_group, - dout, - q, - k, - v, - out, - softmax_lse, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - kv_comm = RingComm(process_group) - d_kv_comm = RingComm(process_group) - dq, dk, dv = None, None, None - next_dk, next_dv = None, None - - block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) - block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) - block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) - - next_dk, next_dv = None, None - next_k, next_v = None, None - - for step in range(kv_comm.world_size): - if step + 1 != kv_comm.world_size: - next_k = kv_comm.send_recv(k) - next_v = kv_comm.send_recv(v) - kv_comm.commit() - if step <= kv_comm.rank or not causal: - bwd_causal = causal and step == 0 - _flash_attn_backward( - dout, - q, - k, - v, - out, - softmax_lse, - block_dq_buffer, - block_dk_buffer, - block_dv_buffer, - dropout_p, - softmax_scale, - bwd_causal, - window_size, - alibi_slopes, - deterministic, - rng_state=None, - ) - - if dq is None: - dq = block_dq_buffer.to(torch.float32) - dk = block_dk_buffer.to(torch.float32) - dv = block_dv_buffer.to(torch.float32) - else: - dq += block_dq_buffer - d_kv_comm.wait() - dk = block_dk_buffer + next_dk - dv = block_dv_buffer + next_dv - elif step != 0: - d_kv_comm.wait() - dk = next_dk - dv = next_dv - - if step + 1 != kv_comm.world_size: - kv_comm.wait() - k = next_k - v = next_v - - next_dk = d_kv_comm.send_recv(dk) - next_dv = d_kv_comm.send_recv(dv) - d_kv_comm.commit() - - d_kv_comm.wait() - - return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) - - -class RingFlashAttnFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_softmax, - group, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - - assert alibi_slopes is None - k = k.contiguous() - v = v.contiguous() - out, softmax_lse = ring_flash_attn_forward( - group, - q, - k, - v, - softmax_scale=softmax_scale, - dropout_p=dropout_p, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - deterministic=False, - ) - # this should be out_padded - ctx.save_for_backward(q, k, v, out, softmax_lse) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - ctx.group = group - return out if not return_softmax else (out, softmax_lse, None) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse = ctx.saved_tensors - dq, dk, dv = ring_flash_attn_backward( - ctx.group, - dout, - q, - k, - v, - out, - softmax_lse, - softmax_scale=ctx.softmax_scale, - dropout_p=ctx.dropout_p, - causal=ctx.causal, - window_size=ctx.window_size, - alibi_slopes=ctx.alibi_slopes, - deterministic=ctx.deterministic, - ) - return dq, dk, dv, None, None, None, None, None, None, None, None - - -def ring_flash_attn_qkvpacked_func( - qkv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return RingFlashAttnFunc.apply( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def ring_flash_attn_kvpacked_func( - q, - kv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return RingFlashAttnFunc.apply( - q, - kv[:, :, 0], - kv[:, :, 1], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def ring_flash_attn_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return RingFlashAttnFunc.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) diff --git a/ring-flash-attention/ring_flash_attn/ring_flash_attn_varlen.py b/ring-flash-attention/ring_flash_attn/ring_flash_attn_varlen.py deleted file mode 100644 index 118bdea4c7d0..000000000000 --- a/ring-flash-attention/ring_flash_attn/ring_flash_attn_varlen.py +++ /dev/null @@ -1,318 +0,0 @@ -import torch -from flash_attn.flash_attn_interface import _flash_attn_varlen_backward, _flash_attn_varlen_forward - -from .utils import RingComm, update_out_and_lse - -try: - from .triton_utils import flatten_varlen_lse, unflatten_varlen_lse -except: - from .utils import flatten_varlen_lse, unflatten_varlen_lse - - -def ring_flash_attn_varlen_forward( - process_group, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens, - max_seqlen, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - comm = RingComm(process_group) - - out = None - lse = None - next_k, next_v = None, None - - for step in range(comm.world_size): - if step + 1 != comm.world_size: - next_k: torch.Tensor = comm.send_recv(k) - next_v: torch.Tensor = comm.send_recv(v) - comm.commit() - if not causal or step <= comm.rank: - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_varlen_forward( - q, - k, - v, - cu_seqlens, - cu_seqlens, - max_seqlen, - max_seqlen, - dropout_p, - softmax_scale, - causal=causal and step == 0, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=True and dropout_p > 0, - ) - block_lse = flatten_varlen_lse( - block_lse, - cu_seqlens=cu_seqlens, - ) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - - if step + 1 != comm.world_size: - comm.wait() - k = next_k - v = next_v - - out = out.to(q.dtype) - lse = unflatten_varlen_lse(lse, cu_seqlens, max_seqlen) - return out, lse - - -def ring_flash_attn_varlen_backward( - process_group, - dout, - q, - k, - v, - out, - softmax_lse, - cu_seqlens, - max_seqlen, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - kv_comm = RingComm(process_group) - d_kv_comm = RingComm(process_group) - dq, dk, dv = None, None, None - next_dk, next_dv = None, None - - block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) - block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) - block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) - - next_dk, next_dv = None, None - next_k, next_v = None, None - for step in range(kv_comm.world_size): - if step + 1 != kv_comm.world_size: - next_k = kv_comm.send_recv(k) - next_v = kv_comm.send_recv(v) - kv_comm.commit() - if step <= kv_comm.rank or not causal: - bwd_causal = causal and step == 0 - _flash_attn_varlen_backward( - dout, - q, - k, - v, - out, - softmax_lse, - block_dq_buffer, - block_dk_buffer, - block_dv_buffer, - cu_seqlens, - cu_seqlens, - max_seqlen, - max_seqlen, - dropout_p, - softmax_scale, - bwd_causal, - window_size, - alibi_slopes, - deterministic, - rng_state=None, - ) - - if dq is None: - dq = block_dq_buffer.to(torch.float32) - dk = block_dk_buffer.to(torch.float32) - dv = block_dv_buffer.to(torch.float32) - else: - dq += block_dq_buffer - d_kv_comm.wait() - dk = block_dk_buffer + next_dk - dv = block_dv_buffer + next_dv - elif step != 0: - d_kv_comm.wait() - dk = next_dk - dv = next_dv - - if step + 1 != kv_comm.world_size: - kv_comm.wait() - k = next_k - v = next_v - - next_dk = d_kv_comm.send_recv(dk) - next_dv = d_kv_comm.send_recv(dv) - d_kv_comm.commit() - - d_kv_comm.wait() - - return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) - - -class RingFlashAttnVarlenFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_softmax, - group, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - - assert alibi_slopes is None - k = k.contiguous() - v = v.contiguous() - out, softmax_lse = ring_flash_attn_varlen_forward( - group, - q, - k, - v, - cu_seqlens, - max_seqlen, - softmax_scale=softmax_scale, - dropout_p=dropout_p, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - deterministic=False, - ) - # this should be out_padded - ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens) - ctx.max_seqlen = max_seqlen - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - ctx.group = group - return out if not return_softmax else (out, softmax_lse, None) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens = ctx.saved_tensors - dq, dk, dv = ring_flash_attn_varlen_backward( - ctx.group, - dout, - q, - k, - v, - out, - softmax_lse, - cu_seqlens, - ctx.max_seqlen, - softmax_scale=ctx.softmax_scale, - dropout_p=ctx.dropout_p, - causal=ctx.causal, - window_size=ctx.window_size, - alibi_slopes=ctx.alibi_slopes, - deterministic=ctx.deterministic, - ) - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None - - -def ring_flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return RingFlashAttnVarlenFunc.apply( - qkv[:, 0], - qkv[:, 1], - qkv[:, 2], - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def ring_flash_attn_varlen_kvpacked_func( - q, - kv, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return RingFlashAttnVarlenFunc.apply( - q, - kv[:, 0], - kv[:, 1], - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def ring_flash_attn_varlen_func( - q, - k, - v, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return RingFlashAttnVarlenFunc.apply( - q, - k, - v, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) diff --git a/ring-flash-attention/ring_flash_attn/stripe_flash_attn.py b/ring-flash-attention/ring_flash_attn/stripe_flash_attn.py deleted file mode 100644 index ca426920f4ed..000000000000 --- a/ring-flash-attention/ring_flash_attn/stripe_flash_attn.py +++ /dev/null @@ -1,325 +0,0 @@ -import torch -from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward - -from .utils import RingComm, update_out_and_lse - - -def stripe_flash_attn_forward( - process_group, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - assert causal, "stripe flash attn only supports causal attention, if not causal, use ring flash attn instead" - comm = RingComm(process_group) - - out = None - lse = None - - next_k, next_v = None, None - - for step in range(comm.world_size): - if step + 1 != comm.world_size: - next_k: torch.Tensor = comm.send_recv(k) - next_v: torch.Tensor = comm.send_recv(v) - comm.commit() - - if step <= comm.rank: - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=True and dropout_p > 0, - ) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - else: - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( - q[:, 1:], - k[:, :-1], - v[:, :-1], - dropout_p, - softmax_scale, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=True and dropout_p > 0, - ) - out, lse = update_out_and_lse(out, lse, block_out, block_lse, slice_=(slice(None), slice(1, None))) - - if step + 1 != comm.world_size: - comm.wait() - k = next_k - v = next_v - - out = out.to(q.dtype) - lse = lse.squeeze(dim=-1).transpose(1, 2) - return out, lse - - -def stripe_flash_attn_backward( - process_group, - dout, - q, - k, - v, - out, - softmax_lse, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - assert causal, "stripe flash attn only supports causal attention, if not causal, ring flash attn instead" - kv_comm = RingComm(process_group) - d_kv_comm = RingComm(process_group) - dq, dk, dv = None, None, None - next_dk, next_dv = None, None - next_k, next_v = None, None - dk_comm_buffer, dv_comm_buffer = None, None - - block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) - block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) - block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) - for step in range(kv_comm.world_size): - if step + 1 != kv_comm.world_size: - next_k = kv_comm.send_recv(k) - next_v = kv_comm.send_recv(v) - kv_comm.commit() - - shift_causal = step > kv_comm.rank - softmax_lse_1 = None - if not shift_causal: - _flash_attn_backward( - dout, - q, - k, - v, - out, - softmax_lse, - block_dq_buffer, - block_dk_buffer, - block_dv_buffer, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - rng_state=None, - ) - else: - if softmax_lse_1 is None: - # lazy init, since the last rank does not need softmax_lse_1 - softmax_lse_1 = softmax_lse[:, :, 1:].contiguous() - _flash_attn_backward( - dout[:, 1:], - q[:, 1:], - k[:, :-1], - v[:, :-1], - out[:, 1:], - softmax_lse_1, - block_dq_buffer[:, 1:], - block_dk_buffer[:, :-1], - block_dv_buffer[:, :-1], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - rng_state=None, - ) - - if dq is None: - dq = block_dq_buffer.to(torch.float32) - dk = block_dk_buffer.to(torch.float32) - dv = block_dv_buffer.to(torch.float32) - else: - if not shift_causal: - dq += block_dq_buffer - else: - dq[:, 1:] += block_dq_buffer[:, 1:] - d_kv_comm.wait() - dk_comm_buffer, dv_comm_buffer = dk, dv - dk = next_dk - dv = next_dv - - if not shift_causal: - dk = block_dk_buffer + dk - dv = block_dv_buffer + dv - else: - dk[:, :-1] += block_dk_buffer[:, :-1] - dv[:, :-1] += block_dv_buffer[:, :-1] - - if step + 1 != kv_comm.world_size: - kv_comm.wait() - k = next_k - v = next_v - - next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) - next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) - d_kv_comm.commit() - - d_kv_comm.wait() - - return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) - - -class StripeFlashAttnFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_softmax, - group, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - - assert alibi_slopes is None - k = k.contiguous() - v = v.contiguous() - out, softmax_lse = stripe_flash_attn_forward( - group, - q, - k, - v, - softmax_scale=softmax_scale, - dropout_p=dropout_p, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - deterministic=False, - ) - # this should be out_padded - ctx.save_for_backward(q, k, v, out, softmax_lse) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - ctx.group = group - return out if not return_softmax else (out, softmax_lse, None) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse = ctx.saved_tensors - dq, dk, dv = stripe_flash_attn_backward( - ctx.group, - dout, - q, - k, - v, - out, - softmax_lse, - softmax_scale=ctx.softmax_scale, - dropout_p=ctx.dropout_p, - causal=ctx.causal, - window_size=ctx.window_size, - alibi_slopes=ctx.alibi_slopes, - deterministic=ctx.deterministic, - ) - return dq, dk, dv, None, None, None, None, None, None, None, None - - -def stripe_flash_attn_qkvpacked_func( - qkv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return StripeFlashAttnFunc.apply( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def stripe_flash_attn_kvpacked_func( - q, - kv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return StripeFlashAttnFunc.apply( - q, - kv[:, :, 0], - kv[:, :, 1], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def stripe_flash_attn_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return StripeFlashAttnFunc.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) diff --git a/ring-flash-attention/ring_flash_attn/triton_utils.py b/ring-flash-attention/ring_flash_attn/triton_utils.py deleted file mode 100644 index 66e362d93d68..000000000000 --- a/ring-flash-attention/ring_flash_attn/triton_utils.py +++ /dev/null @@ -1,137 +0,0 @@ -import torch -import triton -import triton.language as tl - - -@triton.jit -def flatten_kernel( - # pointers to matrices - OUT, - LSE, - CU_SEQLENS, - # strides - stride_out_nheads, - stride_out_seqlen, - stride_lse_batch, - stride_lse_nheads, - stride_lse_seqlen, - # meta-parameters - BLOCK_M: tl.constexpr, -): - pid_m = tl.program_id(axis=0) - pid_batch = tl.program_id(axis=1) - pid_head = tl.program_id(axis=2) - - start_idx = tl.load(CU_SEQLENS + pid_batch) - seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx - LSE = LSE + pid_batch * stride_lse_batch + pid_head * stride_lse_nheads - OUT = OUT + pid_head * stride_out_nheads + start_idx * stride_out_seqlen - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - - LSE = LSE + rm[:, None] * stride_lse_seqlen - x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) - - OUT = OUT + rm[:, None] * stride_out_seqlen - tl.store(OUT, x, mask=rm[:, None] < seqlen) - - -def flatten_varlen_lse(lse, cu_seqlens): - """ - Arguments: - lse: (batch_size, nheads, max_seqlen) - cu_seqlens: (batch_size + 1,) - Return: - flatten_lse: (nheads, total_seqlen) - """ - total_seqlen = cu_seqlens[-1] - batch_size, nheads, max_seqlen = lse.shape - output = torch.empty((nheads, total_seqlen), dtype=lse.dtype, device=lse.device) - - grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) - BLOCK_M = 4 - - with torch.cuda.device(lse.device.index): - flatten_kernel[grid]( - output, - lse, - cu_seqlens, - # strides - output.stride(0), - output.stride(1), - lse.stride(0), - lse.stride(1), - lse.stride(2), - BLOCK_M, - ) - return output - - -@triton.jit -def unflatten_kernel( - # pointers to matrices - OUT, - LSE, - CU_SEQLENS, - # strides - stride_out_batch, - stride_out_nheads, - stride_out_seqlen, - stride_lse_seqlen, - stride_lse_nheads, - # meta-parameters - BLOCK_M: tl.constexpr, -): - pid_m = tl.program_id(axis=0) - pid_batch = tl.program_id(axis=1) - pid_head = tl.program_id(axis=2) - - start_idx = tl.load(CU_SEQLENS + pid_batch) - seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx - LSE = LSE + pid_head * stride_lse_nheads + start_idx * stride_lse_seqlen - OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - - LSE = LSE + rm[:, None] * stride_lse_seqlen - x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) - - OUT = OUT + rm[:, None] * stride_out_seqlen - tl.store(OUT, x, mask=rm[:, None] < seqlen) - - -def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): - """ - Arguments: - lse: (total_seqlen, nheads, 1) - cu_seqlens: (batch_size + 1,) - max_seqlen: int - Return: - unflatten_lse: (batch_size, nheads, max_seqlen) - """ - lse = lse.unsqueeze(dim=-1) - batch_size = len(cu_seqlens) - 1 - nheads = lse.shape[1] - output = torch.empty( - (batch_size, nheads, max_seqlen), - dtype=lse.dtype, - device=lse.device, - ) - - grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) - BLOCK_M = 4 - - with torch.cuda.device(lse.device.index): - unflatten_kernel[grid]( - output, - lse, - cu_seqlens, - # strides - output.stride(0), - output.stride(1), - output.stride(2), - lse.stride(0), - lse.stride(1), - BLOCK_M, - ) - return output diff --git a/ring-flash-attention/ring_flash_attn/utils.py b/ring-flash-attention/ring_flash_attn/utils.py deleted file mode 100644 index 787732af8135..000000000000 --- a/ring-flash-attention/ring_flash_attn/utils.py +++ /dev/null @@ -1,110 +0,0 @@ -from typing import Optional, Tuple - -import torch -import torch.distributed as dist -import torch.nn.functional as F - -__all__ = ["update_out_and_lse", "RingComm"] - - -@torch.jit.script -def _update_out_and_lse( - out: torch.Tensor, - lse: torch.Tensor, - block_out: torch.Tensor, - block_lse: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - - block_out = block_out.to(torch.float32) - block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) - - # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) - # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out - # For additional context and discussion, please refer to: - # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 - out = out - F.sigmoid(block_lse - lse) * (out - block_out) - lse = lse - F.logsigmoid(lse - block_lse) - - return out, lse - - -def update_out_and_lse( - out: Optional[torch.Tensor], - lse: Optional[torch.Tensor], - block_out: torch.Tensor, - block_lse: torch.Tensor, - slice_=None, -) -> Tuple[torch.Tensor, torch.Tensor]: - if out is None: - if slice_ is not None: - raise RuntimeError("first update_out_and_lse should not pass slice_ args") - out = block_out.to(torch.float32) - lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) - elif slice_ is not None: - slice_out, slice_lse = out[slice_], lse[slice_] - slice_out, slice_lse = _update_out_and_lse(slice_out, slice_lse, block_out, block_lse) - out[slice_], lse[slice_] = slice_out, slice_lse - else: - out, lse = _update_out_and_lse(out, lse, block_out, block_lse) - return out, lse - - -@torch.jit.script -def flatten_varlen_lse(lse, cu_seqlens): - new_lse = [] - for i in range(len(cu_seqlens) - 1): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - new_lse.append(lse[i, :, : end - start]) - return torch.cat(new_lse, dim=1) - - -@torch.jit.script -def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): - num_seq = len(cu_seqlens) - 1 - num_head = lse.shape[-2] - new_lse = torch.empty((num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device) - for i in range(num_seq): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - new_lse[i, : end - start] = lse[start:end] - return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous() - - -class RingComm: - def __init__(self, process_group: dist.ProcessGroup): - self._process_group = process_group - self._ops = [] - self.rank = dist.get_rank(self._process_group) - self.world_size = dist.get_world_size(self._process_group) - self._reqs = None - - self.send_rank = (self.rank + 1) % self.world_size - self.recv_rank = (self.rank - 1) % self.world_size - - if process_group is not None: - self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) - self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) - - def send_recv(self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None) -> torch.Tensor: - if recv_tensor is None: - res = torch.empty_like(to_send) - else: - res = recv_tensor - - send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group) - recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) - self._ops.append(send_op) - self._ops.append(recv_op) - return res - - def commit(self): - if self._reqs is not None: - raise RuntimeError("commit called twice") - self._reqs = dist.batch_isend_irecv(self._ops) - - def wait(self): - if self._reqs is None: - raise RuntimeError("wait called before commit") - for req in self._reqs: - req.wait() - self._reqs = None - self._ops = [] diff --git a/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn.py b/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn.py deleted file mode 100644 index d3e2821c5d4d..000000000000 --- a/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn.py +++ /dev/null @@ -1,327 +0,0 @@ -import torch -import torch.distributed as dist -from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward - -from .utils import RingComm, update_out_and_lse - - -def zigzag_ring_flash_attn_forward( - process_group, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - assert causal == True, "zigzag ring is meaningless for causal=False" - comm = RingComm(process_group) - - block_seq_len = q.shape[1] // 2 - q1 = q[:, block_seq_len:] - - out = None - lse = None - next_k, next_v = None, None - - def forward(q, k, v, causal): - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=True and dropout_p > 0, - ) - return block_out, block_lse - - for step in range(comm.world_size): - if step + 1 != comm.world_size: - next_k: torch.Tensor = comm.send_recv(k) - next_v: torch.Tensor = comm.send_recv(v) - comm.commit() - - if step == 0: - block_out, block_lse = forward(q, k, v, causal=True) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - elif step <= comm.rank: - k0 = k[:, :block_seq_len] - v0 = v[:, :block_seq_len] - block_out, block_lse = forward(q, k0, v0, causal=False) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - else: - block_out, block_lse = forward(q1, k, v, causal=False) - out, lse = update_out_and_lse( - out, - lse, - block_out, - block_lse, - slice_=(slice(None), slice(block_seq_len, None)), - ) - - if step + 1 != comm.world_size: - comm.wait() - k = next_k - v = next_v - - out = out.to(q.dtype) - lse = lse.squeeze(dim=-1).transpose(1, 2) - return out, lse - - -def zigzag_ring_flash_attn_backward( - process_group, - dout, - q, - k, - v, - out, - softmax_lse, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - assert causal == True, "zigzag ring is meaningless for causal=False" - kv_comm = RingComm(process_group) - d_kv_comm = RingComm(process_group) - dq, dk, dv = None, None, None - next_dk, next_dv = None, None - next_k, next_v = None, None - dk_comm_buffer, dv_comm_buffer = None, None - - dout1 = dout.chunk(2, dim=1)[1] - q1 = q.chunk(2, dim=1)[1] - out1 = out.chunk(2, dim=1)[1] - softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous() - block_seq_len = q.shape[1] // 2 - - # repeatly allocating buffer may be slow... - dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) - dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) - dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) - - def backward(dout, q, k, v, out, softmax_lse, causal): - seqlen_q = q.shape[1] - seqlen_kv = k.shape[1] - _flash_attn_backward( - dout, - q, - k, - v, - out, - softmax_lse, - dq_buffer[:, :seqlen_q], - dk_buffer[:, :seqlen_kv], - dv_buffer[:, :seqlen_kv], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - rng_state=None, - ) - - for step in range(kv_comm.world_size): - if step + 1 != kv_comm.world_size: - next_k = kv_comm.send_recv(k) - next_v = kv_comm.send_recv(v) - kv_comm.commit() - - if step == 0: - backward(dout, q, k, v, out, softmax_lse, causal=True) - dq = dq_buffer.to(torch.float32) - dk = dk_buffer.to(torch.float32) - dv = dv_buffer.to(torch.float32) - else: - if step <= kv_comm.rank: - k0 = k[:, :block_seq_len] - v0 = v[:, :block_seq_len] - backward(dout, q, k0, v0, out, softmax_lse, causal=False) - dq += dq_buffer - else: - backward(dout1, q1, k, v, out1, softmax_lse1, causal=False) - # always use the first half in dq_buffer. - dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len] - - d_kv_comm.wait() - dk_comm_buffer, dv_comm_buffer = dk, dv - dk, dv = next_dk, next_dv - - if step <= kv_comm.rank: - dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len] - dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len] - else: - dk += dk_buffer - dv += dv_buffer - if dist.get_rank() == 0: - torch.save(torch.stack((dk, dv)), f"step_{step}.pt") - - if step + 1 != kv_comm.world_size: - kv_comm.wait() - k = next_k - v = next_v - - next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) - next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) - d_kv_comm.commit() - - d_kv_comm.wait() - - return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) - - -class ZigZagRingFlashAttnFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_softmax, - group, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - - assert alibi_slopes is None - k = k.contiguous() - v = v.contiguous() - out, softmax_lse = zigzag_ring_flash_attn_forward( - group, - q, - k, - v, - softmax_scale=softmax_scale, - dropout_p=dropout_p, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - deterministic=False, - ) - # this should be out_padded - ctx.save_for_backward(q, k, v, out, softmax_lse) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - ctx.group = group - return out if not return_softmax else (out, softmax_lse, None) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse = ctx.saved_tensors - dq, dk, dv = zigzag_ring_flash_attn_backward( - ctx.group, - dout, - q, - k, - v, - out, - softmax_lse, - softmax_scale=ctx.softmax_scale, - dropout_p=ctx.dropout_p, - causal=ctx.causal, - window_size=ctx.window_size, - alibi_slopes=ctx.alibi_slopes, - deterministic=ctx.deterministic, - ) - return dq, dk, dv, None, None, None, None, None, None, None, None - - -def zigzag_ring_flash_attn_qkvpacked_func( - qkv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return ZigZagRingFlashAttnFunc.apply( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def zigzag_ring_flash_attn_kvpacked_func( - q, - kv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return ZigZagRingFlashAttnFunc.apply( - q, - kv[:, :, 0], - kv[:, :, 1], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def zigzag_ring_flash_attn_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return ZigZagRingFlashAttnFunc.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) diff --git a/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py b/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py deleted file mode 100644 index 5d4a8dd2daf0..000000000000 --- a/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py +++ /dev/null @@ -1,441 +0,0 @@ -import torch -from flash_attn.flash_attn_interface import _flash_attn_varlen_backward, _flash_attn_varlen_forward - -from .utils import RingComm, update_out_and_lse - -try: - from .triton_utils import flatten_varlen_lse, unflatten_varlen_lse -except: - from .utils import flatten_varlen_lse, unflatten_varlen_lse - - -def get_half_index(cu_seqlens, *, front: bool): - if len(cu_seqlens) == 2: - if front: - return slice(None, cu_seqlens[-1] // 2) - else: - return slice(cu_seqlens[-1] // 2, None) - - index = torch.zeros((cu_seqlens[-1],), dtype=bool) - for i in range(len(cu_seqlens) - 1): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - if front: - end = (start + end) // 2 - else: - start = (start + end) // 2 - index[start:end] = True - return index - - -@torch.jit.script -def get_half_lse(lse, cu_seqlens, *, front: bool): - new_lse = torch.empty( - (lse.shape[0], lse.shape[1], lse.shape[2] // 2), - dtype=lse.dtype, - device=lse.device, - ) - for i in range(len(cu_seqlens) - 1): - seqlen = (cu_seqlens[i + 1] - cu_seqlens[i]).item() - if front: - start, end = 0, seqlen // 2 - else: - start, end = seqlen // 2, seqlen - new_lse[i, :, : seqlen // 2] = lse[i, :, start:end] - return new_lse - - -def zigzag_ring_flash_attn_varlen_forward( - process_group, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens, - max_seqlen, - half_index0, - half_index1, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - assert causal == True, "zigzag ring is meaningless for causal=False" - comm = RingComm(process_group) - - block_seq_len = q.shape[0] // 2 - q1 = q[half_index1] - - out = None - lse = None - next_k, next_v = None, None - half_cu_seqlens = cu_seqlens // 2 - half_max_seqlen = max_seqlen // 2 - - def forward(q, k, v, causal): - seqlen_q = q.shape[0] - seqlen_kv = k.shape[0] - cu_seqlens_q = half_cu_seqlens if seqlen_q == block_seq_len else cu_seqlens - max_seqlen_q = half_max_seqlen if seqlen_q == block_seq_len else max_seqlen - cu_seqlens_kv = half_cu_seqlens if seqlen_kv == block_seq_len else cu_seqlens - max_seqlen_kv = half_max_seqlen if seqlen_kv == block_seq_len else max_seqlen - block_out, _, _, _, _, block_lse, _, _ = _flash_attn_varlen_forward( - q, - k, - v, - # the first half and the second half are the same - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - dropout_p, - softmax_scale, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=True and dropout_p > 0, - ) - return block_out, block_lse - - for step in range(comm.world_size): - if step + 1 != comm.world_size: - next_k: torch.Tensor = comm.send_recv(k) - next_v: torch.Tensor = comm.send_recv(v) - comm.commit() - - if step == 0: - block_out, block_lse = forward(q, k, v, causal=True) - block_lse = flatten_varlen_lse( - block_lse, - cu_seqlens=cu_seqlens, - ) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - elif step <= comm.rank: - k0 = k[half_index0] - v0 = v[half_index0] - block_out, block_lse = forward(q, k0, v0, causal=False) - block_lse = flatten_varlen_lse( - block_lse, - cu_seqlens=cu_seqlens, - ) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - else: - block_out, block_lse = forward(q1, k, v, causal=False) - block_lse = flatten_varlen_lse( - block_lse, - cu_seqlens=half_cu_seqlens, - ) - out[half_index1], lse[half_index1] = update_out_and_lse( - out[half_index1], lse[half_index1], block_out, block_lse - ) - - if step + 1 != comm.world_size: - comm.wait() - k = next_k - v = next_v - - out = out.to(q.dtype) - lse = unflatten_varlen_lse(lse, cu_seqlens, max_seqlen) - return out, lse - - -def zigzag_ring_flash_attn_varlen_backward( - process_group, - dout, - q, - k, - v, - out, - softmax_lse, - cu_seqlens, - max_seqlen, - half_index0, - half_index1, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - assert causal == True, "zigzag ring is meaningless for causal=False" - kv_comm = RingComm(process_group) - d_kv_comm = RingComm(process_group) - dq, dk, dv = None, None, None - next_dk, next_dv = None, None - next_k, next_v = None, None - dk_comm_buffer, dv_comm_buffer = None, None - - dout1 = dout[half_index1] - q1 = q[half_index1] - out1 = out[half_index1] - softmax_lse1 = get_half_lse(softmax_lse, cu_seqlens, front=False) - block_seq_len = q.shape[0] // 2 - - half_cu_seqlens = cu_seqlens // 2 - half_max_seqlen = max_seqlen // 2 - - # repeatly allocating buffer may be slow... - dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) - dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) - dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) - - def backward(dout, q, k, v, out, softmax_lse, causal): - seqlen_q = q.shape[0] - seqlen_kv = k.shape[0] - cu_seqlens_q = half_cu_seqlens if seqlen_q == block_seq_len else cu_seqlens - max_seqlen_q = half_max_seqlen if seqlen_q == block_seq_len else max_seqlen - cu_seqlens_kv = half_cu_seqlens if seqlen_kv == block_seq_len else cu_seqlens - max_seqlen_kv = half_max_seqlen if seqlen_kv == block_seq_len else max_seqlen - _flash_attn_varlen_backward( - dout, - q, - k, - v, - out, - softmax_lse, - dq_buffer[:seqlen_q], - dk_buffer[:seqlen_kv], - dv_buffer[:seqlen_kv], - # the first half and the second half are the same - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - rng_state=None, - ) - - for step in range(kv_comm.world_size): - if step + 1 != kv_comm.world_size: - next_k = kv_comm.send_recv(k) - next_v = kv_comm.send_recv(v) - kv_comm.commit() - - if step == 0: - backward(dout, q, k, v, out, softmax_lse, causal=True) - dq = dq_buffer.to(torch.float32) - dk = dk_buffer.to(torch.float32) - dv = dv_buffer.to(torch.float32) - else: - if step <= kv_comm.rank: - k0 = k[half_index0] - v0 = v[half_index0] - backward(dout, q, k0, v0, out, softmax_lse, causal=False) - dq += dq_buffer - else: - backward(dout1, q1, k, v, out1, softmax_lse1, causal=False) - dq[half_index1] += dq_buffer[:block_seq_len] - - d_kv_comm.wait() - dk_comm_buffer, dv_comm_buffer = dk, dv - dk, dv = next_dk, next_dv - - if step <= kv_comm.rank: - dk[half_index0] += dk_buffer[:block_seq_len] - dv[half_index0] += dv_buffer[:block_seq_len] - else: - dk += dk_buffer - dv += dv_buffer - - if step + 1 != kv_comm.world_size: - kv_comm.wait() - k = next_k - v = next_v - - next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) - next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) - d_kv_comm.commit() - - d_kv_comm.wait() - - return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) - - -class ZigZagRingFlashAttnVarlenFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_softmax, - group, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - - assert alibi_slopes is None - k = k.contiguous() - v = v.contiguous() - half_index0 = get_half_index(cu_seqlens, front=True) - half_index1 = get_half_index(cu_seqlens, front=False) - out, softmax_lse = zigzag_ring_flash_attn_varlen_forward( - group, - q, - k, - v, - cu_seqlens, - max_seqlen, - half_index0, - half_index1, - softmax_scale=softmax_scale, - dropout_p=dropout_p, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - deterministic=False, - ) - # this should be out_padded - is_half_index_tensor = isinstance(half_index0, torch.Tensor) - ctx.is_half_index_tensor = is_half_index_tensor - if is_half_index_tensor: - ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens, half_index0, half_index1) - else: - ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens) - ctx.half_index0 = half_index0 - ctx.half_index1 = half_index1 - ctx.max_seqlen = max_seqlen - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - ctx.group = group - return out if not return_softmax else (out, softmax_lse, None) - - @staticmethod - def backward(ctx, dout, *args): - if ctx.is_half_index_tensor: - (q, k, v, out, softmax_lse, cu_seqlens, half_index0, half_index1) = ctx.saved_tensors - else: - q, k, v, out, softmax_lse, cu_seqlens = ctx.saved_tensors - half_index0 = ctx.half_index0 - half_index1 = ctx.half_index1 - dq, dk, dv = zigzag_ring_flash_attn_varlen_backward( - ctx.group, - dout, - q, - k, - v, - out, - softmax_lse, - cu_seqlens, - ctx.max_seqlen, - half_index0, - half_index1, - softmax_scale=ctx.softmax_scale, - dropout_p=ctx.dropout_p, - causal=ctx.causal, - window_size=ctx.window_size, - alibi_slopes=ctx.alibi_slopes, - deterministic=ctx.deterministic, - ) - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None - - -def zigzag_ring_flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return ZigZagRingFlashAttnVarlenFunc.apply( - qkv[:, 0], - qkv[:, 1], - qkv[:, 2], - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def zigzag_ring_flash_attn_varlen_kvpacked_func( - q, - kv, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return ZigZagRingFlashAttnVarlenFunc.apply( - q, - kv[:, 0], - kv[:, 1], - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def zigzag_ring_flash_attn_varlen_func( - q, - k, - v, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return ZigZagRingFlashAttnVarlenFunc.apply( - q, - k, - v, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) diff --git a/ring-flash-attention/setup.py b/ring-flash-attention/setup.py deleted file mode 100644 index 58413e1b54f3..000000000000 --- a/ring-flash-attention/setup.py +++ /dev/null @@ -1,9 +0,0 @@ -from setuptools import find_packages, setup - -setup( - name="ring_flash_attn", - version="0.1", - author="zhuzilin", - url="https://github.com/zhuzilin/ring-flash-attention", - packages=find_packages(), -) diff --git a/ring-flash-attention/test/test_ring_flash_attn_func.py b/ring-flash-attention/test/test_ring_flash_attn_func.py deleted file mode 100644 index 50edd03bef4e..000000000000 --- a/ring-flash-attention/test/test_ring_flash_attn_func.py +++ /dev/null @@ -1,124 +0,0 @@ -import random - -import torch -import torch.distributed as dist -from flash_attn import flash_attn_qkvpacked_func -from ring_flash_attn import ring_flash_attn_qkvpacked_func - - -def set_seed(rank, seed=42): - seed = rank + seed - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def log(msg, a, rank0_only=False): - world_size = dist.get_world_size() - rank = dist.get_rank() - if rank0_only: - if rank == 0: - print( - f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - return - - for i in range(world_size): - if i == rank: - if rank == 0: - print(f"{msg}:") - print( - f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - dist.barrier() - - -if __name__ == "__main__": - dist.init_process_group("nccl") - rank = dist.get_rank() - set_seed(rank) - world_size = dist.get_world_size() - dtype = torch.bfloat16 - device = torch.device(f"cuda:{rank}") - - batch_size = 1 - seqlen = 3816 - nheads = 5 - d = 128 - dropout_p = 0 - causal = True - deterministic = False - - assert seqlen % world_size == 0 - assert d % 8 == 0 - - qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) - dist.broadcast(qkv, src=0) - - dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) - dist.broadcast(dout, src=0) - - local_qkv = qkv.chunk(world_size, dim=1)[rank].detach().clone() - local_qkv.requires_grad = True - local_dout = dout.chunk(world_size, dim=1)[rank].detach().clone() - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# forward:") - print("#" * 30) - - out, lse, _ = flash_attn_qkvpacked_func( - qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - local_out = out.chunk(world_size, dim=1)[rank] - local_lse = lse.chunk(world_size, dim=-1)[rank] - - fn = ring_flash_attn_qkvpacked_func - - ring_out, ring_lse, _ = fn( - local_qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - log("out", out, rank0_only=True) - log("lse", lse, rank0_only=True) - log("out diff", local_out - ring_out) - log("lse diff", local_lse - ring_lse) - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# backward:") - print("#" * 30) - - out.backward(dout) - dqkv = qkv.grad - local_dqkv = dqkv.chunk(world_size, dim=1)[rank] - - ring_out.backward(local_dout) - ring_dqkv = local_qkv.grad - - log("local_dq", local_dqkv[:, :, 0, :]) - log("dq diff", local_dqkv[:, :, 0, :] - ring_dqkv[:, :, 0, :]) - - log("local_dk", local_dqkv[:, :, 1, :]) - log("dk diff", local_dqkv[:, :, 1, :] - ring_dqkv[:, :, 1, :]) - - log("local_dv", local_dqkv[:, :, 2, :]) - log("dv diff", local_dqkv[:, :, 2, :] - ring_dqkv[:, :, 2, :]) diff --git a/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py b/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py deleted file mode 100644 index 51bb1ec5d67d..000000000000 --- a/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py +++ /dev/null @@ -1,157 +0,0 @@ -import random - -import torch -import torch.distributed as dist -from flash_attn import flash_attn_varlen_qkvpacked_func -from ring_flash_attn import ring_flash_attn_varlen_qkvpacked_func - - -def set_seed(rank, seed=42): - seed = rank + seed - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def log(msg, a, rank0_only=False): - world_size = dist.get_world_size() - rank = dist.get_rank() - if rank0_only: - if rank == 0: - print( - f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - return - - for i in range(world_size): - if i == rank: - if rank == 0: - print(f"{msg}:") - print( - f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - dist.barrier() - - -def extract_local(value, cu_seqlens, rank, world_size): - local_values = [] - for i in range(len(cu_seqlens) - 1): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - local_value = value[start:end].chunk(world_size, dim=0)[rank].detach().clone() - local_values.append(local_value) - return torch.cat(local_values, dim=0).contiguous() - - -def extract_lse(lse, cu_seqlens): - values = [] - for i in range(len(cu_seqlens) - 1): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - value = lse[i, :, : end - start] - values.append(value) - return values - - -if __name__ == "__main__": - dist.init_process_group("nccl") - rank = dist.get_rank() - set_seed(rank) - world_size = dist.get_world_size() - dtype = torch.bfloat16 - device = torch.device(f"cuda:{rank}") - - batch_size = 1 - nheads = 5 - d = 128 - dropout_p = 0 - causal = True - deterministic = False - - cu_seqlens = [0, 120, 1248, 4232] - cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) - max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item() - total_length = cu_seqlens[-1] - num_seq = len(cu_seqlens) - 1 - - assert torch.all(cu_seqlens_tensor % world_size == 0) - assert d % 8 == 0 - - qkv = torch.randn(total_length, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) - dist.broadcast(qkv, src=0) - - dout = torch.randn(total_length, nheads, d, device=device, dtype=dtype) - dist.broadcast(dout, src=0) - - local_cu_seqlens_tensor = cu_seqlens_tensor // world_size - local_max_seqlen = max_seqlen // world_size - - local_qkv = extract_local(qkv, cu_seqlens, rank, world_size) - local_qkv.requires_grad = True - local_dout = extract_local(dout, cu_seqlens, rank, world_size) - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# forward:") - print("#" * 30) - - out, lse, _ = flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens_tensor, - max_seqlen, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - local_out = extract_local(out, cu_seqlens, rank, world_size) - lse_list = extract_lse(lse, cu_seqlens) - - ring_out, ring_lse, _ = ring_flash_attn_varlen_qkvpacked_func( - local_qkv, - local_cu_seqlens_tensor, - local_max_seqlen, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - ring_lse_list = extract_lse(ring_lse, local_cu_seqlens_tensor.tolist()) - - log("out", out, rank0_only=True) - log("out diff", local_out - ring_out) - - for lse, ring_lse in zip(lse_list, ring_lse_list): - local_lse = lse.chunk(world_size, dim=-1)[rank] - log("lse", lse, rank0_only=True) - log("lse diff", local_lse - ring_lse) - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# backward:") - print("#" * 30) - - out.backward(dout) - dqkv = qkv.grad - local_dqkv = extract_local(dqkv, cu_seqlens, rank, world_size) - - ring_out.backward(local_dout) - ring_dqkv = local_qkv.grad - - log("local_dq", local_dqkv[:, 0]) - log("dq diff", local_dqkv[:, 0] - ring_dqkv[:, 0]) - - log("local_dk", local_dqkv[:, 1]) - log("dk diff", local_dqkv[:, 1] - ring_dqkv[:, 1]) - - log("local_dv", local_dqkv[:, 2]) - log("dv diff", local_dqkv[:, 2] - ring_dqkv[:, 2]) diff --git a/ring-flash-attention/test/test_stripe_flash_attn_func.py b/ring-flash-attention/test/test_stripe_flash_attn_func.py deleted file mode 100644 index dc9f5248d69d..000000000000 --- a/ring-flash-attention/test/test_stripe_flash_attn_func.py +++ /dev/null @@ -1,130 +0,0 @@ -import random - -import torch -import torch.distributed as dist -from flash_attn import flash_attn_qkvpacked_func -from ring_flash_attn import stripe_flash_attn_qkvpacked_func - - -def set_seed(rank, seed=42): - seed = rank + seed - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def log(msg, a, rank0_only=False): - world_size = dist.get_world_size() - rank = dist.get_rank() - if rank0_only: - if rank == 0: - print( - f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - return - - for i in range(world_size): - if i == rank: - if rank == 0: - print(f"{msg}:") - print( - f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - dist.barrier() - - -def extract_local(value, rank, world_size, dim=1): - value = torch.stack(value.split(world_size, dim=dim), dim=dim).transpose(dim, dim + 1) - slicer = [rank if i == dim else slice(None) for i in range(len(value.shape))] - return value[slicer].contiguous() - - -if __name__ == "__main__": - dist.init_process_group("nccl") - rank = dist.get_rank() - set_seed(rank) - world_size = dist.get_world_size() - dtype = torch.bfloat16 - device = torch.device(f"cuda:{rank}") - - batch_size = 1 - seqlen = 3824 - nheads = 5 - d = 128 - dropout_p = 0 - causal = True - deterministic = False - - assert causal - assert seqlen % (2 * world_size) == 0 - assert d % 8 == 0 - - qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) - dist.broadcast(qkv, src=0) - - dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) - dist.broadcast(dout, src=0) - - local_qkv = extract_local(qkv, rank, world_size).detach().clone() - local_qkv.requires_grad = True - local_dout = extract_local(dout, rank, world_size).detach().clone() - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# forward:") - print("#" * 30) - - out, lse, _ = flash_attn_qkvpacked_func( - qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - local_out = extract_local(out, rank, world_size) - local_lse = extract_local(lse, rank, world_size, dim=2) - - ring_out, ring_lse, _ = stripe_flash_attn_qkvpacked_func( - local_qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - log("out", out, rank0_only=True) - log("lse", lse, rank0_only=True) - log("out diff", local_out - ring_out) - log("lse diff", local_lse - ring_lse) - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# backward:") - print("#" * 30) - - out.backward(dout) - dqkv = qkv.grad - - local_dqkv = extract_local(dqkv, rank, world_size) - - ring_out.backward(local_dout) - ring_dqkv = local_qkv.grad - - log("local_dq", local_dqkv[:, :, 0, :]) - log("dq diff", local_dqkv[:, :, 0, :] - ring_dqkv[:, :, 0, :]) - - log("local_dk", local_dqkv[:, :, 1, :]) - log("dk0 diff", local_dqkv[:, :, 1, :] - ring_dqkv[:, :, 1, :]) - - log("local_dv", local_dqkv[:, :, 2, :]) - log("dv diff", local_dqkv[:, :, 2, :] - ring_dqkv[:, :, 2, :]) diff --git a/ring-flash-attention/test/test_triton_kernels.py b/ring-flash-attention/test/test_triton_kernels.py deleted file mode 100644 index aa1c1fdcd338..000000000000 --- a/ring-flash-attention/test/test_triton_kernels.py +++ /dev/null @@ -1,30 +0,0 @@ -import torch -from ring_flash_attn.triton_utils import flatten_varlen_lse as triton_flatten_varlen_lse -from ring_flash_attn.triton_utils import unflatten_varlen_lse as triton_unflatten_varlen_lse -from ring_flash_attn.utils import flatten_varlen_lse, unflatten_varlen_lse - -if __name__ == "__main__": - device = torch.device("cuda:0") - - cu_seqlens = [0, 15, 156, 529] - cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) - batch_size = len(cu_seqlens) - 1 - max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item() - n_head = 5 - - lse = torch.randn((batch_size, n_head, max_seqlen), dtype=torch.float32, device=device) - flatten_lse = flatten_varlen_lse(lse, cu_seqlens_tensor) - triton_flatten_lse = triton_flatten_varlen_lse(lse, cu_seqlens_tensor) - assert torch.all(flatten_lse == triton_flatten_lse) - - flatten_lse = flatten_lse.transpose(-2, -1).unsqueeze(dim=-1) - triton_flatten_lse = triton_flatten_lse.transpose(-2, -1).unsqueeze(dim=-1) - - unflatten_lse = unflatten_varlen_lse(flatten_lse, cu_seqlens_tensor, max_seqlen) - triton_unflatten_lse = triton_unflatten_varlen_lse(triton_flatten_lse, cu_seqlens_tensor, max_seqlen) - - for i in range(batch_size): - seqlen = cu_seqlens[i + 1] - cu_seqlens[i] - assert torch.all( - unflatten_lse[i, :, :seqlen] == triton_unflatten_lse[i, :, :seqlen] - ), f"{unflatten_lse[i, :seqlen]} vs {triton_unflatten_lse[i, :seqlen]}" diff --git a/ring-flash-attention/test/test_zigzag_ring_flash_attn_func.py b/ring-flash-attention/test/test_zigzag_ring_flash_attn_func.py deleted file mode 100644 index 5f84bc58cf10..000000000000 --- a/ring-flash-attention/test/test_zigzag_ring_flash_attn_func.py +++ /dev/null @@ -1,150 +0,0 @@ -import os -import random - -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -from flash_attn import flash_attn_qkvpacked_func -from ring_flash_attn import zigzag_ring_flash_attn_qkvpacked_func - -from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention - - -def set_seed(rank, seed=42): - seed = rank + seed - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def log(msg, a, rank0_only=False): - world_size = dist.get_world_size() - rank = dist.get_rank() - if rank0_only: - if rank == 0: - print( - f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - return - - for i in range(world_size): - if i == rank: - if rank == 0: - print(f"{msg}:") - print( - f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - dist.barrier() - - -def extract_local(value, rank, world_size, dim=1): - value_chunks = value.chunk(2 * world_size, dim=dim) - local_value = torch.cat([value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim) - return local_value.contiguous() - - -def run_test(rank, world_size): - os.environ["MASTER_ADDR"] = "localhost" # or the IP of the master node - os.environ["MASTER_PORT"] = "8125" # make sure this port is free - dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) - set_seed(rank) - dtype = torch.bfloat16 - device = torch.device(f"cuda:{rank}") - - batch_size = 1 - seqlen = 3824 - nheads = 5 - d = 128 - dropout_p = 0 - causal = True - deterministic = False - - assert causal - assert seqlen % (2 * world_size) == 0 - assert d % 8 == 0 - - qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) - dist.broadcast(qkv, src=0) - - dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) - dist.broadcast(dout, src=0) - - local_qkv = extract_local(qkv, rank, world_size).detach().clone() - local_qkv.requires_grad = True - extract_local(dout, rank, world_size).detach().clone() - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# forward:") - print("#" * 30) - - out, lse, _ = flash_attn_qkvpacked_func( - qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - local_out = extract_local(out, rank, world_size) - # local_lse = extract_local(lse, rank, world_size, dim=2) - q, k, v = local_qkv.chunk(3, dim=2) - q, k, v = [x.squeeze(2).detach().clone().transpose(1, 2) for x in (q, k, v)] - q.requires_grad = k.requires_grad = v.requires_grad = True - sp_stream = torch.cuda.Stream() - sp_group = dist.new_group() - colo_out = RingAttention.attention(q, k, v, sp_group, sp_stream, AttnMaskType.CAUSAL) - - ring_out, ring_lse, _ = zigzag_ring_flash_attn_qkvpacked_func( - local_qkv, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - log("colo_out", colo_out, rank0_only=True) - log("ring_out", ring_out, rank0_only=True) - # log("lse", lse, rank0_only=True) - log("colo_out - ring_out", colo_out - ring_out) - # log("lse diff", local_lse - ring_lse) - log("ring_out - local_out", ring_out - local_out) - log("colo_out - local_out", colo_out - local_out) - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# backward:") - print("#" * 30) - - colo_out.sum().backward() - qkv.grad - # q, k, v = [x.transpose(1, 2) for x in (q, k, v)] - colo_dq, colo_dk, colo_dv = [x.transpose(1, 2) for x in (q.grad, k.grad, v.grad)] - - ring_out.sum().backward() - ring_dqkv = local_qkv.grad - out.sum().backward() - dqkv = extract_local(qkv.grad, rank, world_size) - - # log("colo_dq", colo_dq) - log("dq diff", colo_dq - ring_dqkv[:, :, 0, :]) - - # log("colo_dk", colo_dk) - log("dk diff", colo_dk - ring_dqkv[:, :, 1, :]) - - # log("colo_dv", colo_dv) - log("dv diff", colo_dv - ring_dqkv[:, :, 2, :]) - log("colo_dv - local_dv", colo_dv - dqkv[:, :, 2, :]) - - -if __name__ == "__main__": - world_size = 4 - mp.spawn(run_test, args=(world_size,), nprocs=world_size, join=True) diff --git a/ring-flash-attention/test/test_zigzag_ring_flash_attn_varlen_func.py b/ring-flash-attention/test/test_zigzag_ring_flash_attn_varlen_func.py deleted file mode 100644 index 7f6eced6e57b..000000000000 --- a/ring-flash-attention/test/test_zigzag_ring_flash_attn_varlen_func.py +++ /dev/null @@ -1,163 +0,0 @@ -import random - -import torch -import torch.distributed as dist -from flash_attn import flash_attn_varlen_qkvpacked_func -from ring_flash_attn import zigzag_ring_flash_attn_varlen_qkvpacked_func - - -def set_seed(rank, seed=42): - seed = rank + seed - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def log(msg, a, rank0_only=False): - world_size = dist.get_world_size() - rank = dist.get_rank() - if rank0_only: - if rank == 0: - print( - f"{msg}: " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - return - - for i in range(world_size): - if i == rank: - if rank == 0: - print(f"{msg}:") - print( - f"[{rank}] " f"max {a.abs().max().item()}, " f"mean {a.abs().mean().item()}", - flush=True, - ) - dist.barrier() - - -def extract_local(value, cu_seqlens, rank, world_size): - local_values = [] - for i in range(len(cu_seqlens) - 1): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - local_value = value[start:end].chunk(2 * world_size, dim=0) - local_values.extend( - [ - local_value[rank].detach().clone(), - local_value[2 * world_size - 1 - rank].detach().clone(), - ] - ) - return torch.cat(local_values, dim=0).contiguous() - - -def extract_lse(lse, cu_seqlens): - values = [] - for i in range(len(cu_seqlens) - 1): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - value = lse[i, :, : end - start] - values.append(value) - return values - - -if __name__ == "__main__": - dist.init_process_group("nccl") - rank = dist.get_rank() - set_seed(rank) - world_size = dist.get_world_size() - dtype = torch.bfloat16 - device = torch.device(f"cuda:{rank}") - - batch_size = 1 - nheads = 5 - d = 128 - dropout_p = 0 - causal = True - deterministic = False - - cu_seqlens = [0, 128, 1248, 4240] - cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) - max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item() - total_length = cu_seqlens[-1] - num_seq = len(cu_seqlens) - 1 - - assert torch.all(cu_seqlens_tensor % (2 * world_size) == 0) - assert d % 8 == 0 - - qkv = torch.randn(total_length, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) - dist.broadcast(qkv, src=0) - - dout = torch.randn(total_length, nheads, d, device=device, dtype=dtype) - dist.broadcast(dout, src=0) - - local_cu_seqlens_tensor = cu_seqlens_tensor // world_size - local_max_seqlen = max_seqlen // world_size - - local_qkv = extract_local(qkv, cu_seqlens, rank, world_size) - local_qkv.requires_grad = True - local_dout = extract_local(dout, cu_seqlens, rank, world_size) - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# forward:") - print("#" * 30) - - out, lse, _ = flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens_tensor, - max_seqlen, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - local_out = extract_local(out, cu_seqlens, rank, world_size) - lse_list = extract_lse(lse, cu_seqlens) - - ring_out, ring_lse, _ = zigzag_ring_flash_attn_varlen_qkvpacked_func( - local_qkv, - local_cu_seqlens_tensor, - local_max_seqlen, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=True, - ) - - ring_lse_list = extract_lse(ring_lse, local_cu_seqlens_tensor.tolist()) - - log("out", out, rank0_only=True) - log("out diff", local_out - ring_out) - - for i, (lse, ring_lse) in enumerate(zip(lse_list, ring_lse_list)): - local_lse = lse.chunk(2 * world_size, dim=-1) - local_lse = torch.cat([local_lse[rank], local_lse[2 * world_size - 1 - rank]], dim=-1) - log(f"lse {i}", lse, rank0_only=True) - log(f"lse diff {i}", local_lse - ring_lse) - - dist.barrier() - if rank == 0: - print("#" * 30) - print("# backward:") - print("#" * 30) - - out.backward(dout) - dqkv = qkv.grad - local_dqkv = extract_local(dqkv, cu_seqlens, rank, world_size) - - ring_out.backward(local_dout) - ring_dqkv = local_qkv.grad - - log("local_dq", local_dqkv[:, 0]) - log("dq diff", local_dqkv - ring_dqkv) - - log("local_dk", local_dqkv[:, 1]) - log("dk diff", local_dqkv[:, 1] - ring_dqkv[:, 1]) - - log("local_dv", local_dqkv[:, 2]) - log("dv diff", local_dqkv[:, 2] - ring_dqkv[:, 2]) diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 805808887db9..2e2ef393ce52 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -56,7 +56,7 @@ def test_ring_attn(seq_len, batch_size, nheads, d, dtype): # Some outliers may seem large, but our errors are still much lower than # than Megatron-LM's context parallel - # https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215 + # (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215) # and the original zigzag implementation: https://github.com/zhuzilin/ring-flash-attention/tree/main atol = rtol = 7e-3 From 52331c9c7d40efdf6fa296c01ae008e971e3d5c7 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 22 Jul 2024 06:29:02 +0000 Subject: [PATCH 36/71] add sp_mode to benchmark; fix varlen interface --- colossalai/shardformer/modeling/llama.py | 85 ------------------- .../test_layer/test_ring_attn.py | 6 +- 2 files changed, 3 insertions(+), 88 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 801e1c91bef4..219933c705e9 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -137,11 +137,7 @@ def llama_model_forward( elif shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) -<<<<<<< HEAD attn_kwargs = ColoAttention.prepare_attn_kwargs( -======= - attn_mask = ColoAttention.prepare_attn_kwargs( ->>>>>>> precision tests passed mask_shape, hidden_states.dtype, hidden_states.device, @@ -150,11 +146,7 @@ def llama_model_forward( invert=(sp_mode != "ring_attn"), ) else: -<<<<<<< HEAD attn_kwargs = self._update_causal_mask(attention_mask, hidden_states, cache_position) -======= - attn_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) ->>>>>>> precision tests passed # Support SP + PP # TODO: support padded casual cu_seqlens across stages @@ -162,24 +154,12 @@ def llama_model_forward( # Ring Attention zigzag batch processing if sp_mode == "ring_attn": assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." -<<<<<<< HEAD if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch( attention_mask, sp_group, hidden_states, position_ids ) else: hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group) -======= - if attn_mask["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: - attn_mask["cu_seqlens"], attn_mask["max_seqlen"], attn_mask["indices"] = get_pad_info( - attn_mask["attention_mask"].squeeze(1).any(dim=-1) - ) # [B, 1, Sq, Skv] -> [B, Sq] - else: - attn_mask["cu_seqlens"] = attn_mask["max_seqlen"] = attn_mask["indices"] = None - batch = [hidden_states, position_ids] - # inputs_embeds, attention_mask["attention_mask"], position_ids = zigzag_split_batch(batch, sp_group) - hidden_states, position_ids = zigzag_split_batch(batch, sp_group) ->>>>>>> precision tests passed elif is_share_sp_tp(sp_mode): hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) @@ -220,11 +200,7 @@ def llama_model_forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, -<<<<<<< HEAD attn_kwargs, -======= - attn_mask, ->>>>>>> precision tests passed position_ids, past_key_values, output_attentions, @@ -234,11 +210,7 @@ def llama_model_forward( else: layer_outputs = decoder_layer( hidden_states, -<<<<<<< HEAD attention_mask=attn_kwargs, -======= - attention_mask=attn_mask, ->>>>>>> precision tests passed position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, @@ -341,7 +313,6 @@ def llama_for_causal_lm_forward( logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False -<<<<<<< HEAD if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output: # Split labels in a zigzag fashion too sp_group = shard_config.sequence_parallel_process_group @@ -350,11 +321,6 @@ def llama_for_causal_lm_forward( else: # [B, max_seqlen // sp_size] labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) -======= - if stage_manager.is_first_stage(): - if shard_config.sequence_parallelism_mode == "ring_attn": - labels = zigzag_split_batch(labels, shard_config.sequence_parallel_process_group) ->>>>>>> precision tests passed # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = LlamaPipelineForwards.llama_model_forward( @@ -596,22 +562,15 @@ def forward( # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - assert not self.q_proj.weight.isnan().any(), self.q_proj.weight - assert not query_states.isnan().any(), query_states if sp_mode == "ring_attn": attn_output = RingAttention.attention( query_states, key_states, value_states, sp_group, -<<<<<<< HEAD **attention_mask, inner_ring_size=shard_config.inner_ring_size, -======= - shard_config.sp_stream, - attention_mask["attention_mask_type"], ->>>>>>> precision tests passed ) elif shard_config.enable_flash_attention: @@ -723,13 +682,8 @@ def forward( position_ids = cache_position.unsqueeze(0) if shard_config.enable_flash_attention: -<<<<<<< HEAD mask_shape = (batch_size, 1, seq_len, past_seen_tokens + seq_len) attn_kwargs: dict = ColoAttention.prepare_attn_kwargs( -======= - mask_shape = (batch_size, 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len) - attn_mask: dict = ColoAttention.prepare_attn_kwargs( ->>>>>>> precision tests passed mask_shape, inputs_embeds.dtype, inputs_embeds.device, @@ -739,16 +693,11 @@ def forward( ) else: -<<<<<<< HEAD attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) -======= - attn_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) ->>>>>>> precision tests passed # Ring Attention zigzag batch processing if sp_mode == "ring_attn": assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." -<<<<<<< HEAD if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: inputs_embeds, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch( attention_mask, sp_group, inputs_embeds, position_ids @@ -758,24 +707,6 @@ def forward( attn_kwargs = {"attention_mask_type": attn_kwargs["attention_mask_type"]} # drop redundant tensors elif is_share_sp_tp(sp_mode): -======= - if attn_mask["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: - attn_mask["cu_seqlens"], attn_mask["max_seqlen"], attn_mask["indices"] = get_pad_info( - attn_mask["attention_mask"].squeeze(1).any(dim=-1) - ) # [B, 1, Sq, Skv] -> [B, Sq] - - else: - attn_mask["cu_seqlens"] = attn_mask["max_seqlen"] = attn_mask["indices"] = None - batch = [inputs_embeds, position_ids] - # inputs_embeds, attention_mask["attention_mask"], position_ids = zigzag_split_batch(batch, sp_group) - inputs_embeds, position_ids = zigzag_split_batch(batch, sp_group) - -<<<<<<< HEAD - elif sp_mode in ["ring", "split_gather"]: ->>>>>>> precision tests passed -======= - elif is_share_sp_tp(sp_mode): ->>>>>>> precision tests passed inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) elif sp_mode == "all_to_all": inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) @@ -793,11 +724,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, -<<<<<<< HEAD attn_kwargs, -======= - attn_mask, ->>>>>>> precision tests passed position_ids, past_key_values, output_attentions, @@ -808,11 +735,7 @@ def forward( else: layer_outputs = decoder_layer( hidden_states, -<<<<<<< HEAD attention_mask=attn_kwargs, -======= - attention_mask=attn_mask, ->>>>>>> precision tests passed position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, @@ -904,8 +827,6 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if shard_config.sequence_parallelism_mode == "ring_attn": -<<<<<<< HEAD -<<<<<<< HEAD labels = zigzag_split_batch(labels, shard_config.sequence_parallel_process_group) if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output: @@ -916,12 +837,6 @@ def forward( else: # [B, max_seq_len // sp_size] labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) -======= - labels = zigzag_split_batch([labels], shard_config.sequence_parallel_process_group)[0] ->>>>>>> precision tests passed -======= - labels = zigzag_split_batch(labels, shard_config.sequence_parallel_process_group) ->>>>>>> precision tests passed # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 2e2ef393ce52..16ffdd4d0c8b 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -54,10 +54,10 @@ def test_ring_attn(seq_len, batch_size, nheads, d, dtype): sp_group = dist.group.WORLD sp_stream = torch.cuda.Stream() - # Some outliers may seem large, but our errors are still much lower than - # than Megatron-LM's context parallel + # Some outliers may seem large, but our errors are still lower than + # than Megatron-LM's context parallel's # (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215) - # and the original zigzag implementation: https://github.com/zhuzilin/ring-flash-attention/tree/main + # and the original zigzag implementation's (https://github.com/zhuzilin/ring-flash-attention/tree/main) atol = rtol = 7e-3 # Setup inputs From c70d03a3f44242d42c80e2b53880aeb3c1193af7 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 22 Jul 2024 07:42:50 +0000 Subject: [PATCH 37/71] update softmax_lse shape by new interface --- colossalai/shardformer/layer/_operation.py | 12 ++-- .../test_layer/test_ring_attn.py | 66 ------------------- 2 files changed, 4 insertions(+), 74 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index c6f61d3bb99f..e031fecc15e0 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -812,7 +812,11 @@ def backward(ctx, *grad_output): process_group = ctx.process_group scatter_dim = ctx.gather_dim gather_dim = ctx.scatter_dim + if torch.distributed.get_rank() == 0: + print(f"shape before A2A: {grad_output[0].shape}") return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) + if torch.distributed.get_rank() == 0: + print(f"shape after A2A: {return_grad.shape}") return (return_grad, None, None, None) @@ -999,19 +1003,11 @@ def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1): return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) -<<<<<<< HEAD def gather_sp_output(hidden_states, sp_group, sp_mode, sp_dim=1): -======= -def gather_sp_output(hidden_states, sp_group, sp_mode): ->>>>>>> fwd bwd logic complete """ Gather the output of the last layer for cross entropy computation """ # Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group) scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group) -<<<<<<< HEAD hidden_states = gather_forward_split_backward(hidden_states, sp_dim, sp_group, grad_scale=scale) -======= - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=scale) ->>>>>>> fwd bwd logic complete return hidden_states diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 16ffdd4d0c8b..5ca618bc8535 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -1,6 +1,5 @@ import torch import torch.distributed as dist -<<<<<<< HEAD import torch.nn.functional as F from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_func from torch.testing import assert_close @@ -32,44 +31,11 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): # Setup inputs qkv = torch.randn(bs, seq_len, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) local_qkv = split_batch_zigzag(qkv, sp_group) -======= -from flash_attn import flash_attn_qkvpacked_func -from torch.testing import assert_close - -import colossalai -from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention -from colossalai.shardformer.layer.utils import zigzag_split_batch -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn - - -@parameterize("seq_len", [4096]) -@parameterize("batch_size", [1]) -@parameterize("nheads", [5]) -@parameterize("d", [128]) -@parameterize("dtype", [torch.bfloat16]) -def test_ring_attn(seq_len, batch_size, nheads, d, dtype): - torch.cuda.manual_seed(2) - rank = dist.get_rank() - device = torch.device(f"cuda:{rank}") - sp_group = dist.group.WORLD - sp_stream = torch.cuda.Stream() - - # Some outliers may seem large, but our errors are still lower than - # than Megatron-LM's context parallel's - # (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215) - # and the original zigzag implementation's (https://github.com/zhuzilin/ring-flash-attention/tree/main) - atol = rtol = 7e-3 - - # Setup inputs - qkv = torch.randn(batch_size, seq_len, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) - local_qkv = zigzag_split_batch(qkv, sp_group) ->>>>>>> precision tests passed q, k, v = local_qkv.unbind(dim=-3) q, k, v = [x.squeeze(2).detach().clone().transpose(1, 2) for x in (q, k, v)] # (B, nHeads, Sq, D) q.requires_grad = k.requires_grad = v.requires_grad = True # Ring attention vs single GPU -<<<<<<< HEAD ring_out, ring_lse = RingAttention.attention( q, k, @@ -81,14 +47,10 @@ def test_ring_attn(seq_len, batch_size, nheads, d, dtype): # inner_ring_size=4 ) ring_out = ring_out.transpose(1, 2) -======= - ring_out, ring_lse = RingAttention.attention(q, k, v, sp_group, sp_stream, AttnMaskType.CAUSAL, return_softmax=True) ->>>>>>> precision tests passed out, lse, _ = flash_attn_qkvpacked_func( qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True ) -<<<<<<< HEAD # Checkout out and softmax denominator local_out = split_batch_zigzag(out, sp_group) local_lse = split_batch_zigzag(lse, sp_group, seq_dim=-1) @@ -97,18 +59,10 @@ def test_ring_attn(seq_len, batch_size, nheads, d, dtype): assert_close(ring_out, local_out, atol=atol, rtol=rtol) # Check grads -======= - local_out = zigzag_split_batch(out, sp_group) - local_lse = zigzag_split_batch(lse, sp_group, seq_dim=-1) - assert_close(ring_out, local_out, atol=atol, rtol=rtol) - assert_close(ring_lse, local_lse, atol=atol, rtol=rtol) - ->>>>>>> precision tests passed ring_out.sum().backward() out.sum().backward() ring_dq, ring_dk, ring_dv = [x.transpose(1, 2) for x in (q.grad, k.grad, v.grad)] dqkv = qkv.grad -<<<<<<< HEAD local_dqkv = split_batch_zigzag(dqkv, sp_group) assert_close(ring_dq, local_dqkv[:, :, 0], atol=atol, rtol=rtol) @@ -231,23 +185,3 @@ def test_double_ring(world_size): if __name__ == "__main__": test_ring_attn() test_double_ring() -======= - local_dqkv = zigzag_split_batch(dqkv, sp_group) - assert_close(ring_dq, local_dqkv[:, :, 0], atol=atol, rtol=rtol) - assert_close(ring_dk, local_dqkv[:, :, 1], atol=atol, rtol=rtol) - assert_close(ring_dv, local_dqkv[:, :, 2], atol=atol, rtol=rtol) - - -def launch(rank, world_size, port): - colossalai.launch(rank, world_size, "localhost", port) - test_ring_attn() - - -@rerun_if_address_is_in_use() -def run_ring_attn(): - spawn(launch, nprocs=8) - - -if __name__ == "__main__": - run_ring_attn() ->>>>>>> precision tests passed From 9c8334371f68ab6a7ab508a44bbe52a7be530994 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 24 Jul 2024 13:54:54 +0000 Subject: [PATCH 38/71] add varlen tests --- .../booster/plugin/hybrid_parallel_plugin.py | 4 ++ colossalai/shardformer/layer/attn.py | 28 ++------ colossalai/shardformer/layer/utils.py | 66 +++++++++++++++++++ colossalai/shardformer/modeling/llama.py | 35 ++++++++++ .../test_layer/test_ring_attn.py | 66 +++++++++++++++++++ .../test_model/test_shard_llama.py | 2 +- 6 files changed, 179 insertions(+), 22 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index d233ccc2ae15..421a2bc9712b 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1244,7 +1244,11 @@ def configure( zero_stage = 0 if not isinstance(model, ModelWrapper): +<<<<<<< HEAD # Shouldn't use pp (frequent grad accumulation) with torch ddp +======= + # Can't use pp (frequent grad accumulation) with torch ddp +>>>>>>> add varlen tests use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( self.dp_size == 1 and self.pp_size == 1 ) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index bc352e9f2790..b72d09575f47 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -18,7 +18,7 @@ from .utils import RingComm, get_half_index, split_varlen_zigzag -from .utils import RingComm +from .utils import RingComm, split_varlen_zigzag __all__ = [ "AttnMaskType", @@ -379,15 +379,11 @@ def _rescale_out_lse_kernel( stride_out_0, stride_out_1, stride_out_2, - stride_out_3, stride_out_per_step_0, stride_out_per_step_1, stride_out_per_step_2, - stride_out_per_step_3, stride_lse_0, stride_lse_1, - stride_lse_2, - stride_lse_3, BLOCK_M: tl.constexpr, ): batch_id = tl.program_id(0) @@ -395,15 +391,10 @@ def _rescale_out_lse_kernel( h_id = tl.program_id(2) d_id = tl.arange(0, D) - out_idx = batch_id * stride_out_0 + sq_id * stride_out_1 + h_id * stride_out_2 + d_id * stride_out_3 - out_per_step_idx = ( - batch_id * stride_out_per_step_0 - + sq_id * stride_out_per_step_1 - + h_id * stride_out_per_step_2 - + d_id * stride_out_per_step_3 - ) - lse_idx = batch_id * stride_lse_0 + h_id * stride_lse_1 + sq_id * stride_lse_2 + tl.zeros(D) * stride_lse_3 - lse_step_idx = batch_id * stride_lse_0 + h_id * stride_lse_1 + sq_id * stride_lse_2 + tl.zeros(D) * stride_lse_3 + out_idx = batch_id * stride_out_0 + sq_id * stride_out_1 + h_id * stride_out_2 + d_id + out_per_step_idx = batch_id * stride_out_per_step_0 + sq_id * stride_out_per_step_1 + h_id * stride_out_per_step_2 + lse_idx = batch_id * stride_lse_0 + h_id * stride_lse_1 + lse_step_idx = batch_id * stride_lse_0 + h_id * stride_lse_1 # Load inputs out = tl.load(out_ptr + out_idx) @@ -420,7 +411,7 @@ def _rescale_out_lse_kernel( def _rescale_out_lse_triton(out, block_out, lse, block_lse): - B, Sq, H, D = out.shape + T, H, D = out.shape assert out.is_contiguous() and block_out.is_contiguous() and lse.is_contiguous() and block_lse.is_contiguous() @@ -431,22 +422,17 @@ def _rescale_out_lse_triton(out, block_out, lse, block_lse): block_out, lse, block_lse, - B, - Sq, + T, H, D, out.stride(0), out.stride(1), out.stride(2), - out.stride(3), block_out.stride(0), block_out.stride(1), block_out.stride(2), - block_out.stride(3), lse.stride(0), lse.stride(1), - lse.stride(2), - lse.stride(3), ) diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index ce423d1dd00c..8638f84a8c67 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -303,6 +303,7 @@ def create_randomizer_with_offset( return Randomizer(seed=base_seed) +<<<<<<< HEAD <<<<<<< HEAD <<<<<<< HEAD def split_batch_zigzag( @@ -313,6 +314,9 @@ def zigzag_split_batch(batch: List[torch.Tensor], sp_group: ProcessGroup, varlen >>>>>>> precision tests passed ======= def zigzag_split_batch( +======= +def split_batch_zigzag( +>>>>>>> add varlen tests batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim=1, varlen: bool = False ): >>>>>>> precision tests passed @@ -483,6 +487,65 @@ def split_varlen_zigzag( return batch +def split_varlen_zigzag( + batch: Union[List[torch.Tensor], torch.Tensor], + cu_seqlens: torch.Tensor, + sp_group: ProcessGroup, + is_2d: bool = False, + max_seq_len: int = 0, +) -> Union[List[torch.Tensor], torch.Tensor]: + """Split each sequence in a batch of packed sequences/indices in a zigzag fashion. + + Args: + batch (List[torch.Tensor]): Packed sequences of shape (B * Sq), or (B, Sq) if is_2d + cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) + sp_group (ProcessGroup): The process group for sequence parallelism. + is_2d (bool): Whether the input is 2D or 1D. + max_seq_len (int): The maximum sequence length in the batch before splitting. + Returns: + batch (List[torch.Tensor]): Unpacked sequences of shape (B * Sq // sp_size) + """ + sp_size = dist.get_world_size(sp_group) + sp_rank = dist.get_rank(sp_group) + + if isinstance(batch, torch.Tensor): + batch = [batch] + for i, packed_seq in enumerate(batch): + if is_2d: + assert max_seq_len % sp_size == 0 + shape = (packed_seq.shape[0], max_seq_len // sp_size, *packed_seq.shape[2:]) + local_seq = torch.zeros(shape, dtype=packed_seq.dtype, device=packed_seq.device) + else: + local_seq = [] + + for j in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[j], cu_seqlens[j + 1] + seqlen = end - start + assert ( + seqlen % (2 * sp_size) == 0 + ), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting" + + if is_2d: + seq = packed_seq[j][:seqlen].chunk(2 * sp_size, dim=0) + local_seq[j][: seqlen // sp_size] = torch.cat([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]], dim=0) + else: + seq = packed_seq[start:end].chunk(2 * sp_size, dim=0) + seq.extend( + [ + seq[sp_rank], + seq[2 * sp_size - 1 - sp_rank], + ] + ) + if is_2d: + batch[i] = local_seq + else: + batch[i] = torch.cat(local_seq, dim=0).contiguous() + + if len(batch) == 1: + batch = batch[0] + return batch + + class RingComm: def __init__(self, process_group: dist.ProcessGroup): self._process_group = process_group @@ -526,6 +589,7 @@ def is_share_sp_tp(sp_mode: str): to correctly get logits at each positions. """ return sp_mode in ["ring", "split_gather"] +<<<<<<< HEAD <<<<<<< HEAD @@ -744,3 +808,5 @@ def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): new_lse[i, : end - start] = lse[start:end] return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous() >>>>>>> precision tests passed +======= +>>>>>>> add varlen tests diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 219933c705e9..bf89efdff125 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -154,12 +154,23 @@ def llama_model_forward( # Ring Attention zigzag batch processing if sp_mode == "ring_attn": assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." +<<<<<<< HEAD if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch( attention_mask, sp_group, hidden_states, position_ids ) else: hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group) +======= + if attn_mask["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: + attn_mask["cu_seqlens"], attn_mask["max_seqlen"], attn_mask["indices"] = get_pad_info( + attn_mask["attention_mask"].squeeze(1).any(dim=-1) + ) # [B, 1, Sq, Skv] -> [B, Sq] + + batch = [hidden_states, position_ids] + # inputs_embeds, attention_mask["attention_mask"], position_ids = zigzag_split_batch(batch, sp_group) + hidden_states, position_ids = split_batch_zigzag(batch, sp_group) +>>>>>>> add varlen tests elif is_share_sp_tp(sp_mode): hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) @@ -313,6 +324,7 @@ def llama_for_causal_lm_forward( logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False +<<<<<<< HEAD if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output: # Split labels in a zigzag fashion too sp_group = shard_config.sequence_parallel_process_group @@ -321,6 +333,11 @@ def llama_for_causal_lm_forward( else: # [B, max_seqlen // sp_size] labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) +======= + if stage_manager.is_first_stage(): + if shard_config.sequence_parallelism_mode == "ring_attn": + labels = split_batch_zigzag(labels, shard_config.sequence_parallel_process_group) +>>>>>>> add varlen tests # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = LlamaPipelineForwards.llama_model_forward( @@ -565,12 +582,16 @@ def forward( if sp_mode == "ring_attn": attn_output = RingAttention.attention( +<<<<<<< HEAD query_states, key_states, value_states, sp_group, **attention_mask, inner_ring_size=shard_config.inner_ring_size, +======= + query_states, key_states, value_states, sp_group, shard_config.sp_stream, **attention_mask +>>>>>>> add varlen tests ) elif shard_config.enable_flash_attention: @@ -698,6 +719,7 @@ def forward( # Ring Attention zigzag batch processing if sp_mode == "ring_attn": assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." +<<<<<<< HEAD if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: inputs_embeds, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch( attention_mask, sp_group, inputs_embeds, position_ids @@ -705,6 +727,15 @@ def forward( else: inputs_embeds, position_ids = split_batch_zigzag([inputs_embeds, position_ids], sp_group) attn_kwargs = {"attention_mask_type": attn_kwargs["attention_mask_type"]} # drop redundant tensors +======= + if attn_mask["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: + inputs_embeds, position_ids, attn_mask = RingAttention.prepare_varlen_batch( + inputs_embeds, attn_mask["attention_mask"], sp_group, batch_size, position_ids + ) + else: + inputs_embeds, position_ids = split_batch_zigzag([inputs_embeds, position_ids], sp_group) + attn_mask = attn_mask["attention_mask_type"] # drop redundant tensors +>>>>>>> add varlen tests elif is_share_sp_tp(sp_mode): inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) @@ -827,6 +858,7 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if shard_config.sequence_parallelism_mode == "ring_attn": +<<<<<<< HEAD labels = zigzag_split_batch(labels, shard_config.sequence_parallel_process_group) if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output: @@ -838,6 +870,9 @@ def forward( # [B, max_seq_len // sp_size] labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) +======= + labels = split_batch_zigzag(labels, shard_config.sequence_parallel_process_group) +>>>>>>> add varlen tests # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 5ca618bc8535..1569a50f0b22 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -5,20 +5,37 @@ from torch.testing import assert_close import colossalai +<<<<<<< HEAD from colossalai.shardformer.layer import AttnMaskType from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag +======= +from colossalai.shardformer.layer import AttnMaskType, ColoAttention +from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention +from colossalai.shardformer.layer.utils import split_batch_zigzag +>>>>>>> add varlen tests from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device @parameterize("seq_len", [4096]) +<<<<<<< HEAD @parameterize("bs", [2]) @parameterize("nheads", [5]) @parameterize("d", [128]) @parameterize("dtype", [torch.bfloat16, torch.float16]) def check_ring_attn(seq_len, bs, nheads, d, dtype): torch.cuda.manual_seed(2) +======= +@parameterize("bs", [1]) +@parameterize("nheads", [5]) +@parameterize("d", [128]) +@parameterize("dtype", [torch.bfloat16]) +def check_ring_attn(seq_len, bs, nheads, d, dtype): + torch.cuda.manual_seed(2) + dist.get_rank() + dist.get_world_size() +>>>>>>> add varlen tests device = get_current_device() sp_group = dist.group.WORLD sp_size = dist.get_world_size() @@ -64,7 +81,10 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): ring_dq, ring_dk, ring_dv = [x.transpose(1, 2) for x in (q.grad, k.grad, v.grad)] dqkv = qkv.grad local_dqkv = split_batch_zigzag(dqkv, sp_group) +<<<<<<< HEAD +======= +>>>>>>> add varlen tests assert_close(ring_dq, local_dqkv[:, :, 0], atol=atol, rtol=rtol) assert_close(ring_dk, local_dqkv[:, :, 1], atol=atol, rtol=rtol) assert_close(ring_dv, local_dqkv[:, :, 2], atol=atol, rtol=rtol) @@ -158,10 +178,56 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): assert_close(dv, dv_ring, atol=atol, rtol=rtol) +<<<<<<< HEAD def launch_single_ring(rank, world_size, port): colossalai.launch(rank, world_size, "localhost", port) check_packed_seq() check_ring_attn() +======= +@parameterize("seq_len", [4096]) +@parameterize("bs", [2]) +@parameterize("nheads", [5]) +@parameterize("d", [128]) +@parameterize("dtype", [torch.bfloat16]) +def check_packed_seq(seq_len, bs, nheads, d, dtype): + device = get_current_device() + sp_group = dist.group.WORLD + sp_stream = torch.cuda.Stream() + atol = rtol = 5e-3 + + # Prepare varlen attention mask + padding_mask = torch.ones((bs, seq_len), dtype=torch.int, device=device) + padding_mask[bs // 2 :, seq_len // 2 :] = 0 + padding_mask[: bs // 2, (seq_len // 4) * 3 :] = 0 + attn_mask = ColoAttention.prepare_attn_kwargs( + (bs, 1, seq_len, seq_len), dtype, padding_mask.device, q_padding_mask=padding_mask, is_causal=True + ) + input_embeds = torch.randn(bs, seq_len, nheads, d, device=device, dtype=dtype, requires_grad=True) + + # Forward + q, k, v = [input_embeds.clone().transpose(1, 2) for _ in range(3)] + colo_out = ColoAttention.attention(q, k, v, **attn_mask) + + input_embeds, _, attn_mask = RingAttention.prepare_varlen_batch(input_embeds, padding_mask, sp_group, bs) + q_ring, k_ring, v_ring = [input_embeds.clone().transpose(1, 2) for _ in range(3)] + ring_out = RingAttention.attention(q_ring, k_ring, v_ring, sp_group, sp_stream, **attn_mask) + + # Check output + colo_out = split_batch_zigzag(colo_out, sp_group) + assert_close(colo_out, ring_out, atol=atol, rtol=rtol) + # Check grads + colo_out.backward() + ring_out.backward() + assert_close(q.grad, q_ring.grad, atol=atol, rtol=rtol) + assert_close(k.grad, k_ring.grad, atol=atol, rtol=rtol) + assert_close(v.grad, v_ring.grad, atol=atol, rtol=rtol) + + +def launch(rank, world_size, port): + colossalai.launch(rank, world_size, "localhost", port) + # check_ring_attn() + check_packed_seq() +>>>>>>> add varlen tests def launch_double_ring(rank, world_size, port): diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index d4f0968f20bc..0adaccbca68e 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -175,7 +175,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "ring_attn", "use_lazy_init": True, - "zero_stage": 1, + "zero_stage": 2, "precision": "bf16", "initial_scale": 1, }, From cc6472a19aa42cf7ca894cc99843083d56e310e9 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 26 Jul 2024 10:00:09 +0000 Subject: [PATCH 39/71] fix typo --- colossalai/shardformer/layer/attn.py | 19 +- colossalai/shardformer/layer/utils.py | 329 +----------------- colossalai/shardformer/modeling/llama.py | 35 -- .../test_layer/test_ring_attn.py | 104 ++++-- 4 files changed, 97 insertions(+), 390 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index b72d09575f47..67901ac5ac83 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -18,7 +18,7 @@ from .utils import RingComm, get_half_index, split_varlen_zigzag -from .utils import RingComm, split_varlen_zigzag +from .utils import RingComm, get_half_index, split_varlen_zigzag __all__ = [ "AttnMaskType", @@ -449,17 +449,19 @@ def _rescale_out_lse(out, block_out, lse, block_lse): # min_scale = torch.min(lse, block_lse) # max_scale = torch.max(lse, block_lse) - # new_lse = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) + # lse.data = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) - assert not (new_lse.isnan().any() or new_lse.isinf().any()), f"lse is nan: {new_lse}" - new_block_lse = torch.exp(block_lse - new_lse) + + new_block_lse = torch.exp(block_lse - lse) out.copy_(torch.exp(lse - new_lse) * out + new_block_lse * block_out) lse.copy_(new_lse) - # block_out = block_out.float() - # assert not lse.isnan().any(), lse - # assert not out.isnan().any(), out + # See https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 + # out.data = (out - F.sigmoid(block_lse - lse) * (out - block_out)) + # lse.data = (lse - F.logsigmoid(lse - block_lse)) + + assert not (lse.isnan().any() or lse.isinf().any()), f"lse is nan: {lse}" class RingAttention(torch.autograd.Function): @@ -726,6 +728,9 @@ def forward( if is_packed: t, h, d = q.shape + # half of each seq + half_idx_front = get_half_index(cu_seqlens, front=True) + half_idx_back = get_half_index(cu_seqlens, front=False) else: b, sq, h, d = q.shape t = b * sq diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 8638f84a8c67..b343998082e6 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -1,17 +1,5 @@ from contextlib import contextmanager -<<<<<<< HEAD -<<<<<<< HEAD -<<<<<<< HEAD from typing import List, Optional, Union -======= -from typing import Dict, List ->>>>>>> add basic ring attn; debug cross entropy -======= -from typing import List ->>>>>>> precision tests passed -======= -from typing import List, Optional, Union ->>>>>>> precision tests passed import torch import torch.distributed as dist @@ -303,23 +291,9 @@ def create_randomizer_with_offset( return Randomizer(seed=base_seed) -<<<<<<< HEAD -<<<<<<< HEAD -<<<<<<< HEAD def split_batch_zigzag( batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False ) -> Union[torch.Tensor, List[torch.Tensor]]: -======= -def zigzag_split_batch(batch: List[torch.Tensor], sp_group: ProcessGroup, varlen: bool = False): ->>>>>>> precision tests passed -======= -def zigzag_split_batch( -======= -def split_batch_zigzag( ->>>>>>> add varlen tests - batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim=1, varlen: bool = False -): ->>>>>>> precision tests passed """ Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask in the causal setting will result in the preceding ranks having much less workload. @@ -327,8 +301,6 @@ def split_batch_zigzag( For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |. Args: -<<<<<<< HEAD -<<<<<<< HEAD batch (List[torch.Tensor] or Tensor): The input tensor(s) to split. sp_group (ProcessGroup): The process group for sequence parallelism. seq_dim (int): The sequence dimension to split. @@ -350,33 +322,6 @@ def split_batch_zigzag( assert tensor.dim() == 2, "Label shape should be (B, Seqlen)" tensor = torch.cat([tensor[:, 1:], torch.full_like(tensor[:, :1], -100)], dim=1) -======= - batch (List[torch.Tensor]): The input tensors to split. -======= - batch (List[torch.Tensor] or Tensor): The input tensor(s) to split. ->>>>>>> precision tests passed - sp_group (ProcessGroup): The process group for sequence parallelism. - seq_dim (int): The sequence dimension to split. - varlen (bool): If the input is padded (aka "packing" mode), such that - sequences in a batch have different lengths, and we need to unpad and - split each sequence evenly by sp_size. - """ - sp_size = dist.get_world_size(sp_group) - sp_rank = dist.get_rank(sp_group) - if isinstance(batch, torch.Tensor): - batch = [batch] - seq_dim = seq_dim if seq_dim != -1 else batch[0].dim() - 1 - - if sp_size > 1: - for idx, tensor in enumerate(batch): - assert ( - tensor.numel() // (sp_size * 2) > 1 - ), f"Bro, the seq length for tensor {idx} in batch is too short to split!" -<<<<<<< HEAD ->>>>>>> precision tests passed -======= - ->>>>>>> precision tests passed tensor = tensor.view( *tensor.shape[:seq_dim], 2 * sp_size, @@ -386,9 +331,7 @@ def split_batch_zigzag( indices = torch.tensor([sp_rank, 2 * sp_size - 1 - sp_rank], device=tensor.device) tensor = tensor.index_select(seq_dim, indices).contiguous() # (B, 2, Sq // (2 * sp_size), ...) -> (B, Sq // sp_size, ...) -<<<<<<< HEAD batch[idx] = tensor.view(*tensor.shape[:seq_dim], -1, *tensor.shape[seq_dim + 2 :]) -<<<<<<< HEAD if len(batch) == 1: return batch[0] @@ -434,10 +377,7 @@ def split_varlen_zigzag( assert max_seqlen % (sp_size * 2) == 0 # Recreate a padded tensor with the new max seqlen shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:]) - if is_label: - local_seq = torch.full(shape, -100, dtype=dtype, device=device) - else: - local_seq = torch.zeros(shape, dtype=dtype, device=device) + local_seq = torch.zeros(shape, dtype=dtype, device=device) else: total_seqlen = cu_seqlens[-1] assert ( @@ -473,126 +413,20 @@ def split_varlen_zigzag( batch[i] = local_seq.contiguous() else: batch[i] = torch.cat(local_seq, dim=0) -======= ->>>>>>> precision tests passed if len(batch) == 1: batch = batch[0] -======= - batch[idx] = tensor.view(*tensor.shape[:seq_dim], -1, *tensor.shape[seq_dim + 2 :]).contiguous() - - if len(batch) == 1: - return batch[0] ->>>>>>> precision tests passed return batch -def split_varlen_zigzag( - batch: Union[List[torch.Tensor], torch.Tensor], - cu_seqlens: torch.Tensor, - sp_group: ProcessGroup, - is_2d: bool = False, - max_seq_len: int = 0, -) -> Union[List[torch.Tensor], torch.Tensor]: - """Split each sequence in a batch of packed sequences/indices in a zigzag fashion. - - Args: - batch (List[torch.Tensor]): Packed sequences of shape (B * Sq), or (B, Sq) if is_2d - cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) - sp_group (ProcessGroup): The process group for sequence parallelism. - is_2d (bool): Whether the input is 2D or 1D. - max_seq_len (int): The maximum sequence length in the batch before splitting. - Returns: - batch (List[torch.Tensor]): Unpacked sequences of shape (B * Sq // sp_size) - """ - sp_size = dist.get_world_size(sp_group) - sp_rank = dist.get_rank(sp_group) - - if isinstance(batch, torch.Tensor): - batch = [batch] - for i, packed_seq in enumerate(batch): - if is_2d: - assert max_seq_len % sp_size == 0 - shape = (packed_seq.shape[0], max_seq_len // sp_size, *packed_seq.shape[2:]) - local_seq = torch.zeros(shape, dtype=packed_seq.dtype, device=packed_seq.device) - else: - local_seq = [] - - for j in range(len(cu_seqlens) - 1): - start, end = cu_seqlens[j], cu_seqlens[j + 1] - seqlen = end - start - assert ( - seqlen % (2 * sp_size) == 0 - ), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting" - - if is_2d: - seq = packed_seq[j][:seqlen].chunk(2 * sp_size, dim=0) - local_seq[j][: seqlen // sp_size] = torch.cat([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]], dim=0) - else: - seq = packed_seq[start:end].chunk(2 * sp_size, dim=0) - seq.extend( - [ - seq[sp_rank], - seq[2 * sp_size - 1 - sp_rank], - ] - ) - if is_2d: - batch[i] = local_seq - else: - batch[i] = torch.cat(local_seq, dim=0).contiguous() - - if len(batch) == 1: - batch = batch[0] - return batch - - -class RingComm: - def __init__(self, process_group: dist.ProcessGroup): - self._process_group = process_group - self._ops = [] - self.rank = dist.get_rank(self._process_group) - self.world_size = dist.get_world_size(self._process_group) - self._reqs = [] - - self.send_rank = (self.rank + 1) % self.world_size - self.recv_rank = (self.rank - 1) % self.world_size - - if process_group is not None: - self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) - self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) - - def send_recv(self, send_tensor: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None) -> torch.Tensor: - if recv_tensor is None: - res = torch.empty_like(send_tensor) - else: - res = recv_tensor - - # NOTE: looks like batch_isend_irecv doesn't deadlock even - # when we never swap send recv ops across ranks - send_op = dist.P2POp(dist.isend, send_tensor, self.send_rank, group=self._process_group) - recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) - self._ops.append(send_op) - self._ops.append(recv_op) - self._reqs = dist.batch_isend_irecv(self._ops) - return res - - def wait(self): - for req in self._reqs: - req.wait() - self._reqs = [] - self._ops = [] - - def is_share_sp_tp(sp_mode: str): """sp_mode "ring" and "split_gather" use the TP group as SP group to split both the vocab and sequence, so we must gather the sequence to correctly get logits at each positions. """ return sp_mode in ["ring", "split_gather"] -<<<<<<< HEAD -<<<<<<< HEAD class RingComm: def __init__(self, process_group: dist.ProcessGroup): self._process_group = process_group @@ -650,163 +484,4 @@ def get_half_index(cu_seqlens, *, front: bool): else: start = (start + end) // 2 index[start:end] = True - return index -======= -# Copied from https://github.com/zhuzilin/ring-flash-attention/tree/main/ring_flash_attn -# Use Triton kernel if installed else use torch -try: - import triton - import triton.language as tl - - @triton.jit - def flatten_kernel( - # pointers to matrices - OUT, - LSE, - CU_SEQLENS, - # strides - stride_out_nheads, - stride_out_seqlen, - stride_lse_batch, - stride_lse_nheads, - stride_lse_seqlen, - # meta-parameters - BLOCK_M: tl.constexpr, - ): - pid_m = tl.program_id(axis=0) - pid_batch = tl.program_id(axis=1) - pid_head = tl.program_id(axis=2) - - start_idx = tl.load(CU_SEQLENS + pid_batch) - seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx - LSE = LSE + pid_batch * stride_lse_batch + pid_head * stride_lse_nheads - OUT = OUT + pid_head * stride_out_nheads + start_idx * stride_out_seqlen - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - - LSE = LSE + rm[:, None] * stride_lse_seqlen - x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) - - OUT = OUT + rm[:, None] * stride_out_seqlen - tl.store(OUT, x, mask=rm[:, None] < seqlen) - - def flatten_varlen_lse(lse, cu_seqlens): - """ - Arguments: - lse: (batch_size, nheads, max_seqlen) - cu_seqlens: (batch_size + 1,) - Return: - flatten_lse: (nheads, total_seqlen) - """ - total_seqlen = cu_seqlens[-1] - batch_size, nheads, max_seqlen = lse.shape - output = torch.empty((nheads, total_seqlen), dtype=lse.dtype, device=lse.device) - - grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) - BLOCK_M = 4 - - with torch.cuda.device(lse.device.index): - flatten_kernel[grid]( - output, - lse, - cu_seqlens, - # strides - output.stride(0), - output.stride(1), - lse.stride(0), - lse.stride(1), - lse.stride(2), - BLOCK_M, - ) - return output - - @triton.jit - def unflatten_kernel( - # pointers to matrices - OUT, - LSE, - CU_SEQLENS, - # strides - stride_out_batch, - stride_out_nheads, - stride_out_seqlen, - stride_lse_seqlen, - stride_lse_nheads, - # meta-parameters - BLOCK_M: tl.constexpr, - ): - pid_m = tl.program_id(axis=0) - pid_batch = tl.program_id(axis=1) - pid_head = tl.program_id(axis=2) - - start_idx = tl.load(CU_SEQLENS + pid_batch) - seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx - LSE = LSE + pid_head * stride_lse_nheads + start_idx * stride_lse_seqlen - OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - - LSE = LSE + rm[:, None] * stride_lse_seqlen - x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) - - OUT = OUT + rm[:, None] * stride_out_seqlen - tl.store(OUT, x, mask=rm[:, None] < seqlen) - - def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): - """ - Arguments: - lse: (total_seqlen, nheads, 1) - cu_seqlens: (batch_size + 1,) - max_seqlen: int - Return: - unflatten_lse: (batch_size, nheads, max_seqlen) - """ - lse = lse.unsqueeze(dim=-1) - batch_size = len(cu_seqlens) - 1 - nheads = lse.shape[1] - output = torch.empty( - (batch_size, nheads, max_seqlen), - dtype=lse.dtype, - device=lse.device, - ) - - grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) - BLOCK_M = 4 - - with torch.cuda.device(lse.device.index): - unflatten_kernel[grid]( - output, - lse, - cu_seqlens, - # strides - output.stride(0), - output.stride(1), - output.stride(2), - lse.stride(0), - lse.stride(1), - BLOCK_M, - ) - return output - -except: - # Triton not installed, use torch instead - @torch.jit.script - def flatten_varlen_lse(lse, cu_seqlens): - new_lse = [] - for i in range(len(cu_seqlens) - 1): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - new_lse.append(lse[i, :, : end - start]) - return torch.cat(new_lse, dim=1) - - @torch.jit.script - def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): - num_seq = len(cu_seqlens) - 1 - num_head = lse.shape[-2] - new_lse = torch.empty((num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device) - for i in range(num_seq): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - new_lse[i, : end - start] = lse[start:end] - return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous() ->>>>>>> precision tests passed -======= ->>>>>>> add varlen tests + return index \ No newline at end of file diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index bf89efdff125..219933c705e9 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -154,23 +154,12 @@ def llama_model_forward( # Ring Attention zigzag batch processing if sp_mode == "ring_attn": assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." -<<<<<<< HEAD if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch( attention_mask, sp_group, hidden_states, position_ids ) else: hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group) -======= - if attn_mask["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: - attn_mask["cu_seqlens"], attn_mask["max_seqlen"], attn_mask["indices"] = get_pad_info( - attn_mask["attention_mask"].squeeze(1).any(dim=-1) - ) # [B, 1, Sq, Skv] -> [B, Sq] - - batch = [hidden_states, position_ids] - # inputs_embeds, attention_mask["attention_mask"], position_ids = zigzag_split_batch(batch, sp_group) - hidden_states, position_ids = split_batch_zigzag(batch, sp_group) ->>>>>>> add varlen tests elif is_share_sp_tp(sp_mode): hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) @@ -324,7 +313,6 @@ def llama_for_causal_lm_forward( logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False -<<<<<<< HEAD if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output: # Split labels in a zigzag fashion too sp_group = shard_config.sequence_parallel_process_group @@ -333,11 +321,6 @@ def llama_for_causal_lm_forward( else: # [B, max_seqlen // sp_size] labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) -======= - if stage_manager.is_first_stage(): - if shard_config.sequence_parallelism_mode == "ring_attn": - labels = split_batch_zigzag(labels, shard_config.sequence_parallel_process_group) ->>>>>>> add varlen tests # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = LlamaPipelineForwards.llama_model_forward( @@ -582,16 +565,12 @@ def forward( if sp_mode == "ring_attn": attn_output = RingAttention.attention( -<<<<<<< HEAD query_states, key_states, value_states, sp_group, **attention_mask, inner_ring_size=shard_config.inner_ring_size, -======= - query_states, key_states, value_states, sp_group, shard_config.sp_stream, **attention_mask ->>>>>>> add varlen tests ) elif shard_config.enable_flash_attention: @@ -719,7 +698,6 @@ def forward( # Ring Attention zigzag batch processing if sp_mode == "ring_attn": assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." -<<<<<<< HEAD if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: inputs_embeds, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch( attention_mask, sp_group, inputs_embeds, position_ids @@ -727,15 +705,6 @@ def forward( else: inputs_embeds, position_ids = split_batch_zigzag([inputs_embeds, position_ids], sp_group) attn_kwargs = {"attention_mask_type": attn_kwargs["attention_mask_type"]} # drop redundant tensors -======= - if attn_mask["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: - inputs_embeds, position_ids, attn_mask = RingAttention.prepare_varlen_batch( - inputs_embeds, attn_mask["attention_mask"], sp_group, batch_size, position_ids - ) - else: - inputs_embeds, position_ids = split_batch_zigzag([inputs_embeds, position_ids], sp_group) - attn_mask = attn_mask["attention_mask_type"] # drop redundant tensors ->>>>>>> add varlen tests elif is_share_sp_tp(sp_mode): inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) @@ -858,7 +827,6 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if shard_config.sequence_parallelism_mode == "ring_attn": -<<<<<<< HEAD labels = zigzag_split_batch(labels, shard_config.sequence_parallel_process_group) if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output: @@ -870,9 +838,6 @@ def forward( # [B, max_seq_len // sp_size] labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) -======= - labels = split_batch_zigzag(labels, shard_config.sequence_parallel_process_group) ->>>>>>> add varlen tests # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 1569a50f0b22..36379f199f90 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -12,8 +12,12 @@ ======= from colossalai.shardformer.layer import AttnMaskType, ColoAttention from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention +<<<<<<< HEAD from colossalai.shardformer.layer.utils import split_batch_zigzag >>>>>>> add varlen tests +======= +from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag +>>>>>>> fix typo from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device @@ -30,12 +34,15 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): @parameterize("bs", [1]) @parameterize("nheads", [5]) @parameterize("d", [128]) -@parameterize("dtype", [torch.bfloat16]) +@parameterize("dtype", [torch.float16, torch.bfloat16]) def check_ring_attn(seq_len, bs, nheads, d, dtype): torch.cuda.manual_seed(2) +<<<<<<< HEAD dist.get_rank() dist.get_world_size() >>>>>>> add varlen tests +======= +>>>>>>> fix typo device = get_current_device() sp_group = dist.group.WORLD sp_size = dist.get_world_size() @@ -53,6 +60,7 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): q.requires_grad = k.requires_grad = v.requires_grad = True # Ring attention vs single GPU +<<<<<<< HEAD ring_out, ring_lse = RingAttention.attention( q, k, @@ -63,6 +71,9 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): inner_ring_size=max(2, sp_size // 2), # inner_ring_size=4 ) +======= + ring_out, ring_lse = RingAttention.attention(q, k, v, sp_group, sp_stream, AttnMaskType.CAUSAL, return_softmax=True) +>>>>>>> fix typo ring_out = ring_out.transpose(1, 2) out, lse, _ = flash_attn_qkvpacked_func( qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True @@ -110,6 +121,7 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): padding_mask[: bs // 2, (seqlen // 4) * 3 :] = 0 padding_mask[:, seqlen // 2 :] = 0 +<<<<<<< HEAD input_embeds = torch.randn(bs, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) # Forward @@ -185,46 +197,92 @@ def launch_single_ring(rank, world_size, port): check_ring_attn() ======= @parameterize("seq_len", [4096]) +======= +@parameterize("seqlen", [16]) +>>>>>>> fix typo @parameterize("bs", [2]) @parameterize("nheads", [5]) @parameterize("d", [128]) -@parameterize("dtype", [torch.bfloat16]) -def check_packed_seq(seq_len, bs, nheads, d, dtype): +@parameterize("dtype", [torch.float16, torch.bfloat16]) +def check_packed_seq(seqlen, bs, nheads, d, dtype): device = get_current_device() sp_group = dist.group.WORLD + sp_size = dist.get_world_size() sp_stream = torch.cuda.Stream() - atol = rtol = 5e-3 + atol = rtol = 7e-3 # Prepare varlen attention mask - padding_mask = torch.ones((bs, seq_len), dtype=torch.int, device=device) - padding_mask[bs // 2 :, seq_len // 2 :] = 0 - padding_mask[: bs // 2, (seq_len // 4) * 3 :] = 0 - attn_mask = ColoAttention.prepare_attn_kwargs( - (bs, 1, seq_len, seq_len), dtype, padding_mask.device, q_padding_mask=padding_mask, is_causal=True + padding_mask = torch.ones((bs, seqlen), dtype=torch.int, device=device) + # padding_mask[: bs // 2, (seqlen // 4) * 3 :] = 0 + padding_mask[:, seqlen // 2 :] = 0 + mask_info = ColoAttention.prepare_attn_kwargs( + (bs, 1, seqlen, seqlen), dtype, padding_mask.device, q_padding_mask=padding_mask, is_causal=True + ) + # input_embeds = torch.randn(bs, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) + input_embeds = ( + torch.arange(seqlen, device=device, dtype=dtype, requires_grad=True) + .repeat(bs, nheads, d, 1) + .permute(0, 3, 1, 2) + .contiguous() ) - input_embeds = torch.randn(bs, seq_len, nheads, d, device=device, dtype=dtype, requires_grad=True) + q, k, v = [input_embeds.clone().transpose(1, 2) for _ in range(3)] # Forward - q, k, v = [input_embeds.clone().transpose(1, 2) for _ in range(3)] - colo_out = ColoAttention.attention(q, k, v, **attn_mask) + # out = ColoAttention.attention(q, k, v, **mask_info) + flat_input = input_embeds.view(-1, nheads, d)[padding_mask.flatten().nonzero().squeeze()] + qkv = torch.stack([flat_input] * 3, dim=1) + qkv.retain_grad() + out, lse, _ = flash_attn_varlen_qkvpacked_func( + qkv, mask_info["cu_seqlens_q"], mask_info["max_seqlen_q"], return_attn_probs=True, causal=True + ) + + input_embeds, _, mask_info = RingAttention.prepare_varlen_batch(input_embeds, padding_mask, sp_group) + # Test the splitting function + local_input = split_varlen_zigzag( + flat_input, mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size + ) + assert (local_input == input_embeds.view(-1, nheads, d)[mask_info["valid_indices"]]).all() + del local_input, flat_input - input_embeds, _, attn_mask = RingAttention.prepare_varlen_batch(input_embeds, padding_mask, sp_group, bs) q_ring, k_ring, v_ring = [input_embeds.clone().transpose(1, 2) for _ in range(3)] - ring_out = RingAttention.attention(q_ring, k_ring, v_ring, sp_group, sp_stream, **attn_mask) + q_ring.retain_grad() + k_ring.retain_grad() + v_ring.retain_grad() + ring_out, ring_lse = RingAttention.attention( + q_ring, k_ring, v_ring, sp_group, sp_stream, **mask_info, pad_output=False, return_softmax=True + ) # Check output - colo_out = split_batch_zigzag(colo_out, sp_group) - assert_close(colo_out, ring_out, atol=atol, rtol=rtol) + # ring_out, out = [x.transpose(1, 2) for x in (ring_out, out)] # to (B, Sq, nHeads, D) + # out = split_varlen_zigzag(out, mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size, is_2d=True) + lse = lse.transpose(0, 1) + out, lse = split_varlen_zigzag( + [out, lse], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size + ) + # assert_close(lse, ring_lse, atol=atol, rtol=rtol) + assert_close(out, ring_out, atol=atol, rtol=rtol) + # Check grads - colo_out.backward() - ring_out.backward() - assert_close(q.grad, q_ring.grad, atol=atol, rtol=rtol) - assert_close(k.grad, k_ring.grad, atol=atol, rtol=rtol) - assert_close(v.grad, v_ring.grad, atol=atol, rtol=rtol) + out.sum().backward() + ring_out.sum().backward() + dq, dk, dv = [ + split_varlen_zigzag( + qkv.grad[:, i], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size + ) + for i in range(3) + ] + dq_ring, dk_ring, dv_ring = [ + x.transpose(1, 2).reshape(-1, nheads, d)[mask_info["valid_indices"]] + for x in (q_ring.grad, k_ring.grad, v_ring.grad) + ] + assert_close(dq, dq_ring, atol=atol, rtol=rtol) + assert_close(dk, dk_ring, atol=atol, rtol=rtol) + assert_close(dv, dv_ring, atol=atol, rtol=rtol) def launch(rank, world_size, port): colossalai.launch(rank, world_size, "localhost", port) +<<<<<<< HEAD # check_ring_attn() check_packed_seq() >>>>>>> add varlen tests @@ -240,6 +298,10 @@ def launch_double_ring(rank, world_size, port): @parameterize("world_size", [2]) def test_ring_attn(world_size): spawn(launch_single_ring, nprocs=world_size) +======= + # check_packed_seq() + check_ring_attn() +>>>>>>> fix typo @rerun_if_address_is_in_use() From ed4ad6d8f025d810793dc7944c8d26e1c1407993 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 1 Aug 2024 03:38:49 +0000 Subject: [PATCH 40/71] all tests passed --- .../booster/plugin/hybrid_parallel_plugin.py | 4 + colossalai/shardformer/layer/attn.py | 27 +++++-- colossalai/shardformer/layer/loss.py | 54 +------------- colossalai/shardformer/layer/utils.py | 5 +- colossalai/shardformer/modeling/llama.py | 5 +- .../test_layer/test_ring_attn.py | 74 +++++++++++++------ 6 files changed, 83 insertions(+), 86 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 421a2bc9712b..539b1586756c 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1244,11 +1244,15 @@ def configure( zero_stage = 0 if not isinstance(model, ModelWrapper): +<<<<<<< HEAD <<<<<<< HEAD # Shouldn't use pp (frequent grad accumulation) with torch ddp ======= # Can't use pp (frequent grad accumulation) with torch ddp >>>>>>> add varlen tests +======= + # Shouldn't use pp (frequent grad accumulation) with torch ddp +>>>>>>> all tests passed use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( self.dp_size == 1 and self.pp_size == 1 ) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 67901ac5ac83..d70bb6f17b85 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -169,11 +169,16 @@ def prepare_attn_kwargs( attention_mask = attention_mask.tril(diagonal=0) attention_mask = attention_mask.expand(b, s_q, s_kv) else: + assert q_padding_mask.shape == ( + b, + s_q, + ), f"q_padding_mask shape {q_padding_mask.shape} should be {b, s_q}." max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) if kv_padding_mask is None: # self attention kv_padding_mask = q_padding_mask max_seqlen_kv, cu_seqlens_kv, kv_indices = max_seqlen_q, cu_seqlens_q, q_indices + attention_mask = q_padding_mask[:, :, None].expand(b, s_q, s_kv).to(dtype=dtype, device=device) else: max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask) attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) @@ -449,19 +454,22 @@ def _rescale_out_lse(out, block_out, lse, block_lse): # min_scale = torch.min(lse, block_lse) # max_scale = torch.max(lse, block_lse) - # lse.data = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) + # new_lse = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) + # NOTE: directly assigning to .data here is buggy + # probably due to casting dtypes/strides new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) - new_block_lse = torch.exp(block_lse - lse) - out.copy_(torch.exp(lse - new_lse) * out + new_block_lse * block_out) - lse.copy_(new_lse) + new_block_lse = torch.exp(block_lse - new_lse) + out = (torch.exp(lse - new_lse) * out + new_block_lse * block_out).to(out) + lse = new_lse + # Equivalent to the above # See https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 - # out.data = (out - F.sigmoid(block_lse - lse) * (out - block_out)) - # lse.data = (lse - F.logsigmoid(lse - block_lse)) - + # out = (out - F.sigmoid(block_lse - lse) * (out - block_out)) + # lse = (lse - F.logsigmoid(lse - block_lse)) assert not (lse.isnan().any() or lse.isinf().any()), f"lse is nan: {lse}" + return out, lse class RingAttention(torch.autograd.Function): @@ -731,6 +739,11 @@ def forward( # half of each seq half_idx_front = get_half_index(cu_seqlens, front=True) half_idx_back = get_half_index(cu_seqlens, front=False) + RingAttention.HALF_INDICES = (half_idx_front, half_idx_back) + RingAttention.CU_SEQLENS = cu_seqlens + + if is_packed: + t, h, d = q.shape else: b, sq, h, d = q.shape t = b * sq diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 01d914e297fe..952693ec2665 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -1,6 +1,5 @@ import torch import torch.distributed as dist -import torch.nn.functional as F from torch.autograd import Function from torch.distributed import ProcessGroup from torch.nn import CrossEntropyLoss @@ -151,11 +150,7 @@ def cross_entropy_1d( def dist_cross_entropy( -<<<<<<< HEAD labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] -======= - labels: torch.Tensor, # [B, S] ->>>>>>> precision tests passed logits: torch.Tensor, # [B, S, Vocab_size] shard_config: ShardConfig, out_features: int, @@ -173,7 +168,6 @@ def dist_cross_entropy( sp_mode = shard_config.sequence_parallelism_mode parallel_output = shard_config.parallel_output is_tp = shard_config.enable_tensor_parallelism -<<<<<<< HEAD is_packed = labels.dim() == 2 if is_packed: bs, seq_len = labels.shape @@ -183,7 +177,6 @@ def dist_cross_entropy( logits = logits.reshape(-1, *logits.shape[2:]) seq_dim = 0 -<<<<<<< HEAD # Shift labels to predict the next token, and remove the tail logit predicting is_sp = sp_size > 1 and (not is_share_sp_tp(sp_mode)) split_labels_here = seq_len // sp_size == logits.size(seq_dim) # ring attn splits labels before forward @@ -200,26 +193,7 @@ def dist_cross_entropy( labels = labels[..., 1:] if split_labels_here: labels = labels.split(seq_len // sp_size, dim=-1)[sp_rank] -======= ->>>>>>> precision tests passed -======= - bs, seq_len = labels.shape - - # Shift labels to predict the next token, and remove the tail logit predicting - is_sp = sp_size > 1 and (not is_share_sp_tp(sp_mode)) - split_labels_here = seq_len // sp_size == logits.size(seq_dim) # ring attn splits labels before forward - if is_sp: - # shift only once - if split_labels_here or (sp_rank == sp_size - 1): - labels = labels[..., 1:] - # Split labels when logits are split - if split_labels_here: - labels = labels.split(seq_len // sp_size, dim=-1)[sp_rank] - -<<<<<<< HEAD - # The rank holding the last seq chunk ->>>>>>> precision tests passed if sp_rank == sp_size - 1: logits = logits[..., :-1, :] # Pad logits and labels to the same shape across all ranks for TP all_reduce @@ -234,37 +208,11 @@ def dist_cross_entropy( labels = torch.cat([labels, padding], dim=seq_dim) else: labels = labels[..., 1:] -<<<<<<< HEAD - logits = logits[..., :-1, :] - labels = labels.contiguous() - logits = logits.contiguous() - num_nonzero = (labels != _IGNORE_IDX).sum() - try: - assert ( - labels.shape == logits.shape[:-1] - ), f"label shape {labels.shape} does not match logit shape {logits.shape}" - except Exception as e: - raise e -======= - logits = logits[..., :-1, :].contiguous() -======= - # Pad to the same shape across all ranks in TP all_reduce - if sp_rank == sp_size - 1: - logits = logits[..., :-1, :] - if is_tp and parallel_output: - pad_shape = [0] * logits.dim() * 2 - pad_shape[-3] = 1 # Right side, dim = -2 - logits = F.pad(logits, pad_shape, value=_IGNORE_IDX) - labels = F.pad(labels, (0, 1, 0, 0), value=_IGNORE_IDX) - else: - labels = labels[..., 1:] logits = logits[..., :-1, :] ->>>>>>> precision tests passed labels = labels.contiguous() logits = logits.contiguous() num_nonzero = (labels != _IGNORE_IDX).sum() assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}" ->>>>>>> precision tests passed # Flatten the tokens loss_fct = CrossEntropyLoss(ignore_index=_IGNORE_IDX, reduction="sum") @@ -295,4 +243,4 @@ def dist_cross_entropy( loss = reduce_forward(loss, sp_group, grad_scale=sp_size) loss, num_nonzero = loss[0], loss[1].detach() loss = (loss / num_nonzero).squeeze() - return loss + return loss \ No newline at end of file diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index b343998082e6..31d0a3b822da 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -377,7 +377,10 @@ def split_varlen_zigzag( assert max_seqlen % (sp_size * 2) == 0 # Recreate a padded tensor with the new max seqlen shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:]) - local_seq = torch.zeros(shape, dtype=dtype, device=device) + if is_label: + local_seq = torch.full(shape, -100, dtype=dtype, device=device) + else: + local_seq = torch.zeros(shape, dtype=dtype, device=device) else: total_seqlen = cu_seqlens[-1] assert ( diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 219933c705e9..b1d1783967bd 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -553,7 +553,10 @@ def forward( kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + try: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + except: + pass if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 36379f199f90..905fa8bb8abe 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -6,11 +6,15 @@ import colossalai <<<<<<< HEAD +<<<<<<< HEAD from colossalai.shardformer.layer import AttnMaskType from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag ======= from colossalai.shardformer.layer import AttnMaskType, ColoAttention +======= +from colossalai.shardformer.layer import AttnMaskType +>>>>>>> all tests passed from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention <<<<<<< HEAD from colossalai.shardformer.layer.utils import split_batch_zigzag @@ -24,6 +28,7 @@ @parameterize("seq_len", [4096]) <<<<<<< HEAD +<<<<<<< HEAD @parameterize("bs", [2]) @parameterize("nheads", [5]) @parameterize("d", [128]) @@ -32,9 +37,12 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): torch.cuda.manual_seed(2) ======= @parameterize("bs", [1]) +======= +@parameterize("bs", [2]) +>>>>>>> all tests passed @parameterize("nheads", [5]) @parameterize("d", [128]) -@parameterize("dtype", [torch.float16, torch.bfloat16]) +@parameterize("dtype", [torch.bfloat16, torch.float16]) def check_ring_attn(seq_len, bs, nheads, d, dtype): torch.cuda.manual_seed(2) <<<<<<< HEAD @@ -170,6 +178,7 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): assert_close(lse, ring_lse, atol=atol, rtol=rtol) assert_close(out, ring_out, atol=atol, rtol=rtol) +<<<<<<< HEAD # Check grads labels = torch.ones(out.shape[0], dtype=dtype, device=device) F.mse_loss(out.sum((-2, -1)), labels).backward() @@ -200,43 +209,42 @@ def launch_single_ring(rank, world_size, port): ======= @parameterize("seqlen", [16]) >>>>>>> fix typo +======= +@parameterize("seqlen", [4096]) +>>>>>>> all tests passed @parameterize("bs", [2]) @parameterize("nheads", [5]) @parameterize("d", [128]) -@parameterize("dtype", [torch.float16, torch.bfloat16]) +@parameterize("dtype", [torch.bfloat16, torch.float16]) def check_packed_seq(seqlen, bs, nheads, d, dtype): device = get_current_device() sp_group = dist.group.WORLD sp_size = dist.get_world_size() sp_stream = torch.cuda.Stream() atol = rtol = 7e-3 - + torch.cuda.manual_seed(2) # Prepare varlen attention mask padding_mask = torch.ones((bs, seqlen), dtype=torch.int, device=device) - # padding_mask[: bs // 2, (seqlen // 4) * 3 :] = 0 + padding_mask[: bs // 2, (seqlen // 4) * 3 :] = 0 padding_mask[:, seqlen // 2 :] = 0 - mask_info = ColoAttention.prepare_attn_kwargs( - (bs, 1, seqlen, seqlen), dtype, padding_mask.device, q_padding_mask=padding_mask, is_causal=True - ) - # input_embeds = torch.randn(bs, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) - input_embeds = ( - torch.arange(seqlen, device=device, dtype=dtype, requires_grad=True) - .repeat(bs, nheads, d, 1) - .permute(0, 3, 1, 2) - .contiguous() - ) - q, k, v = [input_embeds.clone().transpose(1, 2) for _ in range(3)] + + input_embeds = torch.randn(bs, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) # Forward # out = ColoAttention.attention(q, k, v, **mask_info) flat_input = input_embeds.view(-1, nheads, d)[padding_mask.flatten().nonzero().squeeze()] qkv = torch.stack([flat_input] * 3, dim=1) qkv.retain_grad() + + input_embeds, mask_info, _ = RingAttention.prepare_varlen_batch(padding_mask, sp_group, input_embeds) out, lse, _ = flash_attn_varlen_qkvpacked_func( - qkv, mask_info["cu_seqlens_q"], mask_info["max_seqlen_q"], return_attn_probs=True, causal=True + qkv, + mask_info["cu_seqlens"] * sp_size, + mask_info["max_seqlen"] * sp_size, + return_attn_probs=True, + causal=True, + # deterministic=True ) - - input_embeds, _, mask_info = RingAttention.prepare_varlen_batch(input_embeds, padding_mask, sp_group) # Test the splitting function local_input = split_varlen_zigzag( flat_input, mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size @@ -248,23 +256,31 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): q_ring.retain_grad() k_ring.retain_grad() v_ring.retain_grad() + ring_out, ring_lse = RingAttention.attention( - q_ring, k_ring, v_ring, sp_group, sp_stream, **mask_info, pad_output=False, return_softmax=True + q_ring, + k_ring, + v_ring, + sp_group, + sp_stream, + **mask_info, + pad_output=False, + return_softmax=True, + # deterministic=True ) # Check output - # ring_out, out = [x.transpose(1, 2) for x in (ring_out, out)] # to (B, Sq, nHeads, D) - # out = split_varlen_zigzag(out, mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size, is_2d=True) lse = lse.transpose(0, 1) out, lse = split_varlen_zigzag( [out, lse], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size ) - # assert_close(lse, ring_lse, atol=atol, rtol=rtol) + assert_close(lse, ring_lse, atol=atol, rtol=rtol) assert_close(out, ring_out, atol=atol, rtol=rtol) # Check grads - out.sum().backward() - ring_out.sum().backward() + labels = torch.ones(out.shape[0], dtype=dtype, device=device) + F.mse_loss(out.sum((-2, -1)), labels).backward() + F.mse_loss(ring_out.sum((-2, -1)), labels[: ring_out.shape[0]]).backward() dq, dk, dv = [ split_varlen_zigzag( qkv.grad[:, i], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size @@ -275,6 +291,7 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): x.transpose(1, 2).reshape(-1, nheads, d)[mask_info["valid_indices"]] for x in (q_ring.grad, k_ring.grad, v_ring.grad) ] + assert_close(dq, dq_ring, atol=atol, rtol=rtol) assert_close(dk, dk_ring, atol=atol, rtol=rtol) assert_close(dv, dv_ring, atol=atol, rtol=rtol) @@ -282,6 +299,7 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): def launch(rank, world_size, port): colossalai.launch(rank, world_size, "localhost", port) +<<<<<<< HEAD <<<<<<< HEAD # check_ring_attn() check_packed_seq() @@ -300,14 +318,22 @@ def test_ring_attn(world_size): spawn(launch_single_ring, nprocs=world_size) ======= # check_packed_seq() +======= + check_packed_seq() +>>>>>>> all tests passed check_ring_attn() >>>>>>> fix typo @rerun_if_address_is_in_use() +<<<<<<< HEAD @parameterize("world_size", [4]) def test_double_ring(world_size): spawn(launch_double_ring, nprocs=world_size) +======= +def test_ring_attn(): + spawn(launch, nprocs=2) +>>>>>>> all tests passed if __name__ == "__main__": From 36f691d11f18dd03e6df52f0a1517efb6d16457f Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 1 Aug 2024 08:25:47 +0000 Subject: [PATCH 41/71] add dkv_group; fix mask --- colossalai/pipeline/schedule/interleaved_pp.py | 2 -- colossalai/shardformer/layer/attn.py | 1 - colossalai/shardformer/shard/shard_config.py | 1 + tests/test_shardformer/test_layer/test_ring_attn.py | 2 +- 4 files changed, 2 insertions(+), 4 deletions(-) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 8f26f8cb5bb5..412f3896fb80 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -283,8 +283,6 @@ def forward_step( # Load input ids, attention mask and labels micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) - if input_obj is not None: - assert all(not x.isnan().any() for x in input_obj.values()), "NaN detected in input_obj" # for the first stage, input_obj is None # for other stages, input_obj is the output of the previous stage containing hidden_states etc. # Only attention_mask from micro_batch is used diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index d70bb6f17b85..ff5f0fc3b84e 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -178,7 +178,6 @@ def prepare_attn_kwargs( # self attention kv_padding_mask = q_padding_mask max_seqlen_kv, cu_seqlens_kv, kv_indices = max_seqlen_q, cu_seqlens_q, q_indices - attention_mask = q_padding_mask[:, :, None].expand(b, s_q, s_kv).to(dtype=dtype, device=device) else: max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask) attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 589ed730ec79..084c818e18b9 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -56,6 +56,7 @@ class ShardConfig: moe_dp_group: Optional[ProcessGroup] = None ep_group: Optional[ProcessGroup] = None sp_stream: Optional[torch.cuda.Stream] = None + dkv_group: Optional[ProcessGroup] = None # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 905fa8bb8abe..c8ebe8b6b8fb 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -268,7 +268,7 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): return_softmax=True, # deterministic=True ) - + ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d) # Check output lse = lse.transpose(0, 1) out, lse = split_varlen_zigzag( From 6b5d1bf99f033e8d5991923965e383f1c3c219e1 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 1 Aug 2024 09:37:52 +0000 Subject: [PATCH 42/71] remove debug statements --- colossalai/shardformer/modeling/llama.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index b1d1783967bd..219933c705e9 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -553,10 +553,7 @@ def forward( kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids) - try: - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - except: - pass + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} From bdad28ab0b76f66288b93fef36de1e4d9112a6f3 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 2 Aug 2024 03:45:23 +0000 Subject: [PATCH 43/71] add comments --- .../booster/plugin/hybrid_parallel_plugin.py | 1629 ++--------------- colossalai/shardformer/layer/attn.py | 15 +- .../test_layer/test_ring_attn.py | 157 +- 3 files changed, 180 insertions(+), 1621 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 539b1586756c..13b40f97d7ee 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,1477 +1,186 @@ -import ctypes -import random -import warnings -from collections import defaultdict -from contextlib import contextmanager, nullcontext -from copy import deepcopy -from functools import partial -from types import MethodType -from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union - -import numpy as np import torch import torch.distributed as dist -from torch import Tensor, inf -from torch.distributed import ProcessGroup, get_world_size -from torch.nn import Module, SyncBatchNorm -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from torch.utils._pytree import tree_map -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler - -from colossalai.accelerator import get_accelerator -from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer -from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO -from colossalai.cluster import ProcessGroupMesh -from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper -from colossalai.interface.optimizer import DistributedOptim -from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed -from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.quantization import BnbQuantizationConfig, quantize_model -from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer -from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp -from colossalai.shardformer.policies.base_policy import Policy -from colossalai.tensor.colo_parameter import ColoParameter -from colossalai.tensor.d_tensor.api import is_distributed_tensor -from colossalai.tensor.param_op_hook import ColoParamOpHookManager -from colossalai.zero.low_level import LowLevelZeroOptimizer -from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle - -from .pp_plugin_base import PipelinePluginBase - -SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all", "ring_attn"] - -PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} - - -def _convert_floating_point(x, dtype: torch.dtype = torch.float16): - if isinstance(x, torch.Tensor) and torch.is_floating_point(x): - return x.to(dtype) - return x - - -class HybridParallelModule(ModelWrapper, AMPModelMixin): - def __init__( - self, - module: Module, - precision: str, - shard_config: ShardConfig, - dp_group: ProcessGroup, - tp_group: ProcessGroup, - sp_group: ProcessGroup, - use_ddp: bool, - ddp_config: dict, - custom_policy: Policy, - overlap_allgather: bool = False, - ) -> None: - self.stage_manager = shard_config.pipeline_stage_manager - self.shard_config = shard_config - self.dp_group = dp_group - self.tp_group = tp_group - self.sp_group = sp_group - self.use_ddp = use_ddp - self.require_grad_sync = True - self.overlap_allgather = overlap_allgather - - shardformer = ShardFormer(shard_config) - if custom_policy is not None: - assert isinstance(custom_policy, object) - module, self.shared_params = shardformer.optimize(module, policy=custom_policy) - - # setting process groups for shared parameters - self.shared_param_process_groups = [] - for shared_param in self.shared_params: - if len(shared_param) > 0: - self.shared_param_process_groups.append( - self.stage_manager.init_process_group_by_stages(list(shared_param.keys())) - ) - - # setting mixed_precision - self.mixed_precision = None - if precision == "fp16": - self.mixed_precision = torch.float16 - elif precision == "bf16": - self.mixed_precision = torch.bfloat16 - if self.mixed_precision is not None: - module = module.to(self.mixed_precision) - module = module.to(get_accelerator().get_current_device()) - - # setting input type cast when using mixed precision - self.convert_fn = None - if self.mixed_precision is not None: - self.convert_fn = partial(_convert_floating_point, dtype=self.mixed_precision) - - # setting ddp configs - if use_ddp: - # convert model to sync bn - module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group) - # wrap the model with PyTorch DDP - module = DDP(module, process_group=dp_group, **ddp_config) - - super().__init__(module) - if overlap_allgather: - self.op_hook = ZeroOpHook() - for p in module.parameters(): - if p.requires_grad and type(p) is not ColoParameter: - p.__class__ = ColoParameter - p.__init__(p, requires_grad=True) - - def sync_shared_params(self): - for shared_param, group in zip(self.shared_params, self.shared_param_process_groups): - if self.stage_manager.stage in shared_param: - param = shared_param[self.stage_manager.stage] - dist.all_reduce(param.grad, group=group) - dist.barrier() - - @contextmanager - def no_sync(self): - r""" - A context manager to disable automatic gradient synchronization (all-reduce) and allow manual synchronization - when 'no_sync' is active. Alternatively, synchronization will occur in the first forward-backward pass - when exiting the context. - """ - - # Store the current value of 'require_grad_sync' to restore it later. - old_require_grad_sync = self.require_grad_sync - # Disable automatic gradient synchronization. - self.require_grad_sync = False - try: - if self.use_ddp: - # If using data parallel processing (use_ddp), disable synchronization too. - with self.module.no_sync(): - yield - else: - yield - finally: - # Restore the original value of 'require_grad_sync'. - self.require_grad_sync = old_require_grad_sync - - def sync_dp_grads(self): - r""" - Synchronize gradients across data parallelism (DP) if the DP group size is greater than 1. - This function performs an all-reduce operation to combine gradients from different devices in the DP group. - - Args: - None - - Returns: - None - """ - - # Check if the DP group size is 1, meaning no synchronization is needed. - if self.dp_group.size() == 1: - return - - # Iterate through the model's parameters and perform gradient synchronization. - for p in self.module.parameters(): - if p.grad is not None: - # Perform all-reduce to combine gradients from different devices. - dist.all_reduce(p.grad, group=self.dp_group) - # Normalize the gradient by dividing it by the DP group size. - p.grad.div_(self.dp_group.size()) - - def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None): - r""" - Synchronize gradients that are partially derived within sequence parallelism - if sequence parallelism is enabled. Gradients can be provided explicitly or extracted - from the module. - - Args: - grads (Optional[List[torch.Tensor]]): A list of gradient tensors to synchronize. If not - provided, gradients will be extracted from the model. - - Returns: - None - """ - - if self.shard_config.enable_sequence_parallelism: - if self.shard_config.sequence_parallelism_mode in ["all_to_all", "ring_attn"]: - return - - if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: - # If sequence parallelism is enabled and mode is split_gather or ring, gradients are synchronized - # across the tensor parallelism group. - group = self.tp_group - else: - raise ValueError(f"Unknown sequence parallelism mode: {self.shard_config.sequence_parallelism_mode}") - - if grads is not None: - # Synchronize provided gradient tensors across the tensor parallelism group. - SeqParallelUtils.allreduce_partial_data_grad(process_group=group, grads=grads) - else: - # Synchronize gradients from the model across the tensor parallelism group. - SeqParallelUtils.allreduce_partial_data_grad(process_group=group, model=self.module) - - def forward(self, *args, **kwargs): - if self.convert_fn is not None: - args = tree_map(self.convert_fn, args) - kwargs = tree_map(self.convert_fn, kwargs) - with self._wait_all_gather(): - return super().forward(*args, **kwargs) - - def unwrap(self): - module = super().unwrap() - if isinstance(module, DDP): - module = module.module - return module - - def _force_wait_all_gather(self): - for p in self.module.parameters(): - wait_all_gather_handle(p) - - def _wait_all_gather(self): - return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext() - - -def get_param_info(optim: Optimizer): - # Get a backup of necessary information of parameters for future use, which includes: - # 1. A complete param_group, with params in the form of param_id - # 2. A mapping from param address (obtained using id(param)) to integer param_id - # 3. A mapping from integer param_id to param address. - # 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding. - # When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer. - - if optim is None: - return {} - param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}} - start_index = 0 - for group in optim.param_groups: - packed_group = {k: v for k, v in group.items() if k != "params"} - packed_group["params"] = [] - - for param_id, param in enumerate(group["params"], start_index): - original_shape = param.shape if isinstance(param, torch.Tensor) else None - packed_group["params"].append(param_id) - param_info["param2id"][id(param)] = param_id - param_info["id2param"][param_id] = id(param) - param_info["param2shape"][id(param)] = original_shape - - param_info["param_groups"].append(packed_group) - start_index += len(group["params"]) - - return param_info - - -def reinitialize_optimizer(optim: Optimizer, model: Module): - model_params = set(model.parameters()) - new_param_groups = [] - for group in optim.param_groups: - params = [p for p in group["params"] if p in model_params] - new_param_groups.append({**group, "params": params}) - optim.__setstate__({"param_groups": new_param_groups}) - - -class HybridParallelNaiveOptimizer(OptimizerWrapper): - def __init__( - self, - optim: Optimizer, - model: HybridParallelModule, - use_pipeline: bool, - param_info: OrderedDict, - max_norm: float = 0, - tp_process_group: Optional[ProcessGroup] = None, # if using tp - pp_process_group: Optional[ProcessGroup] = None, # if using pp - ): - self.param_info = param_info - if use_pipeline: - reinitialize_optimizer(optim, model) - self.model = model - self.stage_manager = model.stage_manager - self.shared_params = model.shared_params - self.max_norm = max_norm - self.tp_pg = tp_process_group - self.pp_pg = pp_process_group - self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 - self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 - super().__init__(optim) - - def backward(self, loss: Tensor, *args, **kwargs): - r""" - Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. - - This method performs backward pass for gradient computation. If sequence parallelism is enabled - and gradient synchronization is required, it will synchronize gradients that are partially derived - within sequence parallelism across tp parallelism groups. - - Args: - loss (Tensor): The loss tensor to compute gradients with respect to. - *args: Additional positional arguments to be passed to the superclass backward method. - **kwargs: Additional keyword arguments to be passed to the superclass backward method. - - Returns: - None - """ - - # Call the superclass backward method to compute gradients. - super().backward(loss, *args, **kwargs) - - if self.model.require_grad_sync: - # If gradient synchronization is required, sync sequence parallelism gradients. - self.model.sync_sp_grads() - else: - # If gradient synchronization is is not required, return. - return - - def backward_by_grad(self, tensor: Tensor, grad: Tensor): - """ - Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. - - This method performs a backward pass for gradient computation using a precomputed gradient tensor. - If sequence parallelism is enabled and gradient synchronization is required, it will synchronize - gradients that are partially derived within sequence parallelism across tp parallelism groups. - - Args: - tensor (Tensor): The input tensor for which gradients are computed. - grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor. - - Returns: - None - """ - - # Call the superclass backward method to compute gradients. - super().backward_by_grad(tensor, grad) - - if self.model.require_grad_sync: - # If gradient synchronization is required, sync sequence parallelism gradients. - self.model.sync_sp_grads() - else: - # If gradient synchronization is is not required, return. - return - - def step(self, *args, **kwargs): - r""" - Perform an optimization step. - - Args: - *args: Variable-length positional arguments to be passed to the optimizer's step function. - **kwargs: Keyword arguments to be passed to the optimizer's step function. - """ - - if self.max_norm > 0: - # Compute the total gradient norm. - param_gradient_pairs = [ - (p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None - ] - total_norm = self._compute_grad_norm(param_gradient_pairs) - - # Clip the gradients to prevent exploding gradients. - self._clip_grad_norm(total_norm) - - # Perform the optimization step using the underlying optimizer. - self.optim.step(*args, **kwargs) - - def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int: - r""" - Compute and return the gradient norm for gradient clipping. - - Args: - param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation. - norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2. - - Returns: - float: The total norm of the given gradients. - """ - - if len(param_gradient_pairs) == 0: - return 0.0 - - norm_type = float(norm_type) - - # gradients used for norm calculation. - gradients = [grad for param, grad in param_gradient_pairs] - - if norm_type == inf: - total_norm = max(grad.data.abs().max() for grad in gradients) - total_norm_cuda = torch.tensor( - [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 - ) - if self.tp_size > 1: - dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) - if self.pp_size > 1: - dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) - total_norm = total_norm_cuda.item() - else: - # gradients used for norm calculation. - gradients = [grad for param, grad in param_gradient_pairs] - # grad_to_param_mapping is used to check which gradients are not distributed across devices of the 'tp_group'. - grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs} - - total_norm_exponentiated = 0.0 - for grad in gradients: - grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type - - # If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor, - # it indicates that the parameter is not distributed across devices of the 'tp_group'. - # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. - # However, we still perform the 'all_reduce' operation for the sake of good coding practices. - # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' - if self.tp_size > 1: - param_for_grad = grad_to_param_mapping[id(grad)] - if not is_distributed_tensor(param_for_grad): - grad_norm_exponentiated /= self.tp_size - - # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, - # it means that this parameter is used in two different pipeline stages. - # To avoid redundant norm calculations, we divide the exponent of this norm by - # the number of shared stages. - if self.pp_size > 1: - for shared_param in self.shared_params: - if self.stage_manager.stage in shared_param: - stage_shared_param = shared_param[self.stage_manager.stage] - if grad is stage_shared_param.grad: - grad_norm_exponentiated /= len(shared_param) - - total_norm_exponentiated += grad_norm_exponentiated - - total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 - ) - if self.tp_size > 1: - # compute norm in tp process group - dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) - if self.pp_size > 1: - # compute norm in pp process group - dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) - - # compute the total_norm - total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) - - return total_norm - - def _clip_grad_norm(self, total_norm: float) -> None: - r""" - Clips the gradients of the model's parameters to prevent exploding gradients. - - Args: - total_norm (float): The computed total gradient norm. - - Returns: - None - """ - clip_coef = torch.tensor(self.max_norm / (total_norm + 1e-6)) - clip_coef_clamped = torch.clamp(clip_coef, max=1.0) - - for group in self.optim.param_groups: - for p in group["params"]: - if p.grad is None: - continue - p.grad.data.mul_(clip_coef_clamped) - - def update_master_params(self, model: Module): - pass - - def get_working_to_master_map(self): - return None - - def get_master_to_working_map(self): - return None - - -class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): - def __init__( - self, - optim: Optimizer, - model: HybridParallelModule, - use_pipeline: bool, - param_info: OrderedDict, - precision: str = "fp16", - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - max_norm: float = 0, - tp_process_group: Optional[ProcessGroup] = None, # if using tp - pp_process_group: Optional[ProcessGroup] = None, # if using pp - ): - self.model = model - self.param_info = param_info - self.stage_manager = model.stage_manager - self.shared_params = model.shared_params - self.tp_pg = tp_process_group - self.pp_pg = pp_process_group - self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 - self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 - if use_pipeline: - reinitialize_optimizer(optim, model) - super().__init__( - optim, - precision=precision, - initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale, - max_norm=max_norm, - ) - - def backward(self, loss: Tensor, *args, **kwargs): - r""" - Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. - - This method performs backward pass for gradient computation. If sequence parallelism is enabled - and gradient synchronization is required, it will synchronize gradients that are partially derived - within sequence parallelism across tp parallelism groups. - - Args: - loss (Tensor): The loss tensor to compute gradients with respect to. - *args: Additional positional arguments to be passed to the superclass backward method. - **kwargs: Additional keyword arguments to be passed to the superclass backward method. - - Returns: - None - """ - # Call the superclass backward method to compute gradients. - super().backward(loss, *args, **kwargs) - - if self.model.require_grad_sync: - # If gradient synchronization is required, sync sequence parallelism gradients. - self.model.sync_sp_grads() - else: - # If gradient synchronization is is not required, return. - return - - def backward_by_grad(self, tensor: Tensor, grad: Tensor): - """ - Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. - - This method performs a backward pass for gradient computation using a precomputed gradient tensor. - If sequence parallelism is enabled and gradient synchronization is required, it will synchronize - gradients that are partially derived within sequence parallelism across tp parallelism groups. - - Args: - tensor (Tensor): The input tensor for which gradients are computed. - grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor. - - Returns: - None - """ - # Call the superclass backward method to compute gradients. - super().backward_by_grad(tensor, grad) - - if self.model.require_grad_sync: - # If gradient synchronization is required, sync sequence parallelism gradients. - self.model.sync_sp_grads() - else: - # If gradient synchronization is is not required, return. - return - - def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int: - r""" - Compute and return the gradient norm for gradient clipping. - - Args: - param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation. - norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2. - - Returns: - float: The total norm of the given gradients. - """ - if len(param_gradient_pairs) == 0: - return 0.0 - - norm_type = float(norm_type) - - if norm_type == inf: - # The parent class calculates the norm of 'dp' gradients, - # so we need to calculate the norm of 'tp' and 'pp' gradients. - total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type) - - total_norm_cuda = torch.tensor( - [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 - ) - - if self.tp_size > 1: - dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) - if self.pp_size > 1: - dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) - - total_norm = total_norm_cuda.item() - - else: - # gradients used for norm calculation. - gradients = [grad for param, grad in param_gradient_pairs] - # grad_to_param_mapping is used to check which gradients are not distributed in tensor parallelism. - grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs} - - total_norm_exponentiated = 0.0 - for grad in gradients: - grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type - - # If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor, - # it indicates that the parameter is not distributed across devices of the 'tp_group'. - # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. - # However, we still perform the 'all_reduce' operation for the sake of good coding practices. - # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' - if self.tp_size > 1: - param_for_grad = grad_to_param_mapping[id(grad)] - if not is_distributed_tensor(param_for_grad): - grad_norm_exponentiated /= self.tp_size - - # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, - # it means that this parameter is used in two different pipeline stages. - # To avoid redundant norm calculations, we divide the exponent of this norm by - # the number of shared stages. - if self.pp_size > 1: - for shared_param in self.shared_params: - if self.stage_manager.stage in shared_param: - stage_working_shared_param = shared_param[self.stage_manager.stage] - stage_master_shared_param = self.working_to_master_map[stage_working_shared_param] - if grad is stage_master_shared_param.grad: - grad_norm_exponentiated /= len(shared_param) - - total_norm_exponentiated += grad_norm_exponentiated - - total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 - ) - if self.tp_size > 1: - # compute norm in tp process group - dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) - if self.pp_size > 1: - # compute norm in pp process group - dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) - - # compute the total_norm - total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) - - return total_norm - - -class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): - def __init__( - self, - optimizer: Optimizer, - model: HybridParallelModule, - use_pipeline: bool, - param_info: OrderedDict, - pg_to_param_list: Dict[ProcessGroup, List[torch.nn.Parameter]] = None, - initial_scale: int = 2**16, # grad scaler config - min_scale: int = 1, - growth_factor: float = 2.0, - backoff_factor: float = 0.5, - growth_interval: int = 2000, - hysteresis: int = 2, - max_scale: int = 2**24, - clip_grad_norm: float = 0.0, # grad clipping - verbose: bool = False, - reduce_bucket_size: int = 1024 * 1024, # communication - communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True, - partition_grad: bool = False, # stage 2 flag - cpu_offload: bool = False, # cpu offload - dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm - tp_process_group: Optional[ProcessGroup] = None, # if using tp - pp_process_group: Optional[ProcessGroup] = None, # if using pp - forced_dtype: Optional[torch.dtype] = None, - overlap_allgather: bool = False, - ): - self.model = model - self.param_info = param_info - self.stage_manager = model.stage_manager - self.shared_params = model.shared_params - self.tp_pg = tp_process_group - self.pp_pg = pp_process_group - if use_pipeline: - reinitialize_optimizer(optimizer, model) - super().__init__( - optimizer=optimizer, - initial_scale=initial_scale, - min_scale=min_scale, - pg_to_param_list=pg_to_param_list, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale, - clip_grad_norm=clip_grad_norm, - verbose=verbose, - reduce_bucket_size=reduce_bucket_size, - communication_dtype=communication_dtype, - overlap_communication=overlap_communication, - partition_grad=partition_grad, - cpu_offload=cpu_offload, - dp_process_group=dp_process_group, - forced_dtype=forced_dtype, - overlap_allgather=overlap_allgather, +import torch.nn.functional as F +from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_func +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer import AttnMaskType +from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention +from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device + + +@parameterize("seq_len", [4096]) +@parameterize("bs", [2]) +@parameterize("nheads", [5]) +@parameterize("d", [128]) +@parameterize("dtype", [torch.bfloat16, torch.float16]) +def check_ring_attn(seq_len, bs, nheads, d, dtype): + torch.cuda.manual_seed(2) + device = get_current_device() + sp_group = dist.group.WORLD + sp_size = dist.get_world_size() + # Some outliers may seem large, but our errors are still lower than + # than Megatron-LM context parallel's + # (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215) + # and the original zigzag implementation's (https://github.com/zhuzilin/ring-flash-attention/tree/main) + atol = rtol = 7e-3 + + # Setup inputs + qkv = torch.randn(bs, seq_len, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) + local_qkv = split_batch_zigzag(qkv, sp_group) + q, k, v = local_qkv.unbind(dim=-3) + q, k, v = [x.squeeze(2).detach().clone().transpose(1, 2) for x in (q, k, v)] # (B, nHeads, Sq, D) + q.requires_grad = k.requires_grad = v.requires_grad = True + + # Ring attention vs single GPU + ring_out, ring_lse = RingAttention.attention( + q, + k, + v, + sp_group, + AttnMaskType.CAUSAL, + return_softmax=True, + inner_ring_size=max(2, sp_size // 2), + # inner_ring_size=4 + ) + ring_out = ring_out.transpose(1, 2) + out, lse, _ = flash_attn_qkvpacked_func( + qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True + ) + + # Checkout out and softmax denominator + local_out = split_batch_zigzag(out, sp_group) + local_lse = split_batch_zigzag(lse, sp_group, seq_dim=-1) + local_lse = local_lse.transpose(1, 2).contiguous().view(-1, ring_lse.shape[-1]) # (B, nHeads, Sq) -> (T, nHeads) + assert_close(ring_lse, local_lse, atol=atol, rtol=rtol) + assert_close(ring_out, local_out, atol=atol, rtol=rtol) + + # Check grads + ring_out.sum().backward() + out.sum().backward() + ring_dq, ring_dk, ring_dv = [x.transpose(1, 2) for x in (q.grad, k.grad, v.grad)] + dqkv = qkv.grad + local_dqkv = split_batch_zigzag(dqkv, sp_group) + + assert_close(ring_dq, local_dqkv[:, :, 0], atol=atol, rtol=rtol) + assert_close(ring_dk, local_dqkv[:, :, 1], atol=atol, rtol=rtol) + assert_close(ring_dv, local_dqkv[:, :, 2], atol=atol, rtol=rtol) + if dist.get_rank() == 0: + print( + f"sp_size {dist.get_world_size()}, inner ring size {dist.get_world_size(RingAttention.INNER_RING_GROUP)} passed." ) - def sync_dp_grads(self): - r""" - Synchronize gradients in the data parallelism dimension. - - This method wraps the existing `_sync_grad` method in order to explicitly synchronize gradients - in the data parallelism dimension. It is necessary due to the introduction of new parallel dimensions, - namely tp (tensor parallelism) and pp (pipeline parallelism). This ensures better code organization - and readability. - - Args: - None - - Returns: - None - """ - # Call the superclass `_sync_grad` method to synchronize gradients. - super()._sync_grad() - - def _sync_sp_grads(self): - r""" - Synchronize gradients that are partially derived within sequence parallelism. - - This method is responsible for synchronizing partially derived gradients across tp parallelism groups. - It identifies gradients that ara partially derived or not and synchronizes them. - If synchronization is required and gradients are found to be synchronized, - it performs the synchronization. - - Args: - None - - Returns: - None - """ - - def _get_all_working_grads() -> List[Tensor]: - """Retrieve all working gradients from different parameter groups.""" - all_working_grads = [] - for group_id in range(self.num_param_groups): - working_grads = self.get_working_grads_by_group_id(group_id) - all_working_grads.extend(working_grads) - return all_working_grads - - def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]: - """Identify gradients to be synchronized in the sequence parallelism.""" - grads_to_sync = [] - for grad in all_working_grads: - param_id_for_grad = self.get_param_id_for_grad(grad) - param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value - if SeqParallelUtils.is_sp_partial_derived_param(param_for_grad): - grads_to_sync.append(grad) - - if len(grads_to_sync) > 0: - return grads_to_sync - else: - return None - - # Get all working gradients and gradients to be synchronized. - all_working_grads = _get_all_working_grads() - grads_to_sync = _get_grads_to_sync(all_working_grads) - if self.require_grad_sync and grads_to_sync is not None: - # Synchronize sequence parallelism gradients if required. - SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync) - else: - return - - def backward(self, loss, retain_graph=False): - """ - Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. - - This method performs the backward pass for gradient computation based on a given loss tensor. - If sequence parallelism is enabled and gradient synchronization is required, it will synchronize - gradients that are partially derived within sequence parallelism across TP parallelism groups. - - Args: - loss: The loss tensor to compute gradients with respect to. - retain_graph (bool): Whether to retain the computation graph. - - Returns: - None - """ - # Call the superclass backward method to compute gradients. - super().backward(loss, retain_graph) - - if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: - # If gradient synchronization is required, sync sequence parallelism gradients. - self._sync_sp_grads() - else: - # If gradient synchronization is is not required, return. - return - - def backward_by_grad(self, tensor, grad): - """ - Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. - - This method performs a backward pass for gradient computation based on a precomputed gradient tensor. - If sequence parallelism is enabled and gradient synchronization is required, it will synchronize - gradients that are partially derived within sequence parallelism across TP parallelism groups. - - Args: - tensor: The input tensor for which gradients are computed. - grad: The precomputed gradient tensor to compute gradients with respect to the input tensor. - - Returns: - None - """ - # Call the superclass backward_by_grad method to compute gradients. - super().backward_by_grad(tensor, grad) - - if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: - # If gradient synchronization is required, sync sequence parallelism gradients. - self._sync_sp_grads() - else: - # If gradient synchronization is is not required, return. - return - - def _compute_grad_norm(self, dp_pg, gradients: List[Tensor], norm_type: int = 2) -> float: - r""" - Compute and return the gradient norm for gradient clipping. - - Args: - gradients (List[Tensor]): A list of tensors containing gradients. - norm_type (int, optional): Type of the p-norm to be computed. Defaults to 2. - - Returns: - float: The computed gradient norm. - """ - - # Check if the list of gradients is empty - if len(gradients) == 0: - return 0.0 - - dp_size = get_world_size(dp_pg) if dp_pg is not None else 1 - tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 - pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 - norm_type = float(norm_type) - - if norm_type == inf: - # The parent class calculates the norm of 'dp' gradients, - # so we only need to calculate the norm 'tp' of 'pp' gradients. - total_norm = super()._compute_grad_norm(gradients, norm_type) - - total_norm_cuda = torch.tensor( - [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 - ) - - if tp_size > 1: - dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) - if pp_size > 1: - dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) - - total_norm = total_norm_cuda.item() - else: - total_norm_exponentiated = 0.0 - for grad in gradients: - grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type - - # If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor, - # it indicates that the parameter is not distributed across devices of the 'tp_group'. - # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. - # However, we still perform the 'all_reduce' operation for the sake of good coding practices. - # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' - if tp_size > 1: - param_id_for_grad = self.get_param_id_for_grad(grad) - param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value - - if not is_distributed_tensor(param_for_grad): - grad_norm_exponentiated /= tp_size - - # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, - # it means that this parameter is used in two different pipeline stages. - # To avoid redundant norm calculations, we divide the exponent of this norm by - # the number of shared stages. - if pp_size > 1: - for shared_param in self.shared_params: - if self.stage_manager.stage in shared_param: - stage_shared_param = shared_param[self.stage_manager.stage] - working_grad = self.get_working_grad_by_param_id(id(stage_shared_param)) - if grad is working_grad: - grad_norm_exponentiated /= len(shared_param) - - total_norm_exponentiated += grad_norm_exponentiated - - total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 - ) - if dp_size > 1: - # compute norm in dp process group - dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=dp_pg) - if tp_size > 1: - # compute norm in tp process group - dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) - if pp_size > 1: - # compute norm in pp process group - dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) - - # Compute the 'total_norm' from 'total_norm_exponentiated' - total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) - - return total_norm - -class HybridParallelPlugin(PipelinePluginBase): - """ - Plugin for Hybrid Parallel Training. - Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin. - The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size). - - ```python - from colossalai.booster import Booster - from colossalai.booster.plugin import HybridParallelPlugin - - model, train_dataset, optimizer, criterion = ... - plugin = HybridParallelPlugin(tp_size=2, pp_size=2) - - train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) - booster = Booster(plugin=plugin) - model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader) - ``` - - Args: - tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. - pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1. - sp_size (int): The size of sequence parallelism. - precision (str, optional): Specifies the precision of parameters during training. - Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'. - Defaults to 'fp16'. - zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2]. - When set to 0, ZeRO will not be used. Defaults to 0. - enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. - Currently all the optimization methods include fused normalization, flash attention and JIT. - Defaults to False. - enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False. - enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. - enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. - enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. - sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather". - enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. - parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True. - num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. - microbatch_size (int, optional): Microbatch size when using pipeline parallelism. - Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline. - If ``num_microbatches`` is provided, this will be ignored. Defaults to None. - initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16. - min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1. - growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2. - backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5. - growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000. - hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2. - max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32. - max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0. - broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True. - ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25. - find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False. - check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False. - gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False. - static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False. - zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12. - cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. - communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. - overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. - custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. - pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. - num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. - gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None. - enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. - make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64. - overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism - inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn". - It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default. - - """ - - def __init__( - self, - tp_size: int, - pp_size: int, - sp_size: int = None, - precision: str = "fp16", - zero_stage: int = 0, - enable_all_optimization: bool = False, - enable_fused_normalization: bool = False, - enable_flash_attention: bool = False, - enable_jit_fused: bool = False, - enable_sequence_parallelism: bool = False, - sequence_parallelism_mode: str = None, - enable_sequence_overlap: bool = False, - parallel_output: bool = True, - num_microbatches: Optional[int] = None, - microbatch_size: Optional[int] = None, - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - max_norm: float = 0, - broadcast_buffers: bool = True, - ddp_bucket_cap_mb: int = 25, - find_unused_parameters: bool = False, - check_reduction: bool = False, - gradient_as_bucket_view: bool = False, - static_graph: bool = False, - zero_bucket_size_in_m: int = 12, - cpu_offload: bool = False, - communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True, - custom_policy: Policy = None, - pp_style: str = "1f1b", - num_model_chunks: int = 1, - num_layers_per_stage: Optional[List[int]] = None, - gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, - enable_metadata_cache: bool = True, - make_vocab_size_divisible_by: int = 64, - dp_outside: bool = True, - overlap_p2p: bool = True, - overlap_allgather: bool = False, - inner_ring_size: int = None, - ) -> None: - super().__init__() - - assert ( - dist.get_world_size() % (tp_size * pp_size) == 0 - ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" - - if enable_sequence_parallelism: - self.sequence_parallelism_mode = ( - sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" - ) - assert ( - self.sequence_parallelism_mode in SUPPORT_SP_MODE - ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}" - if self.sequence_parallelism_mode in ["split_gather", "ring"]: - assert ( - tp_size > 1 - ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" - if sp_size != 1: - warnings.warn( - f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." - ) - self.sp_size = 1 - self.dp_size = dist.get_world_size() // (tp_size * pp_size) - elif self.sequence_parallelism_mode in ["all_to_all", "ring_attn"]: - self.sp_size = 1 if sp_size is None else sp_size - self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size) - if self.sequence_parallelism_mode == "ring_attn": - enable_flash_attention = True - else: - self.dp_size = dist.get_world_size() // (tp_size * pp_size) - assert ( - sp_size == 1 or sp_size is None - ), f"You should not set sp_size when sequence parallelism is not enabled." - self.sp_size = 1 - - self.tp_size = tp_size - self.pp_size = pp_size - self.precision = precision - self.zero_stage = zero_stage - self.cpu_offload = cpu_offload - self.enable_all_optimization = enable_all_optimization - self.enable_fused_normalization = enable_fused_normalization - self.enable_flash_attention = enable_flash_attention - self.enable_jit_fused = enable_jit_fused - self.enable_sequence_parallelism = enable_sequence_parallelism - if dp_outside: - self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 - if sequence_parallelism_mode == "ring_attn": - # Swap tp and sp since 2D Ring has better inter-node latency - self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size) - self.sp_axis = 2 - self.tp_axis = 3 - else: - self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) - else: - self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 - if sequence_parallelism_mode == "ring_attn": - self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.sp_size, self.tp_size) - self.sp_axis = 2 - self.tp_axis = 3 - else: - self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) - - self.stage_manager = None - self.schedule = None - self.custom_policy = custom_policy - assert zero_stage in (0, 1, 2) - if self.pp_size > 1: - assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" - assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" - assert ( - num_microbatches is not None or microbatch_size is not None - ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" - assert ( - self.zero_stage <= 1 - ), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism" - self.stage_manager = PipelineStageManager( - self.pg_mesh, - pipeline_axis=self.pp_axis, - enable_interleave=pp_style == "interleaved", - num_model_chunks=num_model_chunks, - num_layers_per_stage=num_layers_per_stage, - ) - - if pp_style == "interleaved": - assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" - self.schedule = InterleavedSchedule( - stage_manager=self.stage_manager, - num_model_chunks=num_model_chunks, - num_microbatch=num_microbatches, - microbatch_size=microbatch_size, - enable_metadata_cache=enable_metadata_cache, - overlap_p2p=overlap_p2p, - ) - elif pp_style == "1f1b": - self.schedule = OneForwardOneBackwardSchedule( - stage_manager=self.stage_manager, - num_microbatches=num_microbatches, - microbatch_size=microbatch_size, - enable_metadata_cache=enable_metadata_cache, - ) - else: - raise NotImplementedError() - if sequence_parallelism_mode == "ring_attn": - assert parallel_output, "Ring Attention doesn't support gathering output yet." - - self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) - self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) - self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis) - if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]: - self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) - else: - self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) - # According to https://github.com/InternLM/InternEvo/blob/a53a4ff4fc45761f80d7fe8e9188bc2e02d487fc/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py#L405 - # and https://zhuanlan.zhihu.com/p/706805407 - # using a different proc group may put p2p comm on a new - # NCCL stream :) - dkv_group = None - if sequence_parallelism_mode == "ring_attn": - sp_ranks = dist.get_process_group_ranks(self.sp_group) - dkv_group = dist.new_group(ranks=sp_ranks) - - self.shard_config = ShardConfig( - tensor_parallel_process_group=self.tp_group, - sequence_parallel_process_group=self.sp_group, - pipeline_stage_manager=self.stage_manager, - enable_tensor_parallelism=self.tp_size > 1, - enable_all_optimization=self.enable_all_optimization, - enable_fused_normalization=self.enable_fused_normalization, - enable_flash_attention=self.enable_flash_attention, - enable_jit_fused=self.enable_jit_fused, - enable_sequence_parallelism=enable_sequence_parallelism, - sequence_parallelism_mode=sequence_parallelism_mode, - enable_sequence_overlap=enable_sequence_overlap, - parallel_output=parallel_output, - make_vocab_size_divisible_by=make_vocab_size_divisible_by, - gradient_checkpoint_config=gradient_checkpoint_config, - inner_ring_size=inner_ring_size, - ) - self.amp_config = dict( - initial_scale=initial_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - min_scale=min_scale, - max_scale=max_scale, - ) - - self.ddp_config = dict( - broadcast_buffers=broadcast_buffers, - bucket_cap_mb=ddp_bucket_cap_mb, - find_unused_parameters=find_unused_parameters, - check_reduction=check_reduction, - gradient_as_bucket_view=gradient_as_bucket_view, - static_graph=static_graph, - ) - - self.zero_config = dict( - reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, - communication_dtype=communication_dtype, - overlap_communication=overlap_communication, - cpu_offload=cpu_offload, - partition_grad=(self.zero_stage == 2), - forced_dtype=PRECISION_TORCH_TYPE[precision], - overlap_allgather=overlap_allgather, +@parameterize("seqlen", [4096]) +@parameterize("bs", [2]) +@parameterize("nheads", [5]) +@parameterize("d", [128]) +@parameterize("dtype", [torch.bfloat16, torch.float16]) +def check_packed_seq(seqlen, bs, nheads, d, dtype): + device = get_current_device() + sp_group = dist.group.WORLD + sp_size = dist.get_world_size() + atol = rtol = 7e-3 + torch.cuda.manual_seed(2) + # Prepare varlen attention mask + padding_mask = torch.ones((bs, seqlen), dtype=torch.int, device=device) + padding_mask[: bs // 2, (seqlen // 4) * 3 :] = 0 + padding_mask[:, seqlen // 2 :] = 0 + + input_embeds = torch.randn(bs, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) + + # Forward + # out = ColoAttention.attention(q, k, v, **mask_info) + flat_input = input_embeds.view(-1, nheads, d)[padding_mask.flatten().nonzero().squeeze()] + qkv = torch.stack([flat_input] * 3, dim=1) + qkv.retain_grad() + + input_embeds, mask_info, _ = RingAttention.prepare_varlen_batch(padding_mask, sp_group, input_embeds) + out, lse, _ = flash_attn_varlen_qkvpacked_func( + qkv, + mask_info["cu_seqlens"] * sp_size, + mask_info["max_seqlen"] * sp_size, + return_attn_probs=True, + causal=True, + # deterministic=True + ) + # Test the splitting function + local_input = split_varlen_zigzag( + flat_input, mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size + ) + assert (local_input == input_embeds.view(-1, nheads, d)[mask_info["valid_indices"]]).all() + del local_input, flat_input + + q_ring, k_ring, v_ring = [input_embeds.clone().transpose(1, 2) for _ in range(3)] + q_ring.retain_grad() + k_ring.retain_grad() + v_ring.retain_grad() + + ring_out, ring_lse = RingAttention.attention( + q_ring, + k_ring, + v_ring, + sp_group, + **mask_info, + pad_output=False, + return_softmax=True, + # deterministic=True + ) + ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d) + # Check output + lse = lse.transpose(0, 1) + out, lse = split_varlen_zigzag( + [out, lse], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size + ) + assert_close(lse, ring_lse, atol=atol, rtol=rtol) + assert_close(out, ring_out, atol=atol, rtol=rtol) + + # Check grads + labels = torch.ones(out.shape[0], dtype=dtype, device=device) + F.mse_loss(out.sum((-2, -1)), labels).backward() + F.mse_loss(ring_out.sum((-2, -1)), labels[: ring_out.shape[0]]).backward() + dq, dk, dv = [ + split_varlen_zigzag( + qkv.grad[:, i], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size ) + for i in range(3) + ] + dq_ring, dk_ring, dv_ring = [ + x.transpose(1, 2).reshape(-1, nheads, d)[mask_info["valid_indices"]] + for x in (q_ring.grad, k_ring.grad, v_ring.grad) + ] - self.max_norm = max_norm - - def __del__(self): - """Destroy the process groups in ProcessGroupMesh""" - self.pg_mesh.destroy_mesh_process_groups() - - @property - def enable_pipeline_parallelism(self) -> bool: - return self.pp_size > 1 - - def supported_devices(self) -> List[str]: - return ["cuda", "npu"] - - def supported_precisions(self) -> List[str]: - return ["fp16", "bf16", "fp32"] - - def control_device(self) -> bool: - return True - - def control_precision(self) -> bool: - return True + assert_close(dq, dq_ring, atol=atol, rtol=rtol) + assert_close(dk, dk_ring, atol=atol, rtol=rtol) + assert_close(dv, dv_ring, atol=atol, rtol=rtol) - def support_no_sync(self) -> bool: - return True - def support_lora(self) -> bool: - return True +def launch_single_ring(rank, world_size, port): + colossalai.launch(rank, world_size, "localhost", port) + check_packed_seq() + check_ring_attn() - def control_checkpoint_io(self) -> bool: - return True - def configure( - self, - model: Module, - optimizer: Optional[Optimizer] = None, - criterion: Optional[Callable] = None, - dataloader: Optional[DataLoader] = None, - lr_scheduler: Optional[LRScheduler] = None, - ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: - param_info = get_param_info(optimizer) - - # TODO: Support Galore + ZeRO - zero_stage = self.zero_stage - zero_config = deepcopy(self.zero_config) - - # Replace with distributed implementation if exists - optimizer = cast_to_distributed(optimizer) - - if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: - warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.") - zero_config["partition_grad"] = False - zero_stage = 0 - - if not isinstance(model, ModelWrapper): -<<<<<<< HEAD -<<<<<<< HEAD - # Shouldn't use pp (frequent grad accumulation) with torch ddp -======= - # Can't use pp (frequent grad accumulation) with torch ddp ->>>>>>> add varlen tests -======= - # Shouldn't use pp (frequent grad accumulation) with torch ddp ->>>>>>> all tests passed - use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( - self.dp_size == 1 and self.pp_size == 1 - ) - - # Apply Hybrid ZeRO across DP * SP ranks - if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode): - dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) - self.dp_size = get_world_size(dp_group) - else: - dp_group = self.dp_group - model = HybridParallelModule( - model, - precision=self.precision, - shard_config=self.shard_config, - dp_group=dp_group, - tp_group=self.tp_group, - sp_group=self.sp_group, - use_ddp=use_ddp, - ddp_config=self.ddp_config, - custom_policy=self.custom_policy, - overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]), - ) - if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): - if zero_stage == 0: - is_zero = False - if self.precision in ["fp16", "bf16"]: - optimizer = HybridParallelAMPOptimizer( - optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info, - precision=self.precision, - max_norm=self.max_norm, - pp_process_group=self.pp_group, - tp_process_group=self.tp_group, - **self.amp_config, - ) - else: - optimizer = HybridParallelNaiveOptimizer( - optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info, - max_norm=self.max_norm, - pp_process_group=self.pp_group, - tp_process_group=self.tp_group, - ) - else: - is_zero = self.dp_size > 1 - if self.dp_size == 1: - warnings.warn( - "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " - "If you do not intend to use cpu_offload, please consider set zero_stage=0." - ) - - assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." - optimizer = HybridParallelZeroOptimizer( - optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info, - dp_process_group=dp_group, - tp_process_group=self.tp_group, - pp_process_group=self.pp_group, - verbose=True, - clip_grad_norm=self.max_norm, - **zero_config, - **self.amp_config, - ) - # inject update_master_params - model.update_master_params = MethodType(optimizer.update_master_params, model) - - # Setup optimizers that require global states - optim = optimizer.optim - if isinstance(optim, DistributedOptim): - shard_to_param = optimizer.get_master_to_working_map() if is_zero else {} - padding_map = optimizer.get_param_padding_map() if is_zero else defaultdict(int) - optim.setup_distributed(self.tp_group, self.dp_group, shard_to_param, padding_map, is_zero) - - return model, optimizer, criterion, dataloader, lr_scheduler - - def execute_pipeline( - self, - data_iter: Iterator, - model: HybridParallelModule, - criterion: Callable[[Any, Any], torch.Tensor], - optimizer: Optional[ - Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, HybridParallelZeroOptimizer] - ] = None, - return_loss: bool = True, - return_outputs: bool = False, - ) -> dict: - assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled" - - if return_outputs: - warnings.warn("return_outputs may lead to significant extra memory consumption.") - - # Create a context for gradient synchronization based on the optimizer type. - # If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync(). - # This is to avoid redundant gradient reduction in pipeline parallelism (multiple microbatch values should be reduced once), - # so we disable it, performing manual reduction instead. - ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() - - with ctx, model._wait_all_gather(): - outputs = self.schedule.forward_backward_step( - model, data_iter, criterion, optimizer, return_loss, return_outputs - ) - - # run with gradients accumulation - if ( - model.require_grad_sync == False - or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False) - or not torch.is_grad_enabled() - ): - return outputs - - # Synchronize the grads of shared parameters of the model. - model.sync_shared_params() - # Synchronize sequence parallelism gradients of the model. - model.sync_sp_grads() - - # Check if the optimizer is a HybridParallelZeroOptimizer and synchronize data parallelism gradients if so. - # Otherwise, synchronize data parallelism gradients of the model. - # This is because these are two different forms of data parallelism. - if isinstance(optimizer, HybridParallelZeroOptimizer): - optimizer.sync_dp_grads() - else: - model.sync_dp_grads() - - return outputs - - def prepare_dataloader( - self, - dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - distributed_sampler_cls=None, - **kwargs, - ): - r""" - Prepare a dataloader for distributed training. The dataloader will be wrapped by - `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. - - - Args: - dataset (`torch.utils.data.Dataset`): The dataset to be loaded. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. - seed (int, optional): Random worker seed for sampling, defaults to 1024. - add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. - drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size - is not divisible by the batch size. If False and the size of dataset is not divisible by - the batch size, then the last batch will be smaller, defaults to False. - pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. - num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. - kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in - `DataLoader `_. - - Returns:` - :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. - """ - _kwargs = kwargs.copy() - distributed_sampler_cls = distributed_sampler_cls or DistributedSampler - sampler = distributed_sampler_cls( - dataset, - num_replicas=self.dp_group.size(), - rank=dist.get_group_rank(self.dp_group, global_rank=dist.get_rank()), - shuffle=shuffle, - ) - - # Deterministic dataloader - def seed_worker(worker_id): - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return DataLoader( - dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs, - ) +def launch_double_ring(rank, world_size, port): + colossalai.launch(rank, world_size, "localhost", port) + check_ring_attn() - def get_checkpoint_io(self) -> CheckpointIO: - return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) - def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]: - assert ( - self.zero_stage != 2 - ), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed." - return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() +@rerun_if_address_is_in_use() +@parameterize("world_size", [2]) +def test_ring_attn(world_size): + spawn(launch_single_ring, nprocs=world_size) - def enable_lora( - self, - model: Module, - pretrained_dir: Optional[str] = None, - lora_config: Optional[Dict] = None, - bnb_quantization_config: Optional[BnbQuantizationConfig] = None, - ) -> Module: - from peft import PeftModel, get_peft_model - assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model." - assert self.pp_size == 1 and self.tp_size == 1 - self.lora_enabled = True - warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") +@rerun_if_address_is_in_use() +@parameterize("world_size", [4]) +def test_double_ring(world_size): + spawn(launch_double_ring, nprocs=world_size) - if bnb_quantization_config is not None: - model = quantize_model(model, bnb_quantization_config) - if pretrained_dir is None: - peft_model = get_peft_model(model, lora_config) - else: - peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True) - return peft_model +if __name__ == "__main__": + test_ring_attn() + test_double_ring() \ No newline at end of file diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index ff5f0fc3b84e..4426e9ea3f1b 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -445,10 +445,10 @@ def _rescale_out_lse(out, block_out, lse, block_lse): Compute the new attention denominator: exp(lse) + exp(block_lse) = exp(max_scale) * (exp(min_scale - max_scale) + 1) Args: - out: (B, Sq, H, D) - block_out: (B, Sq, H, D) - lse: (B, H, Sq, 1) - block_lse: (B, H, Sq, 1) + out: (T, H, D) + block_out: (T, H, D) + lse: (H, T, 1) + block_lse: (H, T, 1) """ # min_scale = torch.min(lse, block_lse) @@ -984,7 +984,6 @@ def backward(ctx, dout, _): cu_seqlens_half = cu_seqlens_q // 2 max_seqlen_half = max_seqlen_q // 2 misc_kwargs = ctx.misc_kwargs - dout = dout.contiguous() del misc_kwargs["block_table"] assert ( @@ -1025,6 +1024,12 @@ def backward(ctx, dout, _): softmax_lse1 = softmax_lse[:, half_idx_back] dout = dout.contiguous() + # Non-contiguous indexing creates a new contiguous tensor, + # so only do it once + if sp_rank != sp_size - 1: + softmax_lse1 = softmax_lse[:, half_idx_back] + dout = dout.contiguous() + # Double comm buffers for sending and receiving kv kv_buffers = [torch.stack((k, v))] # (2, T, H, D) kv_buffers.append(torch.empty_like(kv_buffers[0])) diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index c8ebe8b6b8fb..13b40f97d7ee 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -5,52 +5,20 @@ from torch.testing import assert_close import colossalai -<<<<<<< HEAD -<<<<<<< HEAD from colossalai.shardformer.layer import AttnMaskType from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag -======= -from colossalai.shardformer.layer import AttnMaskType, ColoAttention -======= -from colossalai.shardformer.layer import AttnMaskType ->>>>>>> all tests passed -from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention -<<<<<<< HEAD -from colossalai.shardformer.layer.utils import split_batch_zigzag ->>>>>>> add varlen tests -======= -from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag ->>>>>>> fix typo from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device @parameterize("seq_len", [4096]) -<<<<<<< HEAD -<<<<<<< HEAD @parameterize("bs", [2]) @parameterize("nheads", [5]) @parameterize("d", [128]) @parameterize("dtype", [torch.bfloat16, torch.float16]) def check_ring_attn(seq_len, bs, nheads, d, dtype): torch.cuda.manual_seed(2) -======= -@parameterize("bs", [1]) -======= -@parameterize("bs", [2]) ->>>>>>> all tests passed -@parameterize("nheads", [5]) -@parameterize("d", [128]) -@parameterize("dtype", [torch.bfloat16, torch.float16]) -def check_ring_attn(seq_len, bs, nheads, d, dtype): - torch.cuda.manual_seed(2) -<<<<<<< HEAD - dist.get_rank() - dist.get_world_size() ->>>>>>> add varlen tests -======= ->>>>>>> fix typo device = get_current_device() sp_group = dist.group.WORLD sp_size = dist.get_world_size() @@ -68,7 +36,6 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): q.requires_grad = k.requires_grad = v.requires_grad = True # Ring attention vs single GPU -<<<<<<< HEAD ring_out, ring_lse = RingAttention.attention( q, k, @@ -79,9 +46,6 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): inner_ring_size=max(2, sp_size // 2), # inner_ring_size=4 ) -======= - ring_out, ring_lse = RingAttention.attention(q, k, v, sp_group, sp_stream, AttnMaskType.CAUSAL, return_softmax=True) ->>>>>>> fix typo ring_out = ring_out.transpose(1, 2) out, lse, _ = flash_attn_qkvpacked_func( qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True @@ -100,10 +64,7 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): ring_dq, ring_dk, ring_dv = [x.transpose(1, 2) for x in (q.grad, k.grad, v.grad)] dqkv = qkv.grad local_dqkv = split_batch_zigzag(dqkv, sp_group) -<<<<<<< HEAD -======= ->>>>>>> add varlen tests assert_close(ring_dq, local_dqkv[:, :, 0], atol=atol, rtol=rtol) assert_close(ring_dk, local_dqkv[:, :, 1], atol=atol, rtol=rtol) assert_close(ring_dv, local_dqkv[:, :, 2], atol=atol, rtol=rtol) @@ -129,7 +90,6 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): padding_mask[: bs // 2, (seqlen // 4) * 3 :] = 0 padding_mask[:, seqlen // 2 :] = 0 -<<<<<<< HEAD input_embeds = torch.randn(bs, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) # Forward @@ -178,7 +138,6 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): assert_close(lse, ring_lse, atol=atol, rtol=rtol) assert_close(out, ring_out, atol=atol, rtol=rtol) -<<<<<<< HEAD # Check grads labels = torch.ones(out.shape[0], dtype=dtype, device=device) F.mse_loss(out.sum((-2, -1)), labels).backward() @@ -199,116 +158,14 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): assert_close(dv, dv_ring, atol=atol, rtol=rtol) -<<<<<<< HEAD def launch_single_ring(rank, world_size, port): colossalai.launch(rank, world_size, "localhost", port) check_packed_seq() check_ring_attn() -======= -@parameterize("seq_len", [4096]) -======= -@parameterize("seqlen", [16]) ->>>>>>> fix typo -======= -@parameterize("seqlen", [4096]) ->>>>>>> all tests passed -@parameterize("bs", [2]) -@parameterize("nheads", [5]) -@parameterize("d", [128]) -@parameterize("dtype", [torch.bfloat16, torch.float16]) -def check_packed_seq(seqlen, bs, nheads, d, dtype): - device = get_current_device() - sp_group = dist.group.WORLD - sp_size = dist.get_world_size() - sp_stream = torch.cuda.Stream() - atol = rtol = 7e-3 - torch.cuda.manual_seed(2) - # Prepare varlen attention mask - padding_mask = torch.ones((bs, seqlen), dtype=torch.int, device=device) - padding_mask[: bs // 2, (seqlen // 4) * 3 :] = 0 - padding_mask[:, seqlen // 2 :] = 0 - - input_embeds = torch.randn(bs, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) - - # Forward - # out = ColoAttention.attention(q, k, v, **mask_info) - flat_input = input_embeds.view(-1, nheads, d)[padding_mask.flatten().nonzero().squeeze()] - qkv = torch.stack([flat_input] * 3, dim=1) - qkv.retain_grad() - - input_embeds, mask_info, _ = RingAttention.prepare_varlen_batch(padding_mask, sp_group, input_embeds) - out, lse, _ = flash_attn_varlen_qkvpacked_func( - qkv, - mask_info["cu_seqlens"] * sp_size, - mask_info["max_seqlen"] * sp_size, - return_attn_probs=True, - causal=True, - # deterministic=True - ) - # Test the splitting function - local_input = split_varlen_zigzag( - flat_input, mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size - ) - assert (local_input == input_embeds.view(-1, nheads, d)[mask_info["valid_indices"]]).all() - del local_input, flat_input - - q_ring, k_ring, v_ring = [input_embeds.clone().transpose(1, 2) for _ in range(3)] - q_ring.retain_grad() - k_ring.retain_grad() - v_ring.retain_grad() - - ring_out, ring_lse = RingAttention.attention( - q_ring, - k_ring, - v_ring, - sp_group, - sp_stream, - **mask_info, - pad_output=False, - return_softmax=True, - # deterministic=True - ) - ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d) - # Check output - lse = lse.transpose(0, 1) - out, lse = split_varlen_zigzag( - [out, lse], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size - ) - assert_close(lse, ring_lse, atol=atol, rtol=rtol) - assert_close(out, ring_out, atol=atol, rtol=rtol) - - # Check grads - labels = torch.ones(out.shape[0], dtype=dtype, device=device) - F.mse_loss(out.sum((-2, -1)), labels).backward() - F.mse_loss(ring_out.sum((-2, -1)), labels[: ring_out.shape[0]]).backward() - dq, dk, dv = [ - split_varlen_zigzag( - qkv.grad[:, i], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size - ) - for i in range(3) - ] - dq_ring, dk_ring, dv_ring = [ - x.transpose(1, 2).reshape(-1, nheads, d)[mask_info["valid_indices"]] - for x in (q_ring.grad, k_ring.grad, v_ring.grad) - ] - - assert_close(dq, dq_ring, atol=atol, rtol=rtol) - assert_close(dk, dk_ring, atol=atol, rtol=rtol) - assert_close(dv, dv_ring, atol=atol, rtol=rtol) - - -def launch(rank, world_size, port): - colossalai.launch(rank, world_size, "localhost", port) -<<<<<<< HEAD -<<<<<<< HEAD - # check_ring_attn() - check_packed_seq() ->>>>>>> add varlen tests def launch_double_ring(rank, world_size, port): colossalai.launch(rank, world_size, "localhost", port) - check_packed_seq() check_ring_attn() @@ -316,26 +173,14 @@ def launch_double_ring(rank, world_size, port): @parameterize("world_size", [2]) def test_ring_attn(world_size): spawn(launch_single_ring, nprocs=world_size) -======= - # check_packed_seq() -======= - check_packed_seq() ->>>>>>> all tests passed - check_ring_attn() ->>>>>>> fix typo @rerun_if_address_is_in_use() -<<<<<<< HEAD @parameterize("world_size", [4]) def test_double_ring(world_size): spawn(launch_double_ring, nprocs=world_size) -======= -def test_ring_attn(): - spawn(launch, nprocs=2) ->>>>>>> all tests passed if __name__ == "__main__": test_ring_attn() - test_double_ring() + test_double_ring() \ No newline at end of file From 89343fd24d05de2a26f5197bbd4e56416133b750 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 5 Aug 2024 10:38:31 +0000 Subject: [PATCH 44/71] q1 index only once --- colossalai/shardformer/layer/attn.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 4426e9ea3f1b..d218fad553f2 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -768,6 +768,11 @@ def forward( if sp_rank != sp_size - 1: q1 = q[half_idx_back] + # Non-contiguous indexing creates a new contiguous tensor, + # so only do it once + if sp_rank != sp_size - 1: + q1 = q[half_idx_back] + # Pre-allocate double buffer for overlapping and receiving next step's inputs kv_buffers = [torch.stack((k, v))] # (2, B, Sq, H, D) kv_buffers.append(torch.empty_like(kv_buffers[0])) @@ -1024,8 +1029,6 @@ def backward(ctx, dout, _): softmax_lse1 = softmax_lse[:, half_idx_back] dout = dout.contiguous() - # Non-contiguous indexing creates a new contiguous tensor, - # so only do it once if sp_rank != sp_size - 1: softmax_lse1 = softmax_lse[:, half_idx_back] dout = dout.contiguous() From 551aaeccb23b3a60d82c6885f91bd5c0381e804e Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 6 Aug 2024 00:57:00 +0000 Subject: [PATCH 45/71] remove events to simplify stream sync --- colossalai/shardformer/layer/attn.py | 155 +++------------------------ 1 file changed, 16 insertions(+), 139 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index d218fad553f2..e4640b01021b 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -2,6 +2,7 @@ from typing import Callable, Dict, Optional, Tuple import torch +import torch.distributed import torch.distributed as dist import torch.nn.functional as F from einops import rearrange @@ -16,10 +17,6 @@ from .utils import RingComm, get_half_index, split_varlen_zigzag -from .utils import RingComm, get_half_index, split_varlen_zigzag - -from .utils import RingComm, get_half_index, split_varlen_zigzag - __all__ = [ "AttnMaskType", "ColoAttention", @@ -40,7 +37,7 @@ def invert_mask(mask: torch.Tensor) -> torch.Tensor: """Invert the mask tensor. Args: - mask (torch.Tensor): Mask tensor. Shape should be [B, 1, Sq, Sq] + mask (torch.Tensor): Mask tensor. Shape should be [B, 1, Sq, Skv] Returns: torch.Tensor: Inverted mask tensor. @@ -169,10 +166,6 @@ def prepare_attn_kwargs( attention_mask = attention_mask.tril(diagonal=0) attention_mask = attention_mask.expand(b, s_q, s_kv) else: - assert q_padding_mask.shape == ( - b, - s_q, - ), f"q_padding_mask shape {q_padding_mask.shape} should be {b, s_q}." max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) if kv_padding_mask is None: # self attention @@ -180,7 +173,6 @@ def prepare_attn_kwargs( max_seqlen_kv, cu_seqlens_kv, kv_indices = max_seqlen_q, cu_seqlens_q, q_indices else: max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask) - attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) assert kv_padding_mask.shape == ( b, s_kv, @@ -269,6 +261,18 @@ def attention( ) if attention_mask_type == AttnMaskType.CUSTOM: assert not torch.all(attention_mask != 0, dim=-1).any() + elif attention_mask_type in ( + AttnMaskType.PADDED, + AttnMaskType.PADDED_CAUSAL, + ): + assert ( + cu_seqlens_q is not None + and cu_seqlens_kv is not None + and max_seqlen_q is not None + and max_seqlen_kv is not None + and q_indices is not None + and kv_indices is not None + ) else: # if attention_mask is None, attention_mask_type should be the default value assert attention_mask_type == AttnMaskType.CUSTOM @@ -369,108 +373,6 @@ def _rescale_out_lse(out, block_out, lse, block_lse): return out, lse -def _not_nan(x): - return not (x.isnan().any() or x.isinf().any()) - - -@triton.jit -def _rescale_out_lse_kernel( - out_ptr, - out_per_step_ptr, - lse_ptr, - lse_step_ptr, - D, # Each thread handles D elements - stride_out_0, - stride_out_1, - stride_out_2, - stride_out_per_step_0, - stride_out_per_step_1, - stride_out_per_step_2, - stride_lse_0, - stride_lse_1, - BLOCK_M: tl.constexpr, -): - batch_id = tl.program_id(0) - sq_id = tl.program_id(1) - h_id = tl.program_id(2) - d_id = tl.arange(0, D) - - out_idx = batch_id * stride_out_0 + sq_id * stride_out_1 + h_id * stride_out_2 + d_id - out_per_step_idx = batch_id * stride_out_per_step_0 + sq_id * stride_out_per_step_1 + h_id * stride_out_per_step_2 - lse_idx = batch_id * stride_lse_0 + h_id * stride_lse_1 - lse_step_idx = batch_id * stride_lse_0 + h_id * stride_lse_1 - - # Load inputs - out = tl.load(out_ptr + out_idx) - out_per_step = tl.load(out_per_step_ptr + out_per_step_idx) - lse = tl.load(lse_ptr + lse_idx) - lse_step = tl.load(lse_step_ptr + lse_step_idx) - - # Element-wise rescale - new_lse = lse + tl.log(1 + tl.exp(lse_step - lse)) - out = tl.exp(lse - new_lse) * out + tl.exp(lse_step - new_lse) * out_per_step - - tl.store(out_ptr + out_idx, out) - tl.store(lse_ptr + lse_idx, new_lse) - - -def _rescale_out_lse_triton(out, block_out, lse, block_lse): - T, H, D = out.shape - - assert out.is_contiguous() and block_out.is_contiguous() and lse.is_contiguous() and block_lse.is_contiguous() - - # TODO: use 1d kernel? - grid = lambda META: (triton.cdiv(Sq, META["BLOCK_M"]), B, H) - _rescale_out_lse_kernel[grid]( - out, - block_out, - lse, - block_lse, - T, - H, - D, - out.stride(0), - out.stride(1), - out.stride(2), - block_out.stride(0), - block_out.stride(1), - block_out.stride(2), - lse.stride(0), - lse.stride(1), - ) - - -def _rescale_out_lse(out, block_out, lse, block_lse): - """ - Compute the new attention denominator: - exp(lse) + exp(block_lse) = exp(max_scale) * (exp(min_scale - max_scale) + 1) - Args: - out: (T, H, D) - block_out: (T, H, D) - lse: (H, T, 1) - block_lse: (H, T, 1) - """ - - # min_scale = torch.min(lse, block_lse) - # max_scale = torch.max(lse, block_lse) - # new_lse = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) - - # NOTE: directly assigning to .data here is buggy - # probably due to casting dtypes/strides - new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) - - new_block_lse = torch.exp(block_lse - new_lse) - out = (torch.exp(lse - new_lse) * out + new_block_lse * block_out).to(out) - lse = new_lse - - # Equivalent to the above - # See https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 - # out = (out - F.sigmoid(block_lse - lse) * (out - block_out)) - # lse = (lse - F.logsigmoid(lse - block_lse)) - assert not (lse.isnan().any() or lse.isinf().any()), f"lse is nan: {lse}" - return out, lse - - class RingAttention(torch.autograd.Function): """Implements the Ring Attention from `Ring Attention with Blockwise Transformers for Near-Infinite Context` (https://arxiv.org/abs/2310.01889). @@ -577,6 +479,7 @@ def attention( """ Ring Attention forward pass supporting variable-length sequences. When using varlen mode, each sequence in the batch should have length divisible by sp_size * 2. + Args: q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D] k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D] @@ -725,22 +628,6 @@ def forward( RingAttention.HALF_INDICES = (half_idx_front, half_idx_back) RingAttention.CU_SEQLENS = cu_seqlens - if is_packed: - t, h, d = q.shape - # half of each seq - half_idx_front = get_half_index(cu_seqlens, front=True) - half_idx_back = get_half_index(cu_seqlens, front=False) - RingAttention.HALF_INDICES = (half_idx_front, half_idx_back) - RingAttention.CU_SEQLENS = cu_seqlens - - if is_packed: - t, h, d = q.shape - # half of each seq - half_idx_front = get_half_index(cu_seqlens, front=True) - half_idx_back = get_half_index(cu_seqlens, front=False) - RingAttention.HALF_INDICES = (half_idx_front, half_idx_back) - RingAttention.CU_SEQLENS = cu_seqlens - if is_packed: t, h, d = q.shape else: @@ -768,11 +655,6 @@ def forward( if sp_rank != sp_size - 1: q1 = q[half_idx_back] - # Non-contiguous indexing creates a new contiguous tensor, - # so only do it once - if sp_rank != sp_size - 1: - q1 = q[half_idx_back] - # Pre-allocate double buffer for overlapping and receiving next step's inputs kv_buffers = [torch.stack((k, v))] # (2, B, Sq, H, D) kv_buffers.append(torch.empty_like(kv_buffers[0])) @@ -953,7 +835,6 @@ def _other_ring_forward(ring_num_idx, out, softmax_lse): del misc_kwargs["return_softmax"] ctx.misc_kwargs = misc_kwargs ctx.is_packed = is_packed - ctx.dkv_group = dkv_group ctx.kv_group = inner_ring_group ctx.inter_kv_group = inter_ring_group @@ -1029,10 +910,6 @@ def backward(ctx, dout, _): softmax_lse1 = softmax_lse[:, half_idx_back] dout = dout.contiguous() - if sp_rank != sp_size - 1: - softmax_lse1 = softmax_lse[:, half_idx_back] - dout = dout.contiguous() - # Double comm buffers for sending and receiving kv kv_buffers = [torch.stack((k, v))] # (2, T, H, D) kv_buffers.append(torch.empty_like(kv_buffers[0])) @@ -1276,4 +1153,4 @@ def prepare_varlen_batch( mask_info["valid_indices"] = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() mask_info["cu_seqlens"] //= sp_size mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL - return inputs_embeds, mask_info, position_ids + return inputs_embeds, mask_info, position_ids \ No newline at end of file From 43c0b652b48f39dfe4dc4eed05aa71009a4a01cd Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 9 Aug 2024 11:51:39 +0000 Subject: [PATCH 46/71] simplify forward/backward logic --- colossalai/shardformer/layer/attn.py | 9 +++++++++ colossalai/shardformer/shard/shard_config.py | 1 - 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index e4640b01021b..8a02d121c95e 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -517,6 +517,14 @@ def attention( assert ( attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES ), f"Mask type {attention_mask_type} is not supported yet." + if dkv_group is None: + if RingAttention.DKV_GROUP is None or dist.get_process_group_ranks( + sp_group + ) != dist.get_process_group_ranks(RingAttention.DKV_GROUP): + ranks = dist.get_process_group_ranks(sp_group) + RingAttention.DKV_GROUP = dkv_group = dist.new_group(ranks) + else: + dkv_group = RingAttention.DKV_GROUP clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg)) @@ -823,6 +831,7 @@ def _other_ring_forward(ring_num_idx, out, softmax_lse): else: out, softmax_lse = _other_ring_forward(ring_num_idx, out, softmax_lse) + # torch.cuda.current_stream().wait_stream(sp_stream) out = out.to(q.dtype) if not is_packed: out = out.view(b, sq, h, d) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 084c818e18b9..589ed730ec79 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -56,7 +56,6 @@ class ShardConfig: moe_dp_group: Optional[ProcessGroup] = None ep_group: Optional[ProcessGroup] = None sp_stream: Optional[torch.cuda.Stream] = None - dkv_group: Optional[ProcessGroup] = None # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] From fb4e9050bf0e94977e6375a97445f00b72470843 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 12 Aug 2024 11:10:24 +0000 Subject: [PATCH 47/71] 2d ring forward passed --- .../booster/plugin/hybrid_parallel_plugin.py | 1613 +++++++++++++++-- colossalai/shardformer/layer/attn.py | 9 - colossalai/shardformer/shard/shard_config.py | 2 - examples/language/llama/benchmark.py | 3 - .../test_layer/test_ring_attn.py | 6 +- .../test_model/test_shard_llama.py | 15 +- 6 files changed, 1459 insertions(+), 189 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 13b40f97d7ee..6c785d4aed4d 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,186 +1,1461 @@ +import ctypes +import random +import warnings +from collections import defaultdict +from contextlib import contextmanager, nullcontext +from copy import deepcopy +from functools import partial +from types import MethodType +from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union + +import numpy as np import torch import torch.distributed as dist -import torch.nn.functional as F -from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_func -from torch.testing import assert_close - -import colossalai -from colossalai.shardformer.layer import AttnMaskType -from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention -from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device - - -@parameterize("seq_len", [4096]) -@parameterize("bs", [2]) -@parameterize("nheads", [5]) -@parameterize("d", [128]) -@parameterize("dtype", [torch.bfloat16, torch.float16]) -def check_ring_attn(seq_len, bs, nheads, d, dtype): - torch.cuda.manual_seed(2) - device = get_current_device() - sp_group = dist.group.WORLD - sp_size = dist.get_world_size() - # Some outliers may seem large, but our errors are still lower than - # than Megatron-LM context parallel's - # (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215) - # and the original zigzag implementation's (https://github.com/zhuzilin/ring-flash-attention/tree/main) - atol = rtol = 7e-3 - - # Setup inputs - qkv = torch.randn(bs, seq_len, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) - local_qkv = split_batch_zigzag(qkv, sp_group) - q, k, v = local_qkv.unbind(dim=-3) - q, k, v = [x.squeeze(2).detach().clone().transpose(1, 2) for x in (q, k, v)] # (B, nHeads, Sq, D) - q.requires_grad = k.requires_grad = v.requires_grad = True - - # Ring attention vs single GPU - ring_out, ring_lse = RingAttention.attention( - q, - k, - v, - sp_group, - AttnMaskType.CAUSAL, - return_softmax=True, - inner_ring_size=max(2, sp_size // 2), - # inner_ring_size=4 - ) - ring_out = ring_out.transpose(1, 2) - out, lse, _ = flash_attn_qkvpacked_func( - qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True - ) - - # Checkout out and softmax denominator - local_out = split_batch_zigzag(out, sp_group) - local_lse = split_batch_zigzag(lse, sp_group, seq_dim=-1) - local_lse = local_lse.transpose(1, 2).contiguous().view(-1, ring_lse.shape[-1]) # (B, nHeads, Sq) -> (T, nHeads) - assert_close(ring_lse, local_lse, atol=atol, rtol=rtol) - assert_close(ring_out, local_out, atol=atol, rtol=rtol) - - # Check grads - ring_out.sum().backward() - out.sum().backward() - ring_dq, ring_dk, ring_dv = [x.transpose(1, 2) for x in (q.grad, k.grad, v.grad)] - dqkv = qkv.grad - local_dqkv = split_batch_zigzag(dqkv, sp_group) - - assert_close(ring_dq, local_dqkv[:, :, 0], atol=atol, rtol=rtol) - assert_close(ring_dk, local_dqkv[:, :, 1], atol=atol, rtol=rtol) - assert_close(ring_dv, local_dqkv[:, :, 2], atol=atol, rtol=rtol) - if dist.get_rank() == 0: - print( - f"sp_size {dist.get_world_size()}, inner ring size {dist.get_world_size(RingAttention.INNER_RING_GROUP)} passed." +from torch import Tensor, inf +from torch.distributed import ProcessGroup, get_world_size +from torch.nn import Module, SyncBatchNorm +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils._pytree import tree_map +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from colossalai.accelerator import get_accelerator +from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer +from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO +from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper +from colossalai.interface.optimizer import DistributedOptim +from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed +from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.quantization import BnbQuantizationConfig, quantize_model +from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer +from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp +from colossalai.shardformer.policies.base_policy import Policy +from colossalai.tensor.colo_parameter import ColoParameter +from colossalai.tensor.d_tensor.api import is_distributed_tensor +from colossalai.tensor.param_op_hook import ColoParamOpHookManager +from colossalai.zero.low_level import LowLevelZeroOptimizer +from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle + +from .pp_plugin_base import PipelinePluginBase + +SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all", "ring_attn"] + +PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} + + +def _convert_floating_point(x, dtype: torch.dtype = torch.float16): + if isinstance(x, torch.Tensor) and torch.is_floating_point(x): + return x.to(dtype) + return x + + +class HybridParallelModule(ModelWrapper, AMPModelMixin): + def __init__( + self, + module: Module, + precision: str, + shard_config: ShardConfig, + dp_group: ProcessGroup, + tp_group: ProcessGroup, + sp_group: ProcessGroup, + use_ddp: bool, + ddp_config: dict, + custom_policy: Policy, + overlap_allgather: bool = False, + ) -> None: + self.stage_manager = shard_config.pipeline_stage_manager + self.shard_config = shard_config + self.dp_group = dp_group + self.tp_group = tp_group + self.sp_group = sp_group + self.use_ddp = use_ddp + self.require_grad_sync = True + self.overlap_allgather = overlap_allgather + + shardformer = ShardFormer(shard_config) + if custom_policy is not None: + assert isinstance(custom_policy, object) + module, self.shared_params = shardformer.optimize(module, policy=custom_policy) + + # setting process groups for shared parameters + self.shared_param_process_groups = [] + for shared_param in self.shared_params: + if len(shared_param) > 0: + self.shared_param_process_groups.append( + self.stage_manager.init_process_group_by_stages(list(shared_param.keys())) + ) + + # setting mixed_precision + self.mixed_precision = None + if precision == "fp16": + self.mixed_precision = torch.float16 + elif precision == "bf16": + self.mixed_precision = torch.bfloat16 + if self.mixed_precision is not None: + module = module.to(self.mixed_precision) + module = module.to(get_accelerator().get_current_device()) + + # setting input type cast when using mixed precision + self.convert_fn = None + if self.mixed_precision is not None: + self.convert_fn = partial(_convert_floating_point, dtype=self.mixed_precision) + + # setting ddp configs + if use_ddp: + # convert model to sync bn + module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group) + # wrap the model with PyTorch DDP + module = DDP(module, process_group=dp_group, **ddp_config) + + super().__init__(module) + if overlap_allgather: + self.op_hook = ZeroOpHook() + for p in module.parameters(): + if p.requires_grad and type(p) is not ColoParameter: + p.__class__ = ColoParameter + p.__init__(p, requires_grad=True) + + def sync_shared_params(self): + for shared_param, group in zip(self.shared_params, self.shared_param_process_groups): + if self.stage_manager.stage in shared_param: + param = shared_param[self.stage_manager.stage] + dist.all_reduce(param.grad, group=group) + dist.barrier() + + @contextmanager + def no_sync(self): + r""" + A context manager to disable automatic gradient synchronization (all-reduce) and allow manual synchronization + when 'no_sync' is active. Alternatively, synchronization will occur in the first forward-backward pass + when exiting the context. + """ + + # Store the current value of 'require_grad_sync' to restore it later. + old_require_grad_sync = self.require_grad_sync + # Disable automatic gradient synchronization. + self.require_grad_sync = False + try: + if self.use_ddp: + # If using data parallel processing (use_ddp), disable synchronization too. + with self.module.no_sync(): + yield + else: + yield + finally: + # Restore the original value of 'require_grad_sync'. + self.require_grad_sync = old_require_grad_sync + + def sync_dp_grads(self): + r""" + Synchronize gradients across data parallelism (DP) if the DP group size is greater than 1. + This function performs an all-reduce operation to combine gradients from different devices in the DP group. + + Args: + None + + Returns: + None + """ + + # Check if the DP group size is 1, meaning no synchronization is needed. + if self.dp_group.size() == 1: + return + + # Iterate through the model's parameters and perform gradient synchronization. + for p in self.module.parameters(): + if p.grad is not None: + # Perform all-reduce to combine gradients from different devices. + dist.all_reduce(p.grad, group=self.dp_group) + # Normalize the gradient by dividing it by the DP group size. + p.grad.div_(self.dp_group.size()) + + def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None): + r""" + Synchronize gradients that are partially derived within sequence parallelism + if sequence parallelism is enabled. Gradients can be provided explicitly or extracted + from the module. + + Args: + grads (Optional[List[torch.Tensor]]): A list of gradient tensors to synchronize. If not + provided, gradients will be extracted from the model. + + Returns: + None + """ + + if self.shard_config.enable_sequence_parallelism: + if self.shard_config.sequence_parallelism_mode in ["all_to_all", "ring_attn"]: + return + + if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: + # If sequence parallelism is enabled and mode is split_gather or ring, gradients are synchronized + # across the tensor parallelism group. + group = self.tp_group + else: + raise ValueError(f"Unknown sequence parallelism mode: {self.shard_config.sequence_parallelism_mode}") + + if grads is not None: + # Synchronize provided gradient tensors across the tensor parallelism group. + SeqParallelUtils.allreduce_partial_data_grad(process_group=group, grads=grads) + else: + # Synchronize gradients from the model across the tensor parallelism group. + SeqParallelUtils.allreduce_partial_data_grad(process_group=group, model=self.module) + + def forward(self, *args, **kwargs): + if self.convert_fn is not None: + args = tree_map(self.convert_fn, args) + kwargs = tree_map(self.convert_fn, kwargs) + with self._wait_all_gather(): + return super().forward(*args, **kwargs) + + def unwrap(self): + module = super().unwrap() + if isinstance(module, DDP): + module = module.module + return module + + def _force_wait_all_gather(self): + for p in self.module.parameters(): + wait_all_gather_handle(p) + + def _wait_all_gather(self): + return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext() + + +def get_param_info(optim: Optimizer): + # Get a backup of necessary information of parameters for future use, which includes: + # 1. A complete param_group, with params in the form of param_id + # 2. A mapping from param address (obtained using id(param)) to integer param_id + # 3. A mapping from integer param_id to param address. + # 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding. + # When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer. + + if optim is None: + return {} + param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}} + start_index = 0 + for group in optim.param_groups: + packed_group = {k: v for k, v in group.items() if k != "params"} + packed_group["params"] = [] + + for param_id, param in enumerate(group["params"], start_index): + original_shape = param.shape if isinstance(param, torch.Tensor) else None + packed_group["params"].append(param_id) + param_info["param2id"][id(param)] = param_id + param_info["id2param"][param_id] = id(param) + param_info["param2shape"][id(param)] = original_shape + + param_info["param_groups"].append(packed_group) + start_index += len(group["params"]) + + return param_info + + +def reinitialize_optimizer(optim: Optimizer, model: Module): + model_params = set(model.parameters()) + new_param_groups = [] + for group in optim.param_groups: + params = [p for p in group["params"] if p in model_params] + new_param_groups.append({**group, "params": params}) + optim.__setstate__({"param_groups": new_param_groups}) + + +class HybridParallelNaiveOptimizer(OptimizerWrapper): + def __init__( + self, + optim: Optimizer, + model: HybridParallelModule, + use_pipeline: bool, + param_info: OrderedDict, + max_norm: float = 0, + tp_process_group: Optional[ProcessGroup] = None, # if using tp + pp_process_group: Optional[ProcessGroup] = None, # if using pp + ): + self.param_info = param_info + if use_pipeline: + reinitialize_optimizer(optim, model) + self.model = model + self.stage_manager = model.stage_manager + self.shared_params = model.shared_params + self.max_norm = max_norm + self.tp_pg = tp_process_group + self.pp_pg = pp_process_group + self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 + self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 + super().__init__(optim) + + def backward(self, loss: Tensor, *args, **kwargs): + r""" + Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. + + This method performs backward pass for gradient computation. If sequence parallelism is enabled + and gradient synchronization is required, it will synchronize gradients that are partially derived + within sequence parallelism across tp parallelism groups. + + Args: + loss (Tensor): The loss tensor to compute gradients with respect to. + *args: Additional positional arguments to be passed to the superclass backward method. + **kwargs: Additional keyword arguments to be passed to the superclass backward method. + + Returns: + None + """ + + # Call the superclass backward method to compute gradients. + super().backward(loss, *args, **kwargs) + + if self.model.require_grad_sync: + # If gradient synchronization is required, sync sequence parallelism gradients. + self.model.sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + + def backward_by_grad(self, tensor: Tensor, grad: Tensor): + """ + Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. + + This method performs a backward pass for gradient computation using a precomputed gradient tensor. + If sequence parallelism is enabled and gradient synchronization is required, it will synchronize + gradients that are partially derived within sequence parallelism across tp parallelism groups. + + Args: + tensor (Tensor): The input tensor for which gradients are computed. + grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor. + + Returns: + None + """ + + # Call the superclass backward method to compute gradients. + super().backward_by_grad(tensor, grad) + + if self.model.require_grad_sync: + # If gradient synchronization is required, sync sequence parallelism gradients. + self.model.sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + + def step(self, *args, **kwargs): + r""" + Perform an optimization step. + + Args: + *args: Variable-length positional arguments to be passed to the optimizer's step function. + **kwargs: Keyword arguments to be passed to the optimizer's step function. + """ + + if self.max_norm > 0: + # Compute the total gradient norm. + param_gradient_pairs = [ + (p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None + ] + total_norm = self._compute_grad_norm(param_gradient_pairs) + + # Clip the gradients to prevent exploding gradients. + self._clip_grad_norm(total_norm) + + # Perform the optimization step using the underlying optimizer. + self.optim.step(*args, **kwargs) + + def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int: + r""" + Compute and return the gradient norm for gradient clipping. + + Args: + param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation. + norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2. + + Returns: + float: The total norm of the given gradients. + """ + + if len(param_gradient_pairs) == 0: + return 0.0 + + norm_type = float(norm_type) + + # gradients used for norm calculation. + gradients = [grad for param, grad in param_gradient_pairs] + + if norm_type == inf: + total_norm = max(grad.data.abs().max() for grad in gradients) + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) + if self.tp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) + if self.pp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) + total_norm = total_norm_cuda.item() + else: + # gradients used for norm calculation. + gradients = [grad for param, grad in param_gradient_pairs] + # grad_to_param_mapping is used to check which gradients are not distributed across devices of the 'tp_group'. + grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs} + + total_norm_exponentiated = 0.0 + for grad in gradients: + grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type + + # If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor, + # it indicates that the parameter is not distributed across devices of the 'tp_group'. + # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. + # However, we still perform the 'all_reduce' operation for the sake of good coding practices. + # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' + if self.tp_size > 1: + param_for_grad = grad_to_param_mapping[id(grad)] + if not is_distributed_tensor(param_for_grad): + grad_norm_exponentiated /= self.tp_size + + # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, + # it means that this parameter is used in two different pipeline stages. + # To avoid redundant norm calculations, we divide the exponent of this norm by + # the number of shared stages. + if self.pp_size > 1: + for shared_param in self.shared_params: + if self.stage_manager.stage in shared_param: + stage_shared_param = shared_param[self.stage_manager.stage] + if grad is stage_shared_param.grad: + grad_norm_exponentiated /= len(shared_param) + + total_norm_exponentiated += grad_norm_exponentiated + + total_norm_exponentiated_cuda = torch.tensor( + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) + if self.tp_size > 1: + # compute norm in tp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) + if self.pp_size > 1: + # compute norm in pp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) + + # compute the total_norm + total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) + + return total_norm + + def _clip_grad_norm(self, total_norm: float) -> None: + r""" + Clips the gradients of the model's parameters to prevent exploding gradients. + + Args: + total_norm (float): The computed total gradient norm. + + Returns: + None + """ + clip_coef = torch.tensor(self.max_norm / (total_norm + 1e-6)) + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + + for group in self.optim.param_groups: + for p in group["params"]: + if p.grad is None: + continue + p.grad.data.mul_(clip_coef_clamped) + + def update_master_params(self, model: Module): + pass + + def get_working_to_master_map(self): + return None + + def get_master_to_working_map(self): + return None + + +class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): + def __init__( + self, + optim: Optimizer, + model: HybridParallelModule, + use_pipeline: bool, + param_info: OrderedDict, + precision: str = "fp16", + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + tp_process_group: Optional[ProcessGroup] = None, # if using tp + pp_process_group: Optional[ProcessGroup] = None, # if using pp + ): + self.model = model + self.param_info = param_info + self.stage_manager = model.stage_manager + self.shared_params = model.shared_params + self.tp_pg = tp_process_group + self.pp_pg = pp_process_group + self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 + self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 + if use_pipeline: + reinitialize_optimizer(optim, model) + super().__init__( + optim, + precision=precision, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + max_norm=max_norm, + ) + + def backward(self, loss: Tensor, *args, **kwargs): + r""" + Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. + + This method performs backward pass for gradient computation. If sequence parallelism is enabled + and gradient synchronization is required, it will synchronize gradients that are partially derived + within sequence parallelism across tp parallelism groups. + + Args: + loss (Tensor): The loss tensor to compute gradients with respect to. + *args: Additional positional arguments to be passed to the superclass backward method. + **kwargs: Additional keyword arguments to be passed to the superclass backward method. + + Returns: + None + """ + # Call the superclass backward method to compute gradients. + super().backward(loss, *args, **kwargs) + + if self.model.require_grad_sync: + # If gradient synchronization is required, sync sequence parallelism gradients. + self.model.sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + + def backward_by_grad(self, tensor: Tensor, grad: Tensor): + """ + Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. + + This method performs a backward pass for gradient computation using a precomputed gradient tensor. + If sequence parallelism is enabled and gradient synchronization is required, it will synchronize + gradients that are partially derived within sequence parallelism across tp parallelism groups. + + Args: + tensor (Tensor): The input tensor for which gradients are computed. + grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor. + + Returns: + None + """ + # Call the superclass backward method to compute gradients. + super().backward_by_grad(tensor, grad) + + if self.model.require_grad_sync: + # If gradient synchronization is required, sync sequence parallelism gradients. + self.model.sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + + def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int: + r""" + Compute and return the gradient norm for gradient clipping. + + Args: + param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation. + norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2. + + Returns: + float: The total norm of the given gradients. + """ + if len(param_gradient_pairs) == 0: + return 0.0 + + norm_type = float(norm_type) + + if norm_type == inf: + # The parent class calculates the norm of 'dp' gradients, + # so we need to calculate the norm of 'tp' and 'pp' gradients. + total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type) + + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) + + if self.tp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) + if self.pp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) + + total_norm = total_norm_cuda.item() + + else: + # gradients used for norm calculation. + gradients = [grad for param, grad in param_gradient_pairs] + # grad_to_param_mapping is used to check which gradients are not distributed in tensor parallelism. + grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs} + + total_norm_exponentiated = 0.0 + for grad in gradients: + grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type + + # If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor, + # it indicates that the parameter is not distributed across devices of the 'tp_group'. + # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. + # However, we still perform the 'all_reduce' operation for the sake of good coding practices. + # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' + if self.tp_size > 1: + param_for_grad = grad_to_param_mapping[id(grad)] + if not is_distributed_tensor(param_for_grad): + grad_norm_exponentiated /= self.tp_size + + # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, + # it means that this parameter is used in two different pipeline stages. + # To avoid redundant norm calculations, we divide the exponent of this norm by + # the number of shared stages. + if self.pp_size > 1: + for shared_param in self.shared_params: + if self.stage_manager.stage in shared_param: + stage_working_shared_param = shared_param[self.stage_manager.stage] + stage_master_shared_param = self.working_to_master_map[stage_working_shared_param] + if grad is stage_master_shared_param.grad: + grad_norm_exponentiated /= len(shared_param) + + total_norm_exponentiated += grad_norm_exponentiated + + total_norm_exponentiated_cuda = torch.tensor( + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) + if self.tp_size > 1: + # compute norm in tp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) + if self.pp_size > 1: + # compute norm in pp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) + + # compute the total_norm + total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) + + return total_norm + + +class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): + def __init__( + self, + optimizer: Optimizer, + model: HybridParallelModule, + use_pipeline: bool, + param_info: OrderedDict, + pg_to_param_list: Dict[ProcessGroup, List[torch.nn.Parameter]] = None, + initial_scale: int = 2**16, # grad scaler config + min_scale: int = 1, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + hysteresis: int = 2, + max_scale: int = 2**24, + clip_grad_norm: float = 0.0, # grad clipping + verbose: bool = False, + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload + dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm + tp_process_group: Optional[ProcessGroup] = None, # if using tp + pp_process_group: Optional[ProcessGroup] = None, # if using pp + forced_dtype: Optional[torch.dtype] = None, + overlap_allgather: bool = False, + ): + self.model = model + self.param_info = param_info + self.stage_manager = model.stage_manager + self.shared_params = model.shared_params + self.tp_pg = tp_process_group + self.pp_pg = pp_process_group + if use_pipeline: + reinitialize_optimizer(optimizer, model) + super().__init__( + optimizer=optimizer, + initial_scale=initial_scale, + min_scale=min_scale, + pg_to_param_list=pg_to_param_list, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + clip_grad_norm=clip_grad_norm, + verbose=verbose, + reduce_bucket_size=reduce_bucket_size, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + partition_grad=partition_grad, + cpu_offload=cpu_offload, + dp_process_group=dp_process_group, + forced_dtype=forced_dtype, + overlap_allgather=overlap_allgather, ) + def sync_dp_grads(self): + r""" + Synchronize gradients in the data parallelism dimension. + + This method wraps the existing `_sync_grad` method in order to explicitly synchronize gradients + in the data parallelism dimension. It is necessary due to the introduction of new parallel dimensions, + namely tp (tensor parallelism) and pp (pipeline parallelism). This ensures better code organization + and readability. + + Args: + None + + Returns: + None + """ + # Call the superclass `_sync_grad` method to synchronize gradients. + super()._sync_grad() + + def _sync_sp_grads(self): + r""" + Synchronize gradients that are partially derived within sequence parallelism. + + This method is responsible for synchronizing partially derived gradients across tp parallelism groups. + It identifies gradients that ara partially derived or not and synchronizes them. + If synchronization is required and gradients are found to be synchronized, + it performs the synchronization. + + Args: + None + + Returns: + None + """ + + def _get_all_working_grads() -> List[Tensor]: + """Retrieve all working gradients from different parameter groups.""" + all_working_grads = [] + for group_id in range(self.num_param_groups): + working_grads = self.get_working_grads_by_group_id(group_id) + all_working_grads.extend(working_grads) + return all_working_grads + + def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]: + """Identify gradients to be synchronized in the sequence parallelism.""" + grads_to_sync = [] + for grad in all_working_grads: + param_id_for_grad = self.get_param_id_for_grad(grad) + param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value + if SeqParallelUtils.is_sp_partial_derived_param(param_for_grad): + grads_to_sync.append(grad) + + if len(grads_to_sync) > 0: + return grads_to_sync + else: + return None + + # Get all working gradients and gradients to be synchronized. + all_working_grads = _get_all_working_grads() + grads_to_sync = _get_grads_to_sync(all_working_grads) + if self.require_grad_sync and grads_to_sync is not None: + # Synchronize sequence parallelism gradients if required. + SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync) + else: + return + + def backward(self, loss, retain_graph=False): + """ + Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. + + This method performs the backward pass for gradient computation based on a given loss tensor. + If sequence parallelism is enabled and gradient synchronization is required, it will synchronize + gradients that are partially derived within sequence parallelism across TP parallelism groups. + + Args: + loss: The loss tensor to compute gradients with respect to. + retain_graph (bool): Whether to retain the computation graph. + + Returns: + None + """ + # Call the superclass backward method to compute gradients. + super().backward(loss, retain_graph) + + if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: + # If gradient synchronization is required, sync sequence parallelism gradients. + self._sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + + def backward_by_grad(self, tensor, grad): + """ + Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. + + This method performs a backward pass for gradient computation based on a precomputed gradient tensor. + If sequence parallelism is enabled and gradient synchronization is required, it will synchronize + gradients that are partially derived within sequence parallelism across TP parallelism groups. + + Args: + tensor: The input tensor for which gradients are computed. + grad: The precomputed gradient tensor to compute gradients with respect to the input tensor. + + Returns: + None + """ + # Call the superclass backward_by_grad method to compute gradients. + super().backward_by_grad(tensor, grad) + + if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: + # If gradient synchronization is required, sync sequence parallelism gradients. + self._sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + + def _compute_grad_norm(self, dp_pg, gradients: List[Tensor], norm_type: int = 2) -> float: + r""" + Compute and return the gradient norm for gradient clipping. + + Args: + gradients (List[Tensor]): A list of tensors containing gradients. + norm_type (int, optional): Type of the p-norm to be computed. Defaults to 2. + + Returns: + float: The computed gradient norm. + """ + + # Check if the list of gradients is empty + if len(gradients) == 0: + return 0.0 + + dp_size = get_world_size(dp_pg) if dp_pg is not None else 1 + tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 + pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 + norm_type = float(norm_type) + + if norm_type == inf: + # The parent class calculates the norm of 'dp' gradients, + # so we only need to calculate the norm 'tp' of 'pp' gradients. + total_norm = super()._compute_grad_norm(gradients, norm_type) + + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) + + if tp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) + if pp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) + + total_norm = total_norm_cuda.item() + else: + total_norm_exponentiated = 0.0 + for grad in gradients: + grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type + + # If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor, + # it indicates that the parameter is not distributed across devices of the 'tp_group'. + # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. + # However, we still perform the 'all_reduce' operation for the sake of good coding practices. + # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' + if tp_size > 1: + param_id_for_grad = self.get_param_id_for_grad(grad) + param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value + + if not is_distributed_tensor(param_for_grad): + grad_norm_exponentiated /= tp_size + + # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, + # it means that this parameter is used in two different pipeline stages. + # To avoid redundant norm calculations, we divide the exponent of this norm by + # the number of shared stages. + if pp_size > 1: + for shared_param in self.shared_params: + if self.stage_manager.stage in shared_param: + stage_shared_param = shared_param[self.stage_manager.stage] + working_grad = self.get_working_grad_by_param_id(id(stage_shared_param)) + if grad is working_grad: + grad_norm_exponentiated /= len(shared_param) + + total_norm_exponentiated += grad_norm_exponentiated + + total_norm_exponentiated_cuda = torch.tensor( + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) + if dp_size > 1: + # compute norm in dp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=dp_pg) + if tp_size > 1: + # compute norm in tp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) + if pp_size > 1: + # compute norm in pp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) + + # Compute the 'total_norm' from 'total_norm_exponentiated' + total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) + + return total_norm -@parameterize("seqlen", [4096]) -@parameterize("bs", [2]) -@parameterize("nheads", [5]) -@parameterize("d", [128]) -@parameterize("dtype", [torch.bfloat16, torch.float16]) -def check_packed_seq(seqlen, bs, nheads, d, dtype): - device = get_current_device() - sp_group = dist.group.WORLD - sp_size = dist.get_world_size() - atol = rtol = 7e-3 - torch.cuda.manual_seed(2) - # Prepare varlen attention mask - padding_mask = torch.ones((bs, seqlen), dtype=torch.int, device=device) - padding_mask[: bs // 2, (seqlen // 4) * 3 :] = 0 - padding_mask[:, seqlen // 2 :] = 0 - - input_embeds = torch.randn(bs, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) - - # Forward - # out = ColoAttention.attention(q, k, v, **mask_info) - flat_input = input_embeds.view(-1, nheads, d)[padding_mask.flatten().nonzero().squeeze()] - qkv = torch.stack([flat_input] * 3, dim=1) - qkv.retain_grad() - - input_embeds, mask_info, _ = RingAttention.prepare_varlen_batch(padding_mask, sp_group, input_embeds) - out, lse, _ = flash_attn_varlen_qkvpacked_func( - qkv, - mask_info["cu_seqlens"] * sp_size, - mask_info["max_seqlen"] * sp_size, - return_attn_probs=True, - causal=True, - # deterministic=True - ) - # Test the splitting function - local_input = split_varlen_zigzag( - flat_input, mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size - ) - assert (local_input == input_embeds.view(-1, nheads, d)[mask_info["valid_indices"]]).all() - del local_input, flat_input - - q_ring, k_ring, v_ring = [input_embeds.clone().transpose(1, 2) for _ in range(3)] - q_ring.retain_grad() - k_ring.retain_grad() - v_ring.retain_grad() - - ring_out, ring_lse = RingAttention.attention( - q_ring, - k_ring, - v_ring, - sp_group, - **mask_info, - pad_output=False, - return_softmax=True, - # deterministic=True - ) - ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d) - # Check output - lse = lse.transpose(0, 1) - out, lse = split_varlen_zigzag( - [out, lse], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size - ) - assert_close(lse, ring_lse, atol=atol, rtol=rtol) - assert_close(out, ring_out, atol=atol, rtol=rtol) - - # Check grads - labels = torch.ones(out.shape[0], dtype=dtype, device=device) - F.mse_loss(out.sum((-2, -1)), labels).backward() - F.mse_loss(ring_out.sum((-2, -1)), labels[: ring_out.shape[0]]).backward() - dq, dk, dv = [ - split_varlen_zigzag( - qkv.grad[:, i], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size + +class HybridParallelPlugin(PipelinePluginBase): + """ + Plugin for Hybrid Parallel Training. + Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin. + The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size). + + ```python + from colossalai.booster import Booster + from colossalai.booster.plugin import HybridParallelPlugin + + model, train_dataset, optimizer, criterion = ... + plugin = HybridParallelPlugin(tp_size=2, pp_size=2) + + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + booster = Booster(plugin=plugin) + model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader) + ``` + + Args: + tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. + pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1. + sp_size (int): The size of sequence parallelism. + precision (str, optional): Specifies the precision of parameters during training. + Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'. + Defaults to 'fp16'. + zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2]. + When set to 0, ZeRO will not be used. Defaults to 0. + enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. + Currently all the optimization methods include fused normalization, flash attention and JIT. + Defaults to False. + enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False. + enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. + enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. + enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. + sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather". + enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. + parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True. + num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. + microbatch_size (int, optional): Microbatch size when using pipeline parallelism. + Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline. + If ``num_microbatches`` is provided, this will be ignored. Defaults to None. + initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16. + min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1. + growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2. + backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5. + growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000. + hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2. + max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32. + max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0. + broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True. + ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25. + find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False. + check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False. + gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False. + static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False. + zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12. + cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. + communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. + overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. + custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. + pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. + num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. + gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None. + enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. + make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64. + overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism + inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn". + It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default. + + """ + + def __init__( + self, + tp_size: int, + pp_size: int, + sp_size: int = None, + precision: str = "fp16", + zero_stage: int = 0, + enable_all_optimization: bool = False, + enable_fused_normalization: bool = False, + enable_flash_attention: bool = False, + enable_jit_fused: bool = False, + enable_sequence_parallelism: bool = False, + sequence_parallelism_mode: str = None, + enable_sequence_overlap: bool = False, + parallel_output: bool = True, + num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + broadcast_buffers: bool = True, + ddp_bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + zero_bucket_size_in_m: int = 12, + cpu_offload: bool = False, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + custom_policy: Policy = None, + pp_style: str = "1f1b", + num_model_chunks: int = 1, + num_layers_per_stage: Optional[List[int]] = None, + gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, + enable_metadata_cache: bool = True, + make_vocab_size_divisible_by: int = 64, + dp_outside: bool = True, + overlap_p2p: bool = True, + overlap_allgather: bool = False, + inner_ring_size: int = None, + ) -> None: + super().__init__() + + assert ( + dist.get_world_size() % (tp_size * pp_size) == 0 + ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" + + if enable_sequence_parallelism: + self.sequence_parallelism_mode = ( + sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" + ) + assert ( + self.sequence_parallelism_mode in SUPPORT_SP_MODE + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}" + if self.sequence_parallelism_mode in ["split_gather", "ring"]: + assert ( + tp_size > 1 + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" + if sp_size != 1: + warnings.warn( + f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." + ) + self.sp_size = 1 + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + elif self.sequence_parallelism_mode in ["all_to_all", "ring_attn"]: + self.sp_size = 1 if sp_size is None else sp_size + self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size) + if self.sequence_parallelism_mode == "ring_attn": + enable_flash_attention = True + else: + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + assert ( + sp_size == 1 or sp_size is None + ), f"You should not set sp_size when sequence parallelism is not enabled." + self.sp_size = 1 + + self.tp_size = tp_size + self.pp_size = pp_size + self.precision = precision + self.zero_stage = zero_stage + self.cpu_offload = cpu_offload + self.enable_all_optimization = enable_all_optimization + self.enable_fused_normalization = enable_fused_normalization + self.enable_flash_attention = enable_flash_attention + self.enable_jit_fused = enable_jit_fused + self.enable_sequence_parallelism = enable_sequence_parallelism + if dp_outside: + self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 + if sequence_parallelism_mode == "ring_attn": + # Swap tp and sp since 2D Ring has better inter-node latency + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size) + self.sp_axis = 2 + self.tp_axis = 3 + else: + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) + else: + self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 + if sequence_parallelism_mode == "ring_attn": + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.sp_size, self.tp_size) + self.sp_axis = 2 + self.tp_axis = 3 + else: + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) + + self.stage_manager = None + self.schedule = None + self.custom_policy = custom_policy + assert zero_stage in (0, 1, 2) + if self.pp_size > 1: + assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" + assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" + assert ( + num_microbatches is not None or microbatch_size is not None + ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" + assert ( + self.zero_stage <= 1 + ), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism" + self.stage_manager = PipelineStageManager( + self.pg_mesh, + pipeline_axis=self.pp_axis, + enable_interleave=pp_style == "interleaved", + num_model_chunks=num_model_chunks, + num_layers_per_stage=num_layers_per_stage, + ) + + if pp_style == "interleaved": + assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" + self.schedule = InterleavedSchedule( + stage_manager=self.stage_manager, + num_model_chunks=num_model_chunks, + num_microbatch=num_microbatches, + microbatch_size=microbatch_size, + enable_metadata_cache=enable_metadata_cache, + overlap_p2p=overlap_p2p, + ) + elif pp_style == "1f1b": + self.schedule = OneForwardOneBackwardSchedule( + stage_manager=self.stage_manager, + num_microbatches=num_microbatches, + microbatch_size=microbatch_size, + enable_metadata_cache=enable_metadata_cache, + ) + else: + raise NotImplementedError() + if sequence_parallelism_mode == "ring_attn": + assert parallel_output, "Ring Attention doesn't support gathering output yet." + + self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) + self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) + self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis) + if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]: + self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) + else: + self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) + + self.shard_config = ShardConfig( + tensor_parallel_process_group=self.tp_group, + sequence_parallel_process_group=self.sp_group, + pipeline_stage_manager=self.stage_manager, + enable_tensor_parallelism=self.tp_size > 1, + enable_all_optimization=self.enable_all_optimization, + enable_fused_normalization=self.enable_fused_normalization, + enable_flash_attention=self.enable_flash_attention, + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism, + sequence_parallelism_mode=sequence_parallelism_mode, + enable_sequence_overlap=enable_sequence_overlap, + parallel_output=parallel_output, + make_vocab_size_divisible_by=make_vocab_size_divisible_by, + gradient_checkpoint_config=gradient_checkpoint_config, + inner_ring_size=inner_ring_size, + ) + self.amp_config = dict( + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + ) + + self.ddp_config = dict( + broadcast_buffers=broadcast_buffers, + bucket_cap_mb=ddp_bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph, + ) + + self.zero_config = dict( + reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload, + partition_grad=(self.zero_stage == 2), + forced_dtype=PRECISION_TORCH_TYPE[precision], + overlap_allgather=overlap_allgather, ) - for i in range(3) - ] - dq_ring, dk_ring, dv_ring = [ - x.transpose(1, 2).reshape(-1, nheads, d)[mask_info["valid_indices"]] - for x in (q_ring.grad, k_ring.grad, v_ring.grad) - ] - assert_close(dq, dq_ring, atol=atol, rtol=rtol) - assert_close(dk, dk_ring, atol=atol, rtol=rtol) - assert_close(dv, dv_ring, atol=atol, rtol=rtol) + self.max_norm = max_norm + + def __del__(self): + """Destroy the process groups in ProcessGroupMesh""" + self.pg_mesh.destroy_mesh_process_groups() + + @property + def enable_pipeline_parallelism(self) -> bool: + return self.pp_size > 1 + + def supported_devices(self) -> List[str]: + return ["cuda", "npu"] + + def supported_precisions(self) -> List[str]: + return ["fp16", "bf16", "fp32"] + + def control_device(self) -> bool: + return True + + def control_precision(self) -> bool: + return True + def support_no_sync(self) -> bool: + return True -def launch_single_ring(rank, world_size, port): - colossalai.launch(rank, world_size, "localhost", port) - check_packed_seq() - check_ring_attn() + def support_lora(self) -> bool: + return True + def control_checkpoint_io(self) -> bool: + return True -def launch_double_ring(rank, world_size, port): - colossalai.launch(rank, world_size, "localhost", port) - check_ring_attn() + def configure( + self, + model: Module, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + param_info = get_param_info(optimizer) + + # TODO: Support Galore + ZeRO + zero_stage = self.zero_stage + zero_config = deepcopy(self.zero_config) + + # Replace with distributed implementation if exists + optimizer = cast_to_distributed(optimizer) + + if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: + warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.") + zero_config["partition_grad"] = False + zero_stage = 0 + + if not isinstance(model, ModelWrapper): + # Shouldn't use pp (frequent grad accumulation) with torch ddp + use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( + self.dp_size == 1 and self.pp_size == 1 + ) + + # Apply Hybrid ZeRO across DP * SP ranks + if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode): + dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) + self.dp_size = get_world_size(dp_group) + else: + dp_group = self.dp_group + model = HybridParallelModule( + model, + precision=self.precision, + shard_config=self.shard_config, + dp_group=dp_group, + tp_group=self.tp_group, + sp_group=self.sp_group, + use_ddp=use_ddp, + ddp_config=self.ddp_config, + custom_policy=self.custom_policy, + overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]), + ) + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + if zero_stage == 0: + is_zero = False + if self.precision in ["fp16", "bf16"]: + optimizer = HybridParallelAMPOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + precision=self.precision, + max_norm=self.max_norm, + pp_process_group=self.pp_group, + tp_process_group=self.tp_group, + **self.amp_config, + ) + else: + optimizer = HybridParallelNaiveOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + max_norm=self.max_norm, + pp_process_group=self.pp_group, + tp_process_group=self.tp_group, + ) + else: + is_zero = self.dp_size > 1 + if self.dp_size == 1: + warnings.warn( + "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " + "If you do not intend to use cpu_offload, please consider set zero_stage=0." + ) + + assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." + optimizer = HybridParallelZeroOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + dp_process_group=dp_group, + tp_process_group=self.tp_group, + pp_process_group=self.pp_group, + verbose=True, + clip_grad_norm=self.max_norm, + **zero_config, + **self.amp_config, + ) + # inject update_master_params + model.update_master_params = MethodType(optimizer.update_master_params, model) + + # Setup optimizers that require global states + optim = optimizer.optim + if isinstance(optim, DistributedOptim): + shard_to_param = optimizer.get_master_to_working_map() if is_zero else {} + padding_map = optimizer.get_param_padding_map() if is_zero else defaultdict(int) + optim.setup_distributed(self.tp_group, self.dp_group, shard_to_param, padding_map, is_zero) + + return model, optimizer, criterion, dataloader, lr_scheduler + + def execute_pipeline( + self, + data_iter: Iterator, + model: HybridParallelModule, + criterion: Callable[[Any, Any], torch.Tensor], + optimizer: Optional[ + Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, HybridParallelZeroOptimizer] + ] = None, + return_loss: bool = True, + return_outputs: bool = False, + ) -> dict: + assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled" + + if return_outputs: + warnings.warn("return_outputs may lead to significant extra memory consumption.") + + # Create a context for gradient synchronization based on the optimizer type. + # If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync(). + # This is to avoid redundant gradient reduction in pipeline parallelism (multiple microbatch values should be reduced once), + # so we disable it, performing manual reduction instead. + ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() + + with ctx, model._wait_all_gather(): + outputs = self.schedule.forward_backward_step( + model, data_iter, criterion, optimizer, return_loss, return_outputs + ) + + # run with gradients accumulation + if ( + model.require_grad_sync == False + or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False) + or not torch.is_grad_enabled() + ): + return outputs + + # Synchronize the grads of shared parameters of the model. + model.sync_shared_params() + # Synchronize sequence parallelism gradients of the model. + model.sync_sp_grads() + + # Check if the optimizer is a HybridParallelZeroOptimizer and synchronize data parallelism gradients if so. + # Otherwise, synchronize data parallelism gradients of the model. + # This is because these are two different forms of data parallelism. + if isinstance(optimizer, HybridParallelZeroOptimizer): + optimizer.sync_dp_grads() + else: + model.sync_dp_grads() + + return outputs + + def prepare_dataloader( + self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + distributed_sampler_cls=None, + **kwargs, + ): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. + + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns:` + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + distributed_sampler_cls = distributed_sampler_cls or DistributedSampler + sampler = distributed_sampler_cls( + dataset, + num_replicas=self.dp_group.size(), + rank=dist.get_group_rank(self.dp_group, global_rank=dist.get_rank()), + shuffle=shuffle, + ) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) + def get_checkpoint_io(self) -> CheckpointIO: + return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) -@rerun_if_address_is_in_use() -@parameterize("world_size", [2]) -def test_ring_attn(world_size): - spawn(launch_single_ring, nprocs=world_size) + def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]: + assert ( + self.zero_stage != 2 + ), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed." + return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() + def enable_lora( + self, + model: Module, + pretrained_dir: Optional[str] = None, + lora_config: Optional[Dict] = None, + bnb_quantization_config: Optional[BnbQuantizationConfig] = None, + ) -> Module: + from peft import PeftModel, get_peft_model -@rerun_if_address_is_in_use() -@parameterize("world_size", [4]) -def test_double_ring(world_size): - spawn(launch_double_ring, nprocs=world_size) + assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model." + assert self.pp_size == 1 and self.tp_size == 1 + self.lora_enabled = True + warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") + if bnb_quantization_config is not None: + model = quantize_model(model, bnb_quantization_config) -if __name__ == "__main__": - test_ring_attn() - test_double_ring() \ No newline at end of file + if pretrained_dir is None: + peft_model = get_peft_model(model, lora_config) + else: + peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True) + return peft_model \ No newline at end of file diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 8a02d121c95e..e4640b01021b 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -517,14 +517,6 @@ def attention( assert ( attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES ), f"Mask type {attention_mask_type} is not supported yet." - if dkv_group is None: - if RingAttention.DKV_GROUP is None or dist.get_process_group_ranks( - sp_group - ) != dist.get_process_group_ranks(RingAttention.DKV_GROUP): - ranks = dist.get_process_group_ranks(sp_group) - RingAttention.DKV_GROUP = dkv_group = dist.new_group(ranks) - else: - dkv_group = RingAttention.DKV_GROUP clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg)) @@ -831,7 +823,6 @@ def _other_ring_forward(ring_num_idx, out, softmax_lse): else: out, softmax_lse = _other_ring_forward(ring_num_idx, out, softmax_lse) - # torch.cuda.current_stream().wait_stream(sp_stream) out = out.to(q.dtype) if not is_packed: out = out.view(b, sq, h, d) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 589ed730ec79..70eb271c9b69 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -2,7 +2,6 @@ from dataclasses import dataclass, field from typing import Any, Dict, Optional -import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -55,7 +54,6 @@ class ShardConfig: # for moe related moe_dp_group: Optional[ProcessGroup] = None ep_group: Optional[ProcessGroup] = None - sp_stream: Optional[torch.cuda.Stream] = None # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index cbf497c1f8c5..093377e7a034 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -332,11 +332,8 @@ def empty_init(): performance_evaluator.on_step_start(step) outputs = model(**batch) loss = outputs[0] -<<<<<<< HEAD del outputs # free memory -======= ->>>>>>> precision tests passed if dist.get_rank() == dist.get_world_size() - 1: print(f"Step {step} loss: {loss}") booster.backward(loss, optimizer) diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 13b40f97d7ee..ce5820cf48d0 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -39,7 +39,6 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): ring_out, ring_lse = RingAttention.attention( q, k, - v, sp_group, AttnMaskType.CAUSAL, return_softmax=True, @@ -47,9 +46,6 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): # inner_ring_size=4 ) ring_out = ring_out.transpose(1, 2) - out, lse, _ = flash_attn_qkvpacked_func( - qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True - ) # Checkout out and softmax denominator local_out = split_batch_zigzag(out, sp_group) @@ -183,4 +179,4 @@ def test_double_ring(world_size): if __name__ == "__main__": test_ring_attn() - test_double_ring() \ No newline at end of file + test_double_ring() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 0adaccbca68e..81a7b62fe57e 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -153,7 +153,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - # Zigzag Ring Attention + PP + # # Double Ring Attention + # { + # "tp_size": 1, + # "pp_size": 1, + # "sp_size": 4, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "ring_attn", + # "use_lazy_init": True, + # "zero_stage": 2, + # "precision": "bf16", + # "initial_scale": 1, + # }, + # Ring Attention + PP { "tp_size": 1, "pp_size": 2, From 8bc062d0c3203641e1d0be7f928c4d7b5725e131 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 13 Aug 2024 14:47:26 +0000 Subject: [PATCH 48/71] 2d ring backward passed --- colossalai/shardformer/layer/attn.py | 3 +++ tests/test_shardformer/test_model/test_shard_llama.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index e4640b01021b..0f9fb0555e40 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -24,6 +24,7 @@ _flash_attn_forward = _flash_attn_backward = None _unpad_input = _pad_input = None +logger = get_dist_logger() class AttnMaskType(Enum): @@ -569,6 +570,7 @@ def attention( attention_mask_type == AttnMaskType.PADDED_CAUSAL, inner_ring_group, inter_ring_group, + inter_ring_group_copy, ) if attention_mask_type == AttnMaskType.PADDED_CAUSAL: @@ -599,6 +601,7 @@ def forward( is_packed: Optional[bool] = False, inner_ring_group: Optional[dist.ProcessGroup] = None, inter_ring_group: Optional[dist.ProcessGroup] = None, + inter_ring_group_copy: Optional[dist.ProcessGroup] = None, ): cu_seqlens_q = cu_seqlens_kv = cu_seqlens diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 81a7b62fe57e..7c93717ceee4 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -163,7 +163,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # "sequence_parallelism_mode": "ring_attn", # "use_lazy_init": True, # "zero_stage": 2, - # "precision": "bf16", + # "precision": "fp32", # "initial_scale": 1, # }, # Ring Attention + PP From d844ded39b42743a402f418dea0f7a0dc49aa2b3 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 14 Aug 2024 06:03:00 +0000 Subject: [PATCH 49/71] fixes --- colossalai/shardformer/modeling/llama.py | 2 -- tests/test_shardformer/test_model/test_shard_llama.py | 10 +++++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 219933c705e9..11875e90e595 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -826,8 +826,6 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if shard_config.sequence_parallelism_mode == "ring_attn": - labels = zigzag_split_batch(labels, shard_config.sequence_parallel_process_group) if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output: # Special processing: Split labels in a zigzag fashion too diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 7c93717ceee4..1a2665a0b9a4 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -153,7 +153,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - # # Double Ring Attention + # Double Ring Attention # { # "tp_size": 1, # "pp_size": 1, @@ -162,8 +162,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # "enable_sequence_parallelism": True, # "sequence_parallelism_mode": "ring_attn", # "use_lazy_init": True, - # "zero_stage": 2, - # "precision": "fp32", + # "zero_stage": 0, + # "precision": "fp16", # "initial_scale": 1, # }, # Ring Attention + PP @@ -176,7 +176,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "sequence_parallelism_mode": "ring_attn", "use_lazy_init": True, "zero_stage": 1, - "precision": "bf16", + "precision": "fp16", "initial_scale": 1, }, # Ring Attention + TP @@ -189,7 +189,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "sequence_parallelism_mode": "ring_attn", "use_lazy_init": True, "zero_stage": 2, - "precision": "bf16", + "precision": "fp16", "initial_scale": 1, }, { # Ulysess + TP From 8c012231ec6db66e6a690b601688cbec60e479c8 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 14 Aug 2024 08:02:01 +0000 Subject: [PATCH 50/71] fix ring attn loss --- colossalai/shardformer/layer/attn.py | 2 -- colossalai/shardformer/layer/utils.py | 5 +---- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 0f9fb0555e40..30e74e0ca797 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -570,7 +570,6 @@ def attention( attention_mask_type == AttnMaskType.PADDED_CAUSAL, inner_ring_group, inter_ring_group, - inter_ring_group_copy, ) if attention_mask_type == AttnMaskType.PADDED_CAUSAL: @@ -601,7 +600,6 @@ def forward( is_packed: Optional[bool] = False, inner_ring_group: Optional[dist.ProcessGroup] = None, inter_ring_group: Optional[dist.ProcessGroup] = None, - inter_ring_group_copy: Optional[dist.ProcessGroup] = None, ): cu_seqlens_q = cu_seqlens_kv = cu_seqlens diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 31d0a3b822da..b343998082e6 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -377,10 +377,7 @@ def split_varlen_zigzag( assert max_seqlen % (sp_size * 2) == 0 # Recreate a padded tensor with the new max seqlen shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:]) - if is_label: - local_seq = torch.full(shape, -100, dtype=dtype, device=device) - else: - local_seq = torch.zeros(shape, dtype=dtype, device=device) + local_seq = torch.zeros(shape, dtype=dtype, device=device) else: total_seqlen = cu_seqlens[-1] assert ( From 7a7fb1f5d5a870dd4bd239188629dd8e31a1bac0 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 14 Aug 2024 09:13:31 +0000 Subject: [PATCH 51/71] 2D ring backward + llama passed --- .../test_model/test_shard_llama.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 1a2665a0b9a4..904d7281f9d3 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -154,18 +154,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "test_config", [ # Double Ring Attention - # { - # "tp_size": 1, - # "pp_size": 1, - # "sp_size": 4, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "ring_attn", - # "use_lazy_init": True, - # "zero_stage": 0, - # "precision": "fp16", - # "initial_scale": 1, - # }, + { + "tp_size": 1, + "pp_size": 1, + "sp_size": 4, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring_attn", + "use_lazy_init": True, + "zero_stage": 0, + "precision": "fp16", + "initial_scale": 1, + }, # Ring Attention + PP { "tp_size": 1, From 78ed55dc1c709b2ebd07783a45b147dd9165ada3 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 14 Aug 2024 13:15:03 +0000 Subject: [PATCH 52/71] merge --- colossalai/shardformer/layer/attn.py | 10 +++++ colossalai/shardformer/layer/loss.py | 5 +-- colossalai/shardformer/modeling/bloom.py | 25 ++++++------ colossalai/shardformer/modeling/chatglm2.py | 39 +++++++++++++++---- colossalai/shardformer/modeling/command.py | 35 ++++++++++------- colossalai/shardformer/modeling/gpt2.py | 16 +++++--- colossalai/shardformer/modeling/llama.py | 12 +++--- colossalai/shardformer/modeling/mistral.py | 14 +++---- colossalai/shardformer/modeling/opt.py | 23 +++++------ colossalai/shardformer/modeling/qwen2.py | 12 +++--- colossalai/shardformer/policies/chatglm2.py | 2 +- .../test_model/test_shard_chatglm2.py | 30 ++++++-------- .../test_model/test_shard_command.py | 6 ++- 13 files changed, 137 insertions(+), 92 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 30e74e0ca797..276c7c576c5d 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -656,6 +656,11 @@ def forward( if sp_rank != sp_size - 1: q1 = q[half_idx_back] + # Non-contiguous indexing creates a new contiguous tensor, + # so only do it once + if sp_rank != sp_size - 1: + q1 = q[half_idx_back] + # Pre-allocate double buffer for overlapping and receiving next step's inputs kv_buffers = [torch.stack((k, v))] # (2, B, Sq, H, D) kv_buffers.append(torch.empty_like(kv_buffers[0])) @@ -891,6 +896,7 @@ def backward(ctx, dout, _): local_sp_rank = dist.get_rank(sp_group) sp_size = dist.get_world_size(sp_group) + sp_rank = dist.get_rank(sp_group) # Using separate streams (pg) for concurrent kv and dkv comm may # cause NCCL "software caused connection abort" here... local_kv_comm = RingComm(local_kv_group) @@ -911,6 +917,10 @@ def backward(ctx, dout, _): softmax_lse1 = softmax_lse[:, half_idx_back] dout = dout.contiguous() + if sp_rank != sp_size - 1: + softmax_lse1 = softmax_lse[:, half_idx_back] + dout = dout.contiguous() + # Double comm buffers for sending and receiving kv kv_buffers = [torch.stack((k, v))] # (2, T, H, D) kv_buffers.append(torch.empty_like(kv_buffers[0])) diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 952693ec2665..952551634eca 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -153,7 +153,6 @@ def dist_cross_entropy( labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] logits: torch.Tensor, # [B, S, Vocab_size] shard_config: ShardConfig, - out_features: int, vocab_size: int, dtype: torch.dtype, seq_dim: int = 1, @@ -226,13 +225,13 @@ def dist_cross_entropy( logits, labels, process_group=shard_config.tensor_parallel_process_group, - vocab_size=out_features, + vocab_size=vocab_size, dtype=dtype, mode="sum", ) else: # NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D - logits = logits.view(-1, vocab_size) + logits = logits.view(-1, logits.size(-1)) loss = loss_fct(logits, labels) # Reduce loss instead of gathering logits over seq dim for savings diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 26ffef6c5ee0..daa2296dd338 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -359,14 +359,15 @@ def bloom_for_causal_lm_forward( hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states).contiguous() - loss = dist_cross_entropy( - labels, - lm_logits, - shard_config, - self.lm_head.out_features, - self.config.vocab_size, - self.transformer.dtype, - ) + loss = None + if labels is not None: + loss = dist_cross_entropy( + labels, + lm_logits, + shard_config, + self.lm_head.out_features, + self.transformer.dtype, + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] @@ -1024,9 +1025,11 @@ def forward( hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) - loss = dist_cross_entropy( - labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy( + labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index 5be4b9d78e11..2e12e78378ef 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -183,6 +183,15 @@ def chatglm_model_forward( if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + # Support SP + PP + sp_size = shard_config.sequence_parallel_size + sp_mode = shard_config.sequence_parallelism_mode + sp_group = shard_config.sequence_parallel_process_group + # For generating full positions ids (the states will be gathered along the seq dim before attention fwd). + if sp_mode != "ring_attn" and not stage_manager.is_first_stage(): + seq_length *= sp_size + # Rotary positional embeddings rotary_pos_emb = self.rotary_pos_emb(self.seq_length) if position_ids is not None: @@ -206,11 +215,11 @@ def chatglm_model_forward( # Keep the input split across all PP stages if stage_manager.is_first_stage(): if shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode == "split_gather": + if sp_mode == "split_gather": hidden_states = split_forward_gather_backward( hidden_states, dim=0, - process_group=shard_config.tensor_parallel_process_group, + process_group=sp_group, ) elif shard_config.sequence_parallelism_mode == "all_to_all": hidden_states = split_forward_gather_backward( @@ -255,7 +264,9 @@ def chatglm_model_forward( # Gather seq-wise in the final output stage sp_mode = shard_config.sequence_parallelism_mode if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): - hidden_states = gather_sp_output(hidden_states, shard_config.sequence_parallel_process_group, sp_mode) + hidden_states = gather_sp_output( + hidden_states, shard_config.sequence_parallel_process_group, sp_mode, sp_dim=0 + ) if not return_dict: return tuple( @@ -321,9 +332,21 @@ def chatglm_for_conditional_generation_forward( hidden_states = hidden_states[-1:] lm_logits = self.transformer.output_layer(hidden_states) lm_logits = lm_logits.transpose(0, 1).contiguous() - loss = dist_cross_entropy( - labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, lm_logits.dtype - ) + + loss = None + if labels is not None: + # ChatGLM doesn't have lm_head split + enable_tp = shard_config.enable_tensor_parallelism + shard_config.enable_tensor_parallelism = False + loss = dist_cross_entropy( + labels, + lm_logits, + shard_config, + self.transformer.output_layer.out_features, + lm_logits.dtype, + ) + shard_config.enable_tensor_parallelism = enable_tp + if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output @@ -424,7 +447,9 @@ def forward( ) if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): - hidden_states = gather_sp_output(hidden_states, shard_config.sequence_parallel_process_group, sp_mode) + hidden_states = gather_sp_output( + hidden_states, shard_config.sequence_parallel_process_group, sp_mode, sp_dim=0 + ) if not return_dict: return tuple( diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index cac325dcbea6..bdcf6f0a2f69 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -93,10 +93,16 @@ def command_model_forward( if not isinstance(past_key_values, StaticCache): past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_seen_tokens = past_key_values.get_seq_length() + + # NOTE: For generating full positions ids + # (the states will be gathered along the seq dim before attention fwd). + if shard_config.sequence_parallelism_mode != "ring_attn" and not stage_manager.is_first_stage(): + seq_length *= shard_config.sequence_parallel_size + if cache_position is None: if isinstance(past_key_values, StaticCache): raise ValueError("cache_position is a required argument when using StaticCache.") - cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device) + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device) seq_length_with_past = seq_length + past_seen_tokens @@ -136,7 +142,7 @@ def command_model_forward( ) use_cache = False - if shard_config.enable_sequence_parallelism: + if stage_manager.is_first_stage() and shard_config.enable_sequence_parallelism: if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: hidden_states = split_forward_gather_backward( hidden_states, @@ -320,9 +326,10 @@ def command_for_causal_lm_forward( logits = self.lm_head(hidden_states) logits = logits * self.logit_scale logits = logits.float() - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype - ) + + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype) if not return_dict: output = (logits,) + outputs[1:] @@ -659,14 +666,16 @@ def forward( logits = self.lm_head(hidden_states) logits = logits * self.logit_scale logits = logits.float() - loss = dist_cross_entropy( - labels, - logits, - shard_config, - self.lm_head.out_features, - self.config.vocab_size, - self.model.dtype, - ) + + loss = None + if labels is not None: + loss = dist_cross_entropy( + labels, + logits, + shard_config, + self.lm_head.out_features, + self.model.dtype, + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 6ecda91c4d35..db38c9a0ec33 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -372,9 +372,11 @@ def gpt2_lmhead_model_forward( hidden_states = outputs[0] lm_logits = self.lm_head(hidden_states) - loss = dist_cross_entropy( - labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy( + labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype + ) if not return_dict: output = (lm_logits,) + outputs[1:] @@ -1264,9 +1266,11 @@ def forward( hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) - loss = dist_cross_entropy( - labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy( + labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 11875e90e595..55103bded04e 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -346,9 +346,9 @@ def llama_for_causal_lm_forward( if stage_manager.is_last_stage(): hidden_states = outputs[0] logits = self.lm_head(hidden_states) - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype) if not return_dict: output = (logits,) + outputs[1:] @@ -859,9 +859,9 @@ def forward( else: logits = self.lm_head(hidden_states) logits = logits.float() - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index ec1a8a00a58a..7fc6a1062037 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -274,10 +274,9 @@ def mistral_for_causal_lm_forward( hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() - - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype) if not return_dict: output = (logits,) + outputs[1:] @@ -687,10 +686,9 @@ def forward( hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() - - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 636b46cc461d..3ea4db9e2f70 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -330,14 +330,15 @@ def opt_for_causal_lm_forward( ) if stage_manager.is_last_stage(): logits = self.lm_head(outputs[0]).contiguous() - loss = dist_cross_entropy( - labels, - logits, - shard_config, - self.lm_head.out_features, - self.config.vocab_size, - self.model.decoder.dtype, - ) + loss = None + if labels is not None: + loss = dist_cross_entropy( + labels, + logits, + shard_config, + self.lm_head.out_features, + self.model.decoder.dtype, + ) if not return_dict: output = (logits,) + outputs[1:] @@ -955,9 +956,9 @@ def forward( ) logits = self.lm_head(outputs[0]).contiguous() - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.decoder.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.decoder.dtype) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index d44c7382fdf6..353ccc2f5947 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -350,9 +350,9 @@ def qwen2_for_causal_lm_forward( if hidden_states.shape[1] == 2: pass logits = self.lm_head(hidden_states) - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype) if not return_dict: output = (logits,) + outputs[1:] @@ -824,9 +824,9 @@ def forward( hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index 3877bdac3ae2..f99d1ef819b7 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -64,7 +64,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if sp_mode == "ring": warnings.warn( - f"For ChatGLM2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" + f"For ChatGLM2, sequence parallelism doesn't support mode {sp_mode} yet, will set to be split_gather" ) sp_mode = "split_gather" overlap = self.shard_config.enable_sequence_overlap diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index 92c077950ecc..17a8bf318976 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -136,26 +136,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - { # Ulysess + Flash attention - "tp_size": 1, + { + "tp_size": 2, "pp_size": 2, - "sp_size": 2, "num_microbatches": 2, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", + "sequence_parallelism_mode": "split_gather", "enable_flash_attention": True, "use_lazy_init": True, "zero_stage": 1, "precision": "fp16", "initial_scale": 1, }, - { - "tp_size": 2, + { # Ulysess + Flash attention + "tp_size": 1, "pp_size": 2, "sp_size": 2, "num_microbatches": 2, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "split_gather", + "sequence_parallelism_mode": "all_to_all", "enable_flash_attention": True, "use_lazy_init": True, "zero_stage": 1, @@ -174,17 +173,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, - { - "tp_size": 4, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring", - "enable_flash_attention": False, - "use_lazy_init": True, - "precision": "fp32", - "initial_scale": 1, - }, { "tp_size": 4, "pp_size": 1, @@ -248,7 +236,11 @@ def run_chatglm_test(test_config): loss_fn, _, ) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + try: + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + except Exception as e: + print(f"Test config failed for model {name}: {test_config}") + raise e clear_layout_converter() torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py index 2e6997597928..9435ef84bfa8 100644 --- a/tests/test_shardformer/test_model/test_shard_command.py +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -281,7 +281,11 @@ def run_command_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + try: + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + except Exception as e: + print(f"Failed test config: {test_config}") + raise e clear_layout_converter() Randomizer.reset_index() From 70b1f5d84783057316e9a4a94883b8ce8503ecd0 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 15 Aug 2024 03:29:42 +0000 Subject: [PATCH 53/71] update logger --- colossalai/shardformer/layer/attn.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 276c7c576c5d..56a039b3f9f9 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -24,7 +24,6 @@ _flash_attn_forward = _flash_attn_backward = None _unpad_input = _pad_input = None -logger = get_dist_logger() class AttnMaskType(Enum): @@ -434,7 +433,6 @@ def get_double_ring_groups(sp_group, inner_ring_size=None): assert ( sp_size % inner_ring_size == 0 ), f"sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}" - logger = get_dist_logger() logger.info( f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!", From f91586f05207e6bce4afdefee392fb2e109c2812 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 15 Aug 2024 03:52:39 +0000 Subject: [PATCH 54/71] fix typo --- colossalai/shardformer/modeling/llama.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 55103bded04e..6ba6665cfe7b 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -827,15 +827,6 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output: - # Special processing: Split labels in a zigzag fashion too - sp_group = shard_config.sequence_parallel_process_group - if attention_mask.bool().all(): - labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True) - else: - # [B, max_seq_len // sp_size] - labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, From 2ce53f13f0d7c5f438adbe67e4cc717ffcc71ae1 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 16 Aug 2024 09:28:18 +0000 Subject: [PATCH 55/71] rebase --- tests/kit/model_zoo/transformers/llama.py | 3 --- .../test_model/test_shard_llama.py | 19 ------------------- 2 files changed, 22 deletions(-) diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index db69c9818411..943c5cf1c58e 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -80,7 +80,6 @@ def data_gen_for_causal_lm(): data_gen_fn=data_gen_for_causal_lm, output_transform_fn=output_transform_fn, loss_fn=loss_fn_for_causal_lm, -<<<<<<< HEAD model_attribute=ModelAttribute(has_control_flow=True), ) model_zoo.register( @@ -97,8 +96,6 @@ def data_gen_for_causal_lm(): data_gen_fn=data_gen, output_transform_fn=output_transform_fn, loss_fn=loss_fn, -======= ->>>>>>> precision tests passed model_attribute=ModelAttribute(has_control_flow=True), ) model_zoo.register( diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 904d7281f9d3..e8f7916972a5 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -214,22 +214,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "sequence_parallelism_mode": "all_to_all", "enable_all_optimization": True, "use_lazy_init": True, -<<<<<<< HEAD "zero_stage": 1, -======= - "zero_stage": 0, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 4, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": True, - "use_lazy_init": True, ->>>>>>> precision tests passed "precision": "fp16", "initial_scale": 1, }, @@ -385,8 +370,4 @@ def test_llama_3d(): if __name__ == "__main__": test_llama() -<<<<<<< HEAD test_llama_3d() -======= - # test_llama_3d() ->>>>>>> precision tests passed From 5fed9dac3aaa1668bb485b25bd3a55a714e745b2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 16 Aug 2024 09:30:41 +0000 Subject: [PATCH 56/71] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 2 +- colossalai/shardformer/layer/attn.py | 2 +- colossalai/shardformer/layer/loss.py | 2 +- colossalai/shardformer/layer/utils.py | 2 +- tests/test_shardformer/test_layer/test_ring_attn.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 6c785d4aed4d..63427192f482 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1458,4 +1458,4 @@ def enable_lora( peft_model = get_peft_model(model, lora_config) else: peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True) - return peft_model \ No newline at end of file + return peft_model diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 56a039b3f9f9..35e85c454fbb 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -1162,4 +1162,4 @@ def prepare_varlen_batch( mask_info["valid_indices"] = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() mask_info["cu_seqlens"] //= sp_size mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL - return inputs_embeds, mask_info, position_ids \ No newline at end of file + return inputs_embeds, mask_info, position_ids diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 952551634eca..0e2241af9fc9 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -242,4 +242,4 @@ def dist_cross_entropy( loss = reduce_forward(loss, sp_group, grad_scale=sp_size) loss, num_nonzero = loss[0], loss[1].detach() loss = (loss / num_nonzero).squeeze() - return loss \ No newline at end of file + return loss diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index b343998082e6..c1a73ce05c97 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -484,4 +484,4 @@ def get_half_index(cu_seqlens, *, front: bool): else: start = (start + end) // 2 index[start:end] = True - return index \ No newline at end of file + return index diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index ce5820cf48d0..bd499d954fef 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -1,7 +1,7 @@ import torch import torch.distributed as dist import torch.nn.functional as F -from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_func +from flash_attn import flash_attn_varlen_qkvpacked_func from torch.testing import assert_close import colossalai From 8ec9009c88964f43a3b9310891d461ba5c5b70eb Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 16 Aug 2024 09:38:03 +0000 Subject: [PATCH 57/71] fix typo --- colossalai/shardformer/layer/_operation.py | 4 ---- colossalai/shardformer/layer/attn.py | 11 +---------- colossalai/shardformer/policies/command.py | 8 -------- colossalai/shardformer/policies/llama.py | 11 ----------- examples/language/opt/opt_benchmark.py | 1 + tests/kit/model_zoo/transformers/llama.py | 16 ---------------- .../test_layer/test_ring_attn.py | 6 +++++- 7 files changed, 7 insertions(+), 50 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index e031fecc15e0..efe4d80babbb 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -812,11 +812,7 @@ def backward(ctx, *grad_output): process_group = ctx.process_group scatter_dim = ctx.gather_dim gather_dim = ctx.scatter_dim - if torch.distributed.get_rank() == 0: - print(f"shape before A2A: {grad_output[0].shape}") return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) - if torch.distributed.get_rank() == 0: - print(f"shape after A2A: {return_grad.shape}") return (return_grad, None, None, None) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 35e85c454fbb..b11052257c5e 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -654,11 +654,6 @@ def forward( if sp_rank != sp_size - 1: q1 = q[half_idx_back] - # Non-contiguous indexing creates a new contiguous tensor, - # so only do it once - if sp_rank != sp_size - 1: - q1 = q[half_idx_back] - # Pre-allocate double buffer for overlapping and receiving next step's inputs kv_buffers = [torch.stack((k, v))] # (2, B, Sq, H, D) kv_buffers.append(torch.empty_like(kv_buffers[0])) @@ -894,7 +889,7 @@ def backward(ctx, dout, _): local_sp_rank = dist.get_rank(sp_group) sp_size = dist.get_world_size(sp_group) - sp_rank = dist.get_rank(sp_group) + dist.get_rank(sp_group) # Using separate streams (pg) for concurrent kv and dkv comm may # cause NCCL "software caused connection abort" here... local_kv_comm = RingComm(local_kv_group) @@ -915,10 +910,6 @@ def backward(ctx, dout, _): softmax_lse1 = softmax_lse[:, half_idx_back] dout = dout.contiguous() - if sp_rank != sp_size - 1: - softmax_lse1 = softmax_lse[:, half_idx_back] - dout = dout.contiguous() - # Double comm buffers for sending and receiving kv kv_buffers = [torch.stack((k, v))] # (2, T, H, D) kv_buffers.append(torch.empty_like(kv_buffers[0])) diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index 23e31f2e5dac..1efd3d0179af 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -292,15 +292,7 @@ class CommandForCausalLMPolicy(CommandPolicy): def module_policy(self): from transformers import CohereForCausalLM -<<<<<<< HEAD -<<<<<<< HEAD self.is_causal = True -======= - self.is_casual = True ->>>>>>> precision tests passed -======= - self.is_causal = True ->>>>>>> precision tests passed policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 5483719aa212..76b824d8dd14 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -71,22 +71,11 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: sp_partial_derived = sp_mode in ["split_gather", "ring"] if sp_mode == "ring_attn" and not self.is_causal: raise ValueError("Ring attention is only meant for causal language modeling.") -<<<<<<< HEAD -<<<<<<< HEAD tp_size = self.shard_config.tensor_parallel_size # Modified by SP and TP num_q_heads = self.model.config.num_attention_heads num_kv_heads = getattr(self.model.config, "num_key_value_heads", None) -======= ->>>>>>> precision tests passed - - tp_size = self.shard_config.tensor_parallel_size - # Modified by SP and TP - num_q_heads = self.model.config.num_attention_heads - num_kv_heads = getattr(self.model.config, "num_key_value_heads", None) -======= ->>>>>>> precision tests passed tp_size = self.shard_config.tensor_parallel_size # Modified by SP and TP diff --git a/examples/language/opt/opt_benchmark.py b/examples/language/opt/opt_benchmark.py index 5e5971d9f560..ca9b63d1a14a 100755 --- a/examples/language/opt/opt_benchmark.py +++ b/examples/language/opt/opt_benchmark.py @@ -96,6 +96,7 @@ def main(): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) model, optimizer, _, _, _ = booster.boost(model, optimizer) + SEQ_LEN = 1024 VOCAB_SIZE = 50257 diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 943c5cf1c58e..05ac9d8d24ed 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -90,22 +90,6 @@ def data_gen_for_causal_lm(): loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True), ) - model_zoo.register( - name="transformers_llama", - model_fn=lambda: transformers.LlamaModel(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True), - ) - model_zoo.register( - name="transformers_llama", - model_fn=lambda: transformers.LlamaModel(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True), - ) model_zoo.register( name="transformers_llama_for_sequence_classification", model_fn=lambda: transformers.LlamaForSequenceClassification(config), diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index bd499d954fef..1c7647a7d560 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -1,7 +1,7 @@ import torch import torch.distributed as dist import torch.nn.functional as F -from flash_attn import flash_attn_varlen_qkvpacked_func +from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_func from torch.testing import assert_close import colossalai @@ -39,6 +39,7 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): ring_out, ring_lse = RingAttention.attention( q, k, + v, sp_group, AttnMaskType.CAUSAL, return_softmax=True, @@ -46,6 +47,9 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): # inner_ring_size=4 ) ring_out = ring_out.transpose(1, 2) + out, lse, _ = flash_attn_qkvpacked_func( + qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True + ) # Checkout out and softmax denominator local_out = split_batch_zigzag(out, sp_group) From d0aeec9ed58f007656adf383bfa75f9ad16eeadc Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 16 Aug 2024 11:35:32 +0000 Subject: [PATCH 58/71] remove typos --- colossalai/shardformer/layer/attn.py | 2 +- colossalai/shardformer/modeling/qwen2.py | 5 +---- colossalai/shardformer/policies/llama.py | 5 ----- 3 files changed, 2 insertions(+), 10 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index b11052257c5e..0e9f373553ac 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -889,7 +889,7 @@ def backward(ctx, dout, _): local_sp_rank = dist.get_rank(sp_group) sp_size = dist.get_world_size(sp_group) - dist.get_rank(sp_group) + # Using separate streams (pg) for concurrent kv and dkv comm may # cause NCCL "software caused connection abort" here... local_kv_comm = RingComm(local_kv_group) diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 353ccc2f5947..3e04f3e103de 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -534,10 +534,7 @@ def forward( # Because the input can be padded, the absolute sequence length depends on the max position id. rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - try: - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - except Exception as e: - raise e + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 76b824d8dd14..f72a72df0b1b 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -77,11 +77,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: num_q_heads = self.model.config.num_attention_heads num_kv_heads = getattr(self.model.config, "num_key_value_heads", None) - tp_size = self.shard_config.tensor_parallel_size - # Modified by SP and TP - num_q_heads = self.model.config.num_attention_heads - num_kv_heads = getattr(self.model.config, "num_key_value_heads", None) - if sp_mode == "all_to_all": num_q_heads //= sp_size decoder_attribute_replacement = {"num_heads": num_q_heads} From de0afd17493bd4c591f82c96ec87c4019964b7db Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Sun, 18 Aug 2024 09:35:38 +0000 Subject: [PATCH 59/71] fixes --- colossalai/shardformer/modeling/llama.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 6ba6665cfe7b..e978f7c558c5 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -826,6 +826,14 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output: + # Special processing: Split labels in a zigzag fashion too + sp_group = shard_config.sequence_parallel_process_group + if attention_mask.bool().all(): + labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True) + else: + # [B, max_seq_len // sp_size] + labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( From 2fb7db6983e7d6fa62c25f91458671892d659381 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 19 Aug 2024 07:16:27 +0000 Subject: [PATCH 60/71] support GPT --- colossalai/shardformer/modeling/chatglm2.py | 21 +- colossalai/shardformer/modeling/command.py | 12 +- colossalai/shardformer/modeling/gpt2.py | 256 +++--------------- colossalai/shardformer/modeling/llama.py | 12 +- colossalai/shardformer/modeling/qwen2.py | 12 +- colossalai/shardformer/policies/gpt2.py | 20 +- tests/test_shardformer/test_model/_utils.py | 1 - .../test_model/test_shard_gpt2.py | 24 +- 8 files changed, 93 insertions(+), 265 deletions(-) diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index 2e12e78378ef..16e12312b89a 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -262,11 +262,12 @@ def chatglm_model_forward( hidden_states = self.encoder.final_layernorm(hidden_states) # Gather seq-wise in the final output stage - sp_mode = shard_config.sequence_parallelism_mode - if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): - hidden_states = gather_sp_output( - hidden_states, shard_config.sequence_parallel_process_group, sp_mode, sp_dim=0 - ) + if shard_config.enable_sequence_parallelism: + sp_mode = shard_config.sequence_parallelism_mode + if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + hidden_states = gather_sp_output( + hidden_states, shard_config.sequence_parallel_process_group, sp_mode, sp_dim=0 + ) if not return_dict: return tuple( @@ -445,11 +446,11 @@ def forward( use_cache=use_cache, output_hidden_states=output_hidden_states, ) - - if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): - hidden_states = gather_sp_output( - hidden_states, shard_config.sequence_parallel_process_group, sp_mode, sp_dim=0 - ) + if shard_config.enable_sequence_parallelism: + if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + hidden_states = gather_sp_output( + hidden_states, shard_config.sequence_parallel_process_group, sp_mode, sp_dim=0 + ) if not return_dict: return tuple( diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index bdcf6f0a2f69..42cad19f375d 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -213,8 +213,11 @@ def command_model_forward( if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) sp_mode = shard_config.sequence_parallelism_mode - if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): - hidden_states = gather_sp_output(hidden_states, shard_config.sequence_parallel_process_group, sp_mode) + if shard_config.enable_sequence_parallelism: + if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + hidden_states = gather_sp_output( + hidden_states, shard_config.sequence_parallel_process_group, sp_mode + ) # add hidden states from the last decoder layer if output_hidden_states: @@ -572,8 +575,9 @@ def forward( hidden_states = self.norm(hidden_states) # Cases that don't support parallelizing cross entropy computation along sequence - if shard_config and (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather: - hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) + if shard_config.enable_sequence_parallelism: + if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather: + hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) # add hidden states from the last decoder layer if output_hidden_states: diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index db38c9a0ec33..4392b87b0fbb 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -22,7 +22,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import ColoAttention -from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward +from colossalai.shardformer.layer._operation import gather_sp_output, split_forward_gather_backward +from colossalai.shardformer.layer.utils import is_share_sp_tp from colossalai.shardformer.shard import ShardConfig from ..layer import dist_cross_entropy @@ -123,6 +124,7 @@ def gpt2_model_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, + force_sp_output_gather: Optional[bool] = True, ) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. # Please refer to original code of transformers for more details. @@ -152,10 +154,8 @@ def gpt2_model_forward( elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) - input_ids.shape[0] elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] - inputs_embeds.shape[0] else: raise ValueError("You have to specify either input_ids or inputs_embeds") @@ -190,8 +190,6 @@ def gpt2_model_forward( hidden_states = hidden_states + token_type_embeds hidden_states = self.drop(hidden_states) - output_shape = input_shape + (hidden_states.size(-1),) - attention_mask, encoder_attention_mask = _get_attention_mask( self, shard_config, @@ -215,8 +213,9 @@ def gpt2_model_forward( # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - if shard_config and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode == "split_gather": + sp_mode = shard_config.sequence_parallelism_mode + if stage_manager.is_first_stage() and shard_config.enable_sequence_parallelism: + if sp_mode == "split_gather": hidden_states = split_forward_gather_backward( hidden_states, dim=1, @@ -269,18 +268,21 @@ def gpt2_model_forward( if self.config.add_cross_attention: all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) - # When sequence parallelism done, gather the output tensor in forward and split it in backward - if shard_config and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode == "split_gather": - hidden_states = gather_forward_split_backward( + # When sequence parallelism is done, gather the output tensor in forward and split it in backward + if stage_manager.is_last_stage() and shard_config.enable_sequence_parallelism: + if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + hidden_states = gather_sp_output( hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group, + sp_dim=1, + sp_group=shard_config.tensor_parallel_process_group, + sp_mode=shard_config.sequence_parallelism_mode, ) + # gather_sp_output could've changed seq length. + input_shape = (*input_shape[:-1], hidden_states.size(-2)) + output_shape = input_shape + (hidden_states.size(-1),) if stage_manager.is_last_stage(): hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(output_shape) # Add last hidden state @@ -364,6 +366,7 @@ def gpt2_lmhead_model_forward( hidden_states=hidden_states, stage_index=stage_index, shard_config=shard_config, + force_sp_output_gather=False, ) # If not at the last stage, return hidden_states as in GPT2Model @@ -372,12 +375,7 @@ def gpt2_lmhead_model_forward( hidden_states = outputs[0] lm_logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - loss = dist_cross_entropy( - labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype - ) - + loss = dist_cross_entropy(labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype) if not return_dict: output = (lm_logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output @@ -828,196 +826,7 @@ def forward( return forward -def get_gpt_model_forward_for_flash_attn(shard_config: ShardConfig): - def forward( - self: GPT2Model, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - input_ids.shape[0] - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - inputs_embeds.shape[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if position_ids is not None: - position_ids = position_ids.view(-1, input_shape[-1]) - - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) - if position_ids is None: - position_ids = torch.arange( - past_length, - input_shape[-1] + past_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # head_mask has shape n_layer x batch x n_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds - - if token_type_ids is not None: - token_type_embeds = self.wte(token_type_ids) - hidden_states = hidden_states + token_type_embeds - - hidden_states = self.drop(hidden_states) - - output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) - - attention_mask, encoder_attention_mask = _get_attention_mask( - self, - shard_config, - hidden_states, - past_key_values, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - ) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - all_hidden_states = () if output_hidden_states else None - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure layer_past is on same device as hidden_states (might not be correct) - if layer_past is not None: - layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) - # Ensure that attention_mask is always on the same device as hidden_states - if torch.is_tensor(attention_mask): - attention_mask = attention_mask.to(hidden_states.device) - if isinstance(head_mask, torch.Tensor): - head_mask = head_mask.to(hidden_states.device) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) - - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) - - hidden_states = self.ln_f(hidden_states) - - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - presents, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - - return forward - - -def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): +def get_gpt2_flash_attn_model_forward(shard_config: ShardConfig): def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1033,6 +842,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + force_sp_output_gather: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1046,10 +856,8 @@ def forward( elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) - input_ids.shape[0] elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] - inputs_embeds.shape[0] else: raise ValueError("You have to specify either input_ids or inputs_embeds") @@ -1184,11 +992,18 @@ def custom_forward(*inputs): hidden_states = hidden_states.to("cuda:" + str(k + 1)) # When sequence parallelism done, gather the output tensor in forward and split it in backward - hidden_states = gather_forward_split_backward( - hidden_states, - dim=1, - process_group=shard_config.sequence_parallel_process_group, - ) + if shard_config.enable_sequence_parallelism: + if ( + (not shard_config.parallel_output) + or force_sp_output_gather + or is_share_sp_tp(shard_config.sequence_parallelism_mode) + ): + hidden_states = gather_sp_output( + hidden_states, + sp_dim=1, + sp_group=shard_config.sequence_parallel_process_group, + sp_mode=shard_config.sequence_parallelism_mode, + ) hidden_states = self.ln_f(hidden_states) hidden_states = hidden_states.view(output_shape) @@ -1262,15 +1077,12 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + force_sp_output_gather=False, ) hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - loss = dist_cross_entropy( - labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype - ) + loss = dist_cross_entropy(labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index e978f7c558c5..86c8d0a2796d 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -226,8 +226,9 @@ def llama_model_forward( if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) - if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): - hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) + if shard_config.enable_sequence_parallelism: + if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) # add hidden states from the last decoder layer if output_hidden_states: @@ -752,9 +753,10 @@ def forward( all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) - # Cases that don't support parallelizing cross entropy computation along sequence - if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather: - hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) + if shard_config.enable_sequence_parallelism: + # Cases that don't support parallelizing cross entropy computation along sequence + if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather: + hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) # add hidden states from the last decoder layer if output_hidden_states: diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 3e04f3e103de..279af942becd 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -245,8 +245,11 @@ def qwen2_model_forward( if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) - if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): - hidden_states = gather_sp_output(hidden_states, shard_config.sequence_parallel_process_group, sp_mode) + if shard_config.enable_sequence_parallelism: + if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + hidden_states = gather_sp_output( + hidden_states, shard_config.sequence_parallel_process_group, sp_mode + ) # add hidden states from the last decoder layer if output_hidden_states: @@ -737,8 +740,9 @@ def forward( hidden_states = self.norm(hidden_states) - if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): - hidden_states = gather_sp_output(hidden_states, shard_config.sequence_parallel_process_group, sp_mode) + if shard_config.enable_sequence_parallelism: + if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + hidden_states = gather_sp_output(hidden_states, shard_config.sequence_parallel_process_group, sp_mode) # add hidden states from the last decoder layer if output_hidden_states: diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index cfe20000a2bf..83829757f804 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -9,10 +9,9 @@ from ..modeling.gpt2 import ( GPT2PipelineForwards, get_gpt2_flash_attention_forward, - get_gpt_model_forward_for_flash_attn, + get_gpt2_flash_attn_model_forward, get_jit_fused_gpt2_mlp_forward, get_lm_forward_with_dist_cross_entropy, - gpt2_sequence_parallel_forward_fn, ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -75,14 +74,6 @@ def module_policy(self): overlap = self.shard_config.enable_sequence_overlap sp_partial_derived = sp_mode in ["split_gather", "ring"] use_flash_attention = self.shard_config.enable_flash_attention - # todo: currently sp cannot be used with flashattention - if sp_mode in ["split_gather", "ring", "all_to_all"]: - if use_flash_attention: - warnings.warn( - f"Sequence parallelism mode {sp_mode} cannot be used with FlashAttention, will disable FlashAttention automatically." - ) - self.shard_config.enable_flash_attention = False - use_flash_attention = False if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -211,13 +202,10 @@ def module_policy(self): policy=policy, target_key=attn_cls, ) - if not self.shard_config.pipeline_stage_manager: - policy[GPT2Model].method_replacement = { - "forward": get_gpt_model_forward_for_flash_attn(self.shard_config) - } - if sp_mode is not None: - policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} + # This supports SP + flash attn + if not self.shard_config.pipeline_stage_manager: + policy[GPT2Model].method_replacement = {"forward": get_gpt2_flash_attn_model_forward(self.shard_config)} return policy diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 9ad84341ac9e..c94d049fc13d 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -323,7 +323,6 @@ def check_output_hidden_state( sp_size = shard_config.sequence_parallel_size if org_hidden_state.shape[seq_dim] == sharded_hidden_state.shape[seq_dim] * sp_size: org_hidden_state = org_hidden_state.chunk(sp_size, dim=seq_dim)[dist.get_rank(sp_group)] - assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index f9e368c0ebf3..9db5c0c680c9 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -100,7 +100,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 5e-3 if org_model.__class__.__name__ == "GPT2Model": - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + check_output_hidden_state( + org_output, + sharded_output, + stage_manager, + atol=atol, + rtol=rtol, + shard_config=booster.plugin.shard_config, + ) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) @@ -137,7 +144,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "ring", - "enable_flash_attention": False, + "enable_flash_attention": True, "use_lazy_init": True, "precision": "fp32", "initial_scale": 1, @@ -148,7 +155,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": False, + "enable_flash_attention": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": True, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, From 63fd07581618f315eb3e660b21952eee2ac77e7f Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 20 Aug 2024 07:40:44 +0000 Subject: [PATCH 61/71] fix gpt2 --- colossalai/shardformer/modeling/gpt2.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 4392b87b0fbb..12aa5c74d5f8 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -925,11 +925,13 @@ def forward( # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - hidden_states = split_forward_gather_backward( - hidden_states, - dim=1, - process_group=shard_config.sequence_parallel_process_group, - ) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode == "split_gather": + hidden_states = split_forward_gather_backward( + hidden_states, + dim=1, + process_group=shard_config.sequence_parallel_process_group, + ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): # Model parallel From 17002b6a7742987fd8bfcef1bf372db3c7490f67 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 22 Aug 2024 11:33:49 +0000 Subject: [PATCH 62/71] gpt ring attn + TP passed --- colossalai/shardformer/layer/attn.py | 11 +- .../shardformer/layer/qkv_fused_linear.py | 21 +- colossalai/shardformer/layer/utils.py | 6 + colossalai/shardformer/modeling/gpt2.py | 384 ++++-------------- colossalai/shardformer/modeling/llama.py | 282 +------------ colossalai/shardformer/policies/gpt2.py | 18 +- colossalai/shardformer/policies/llama.py | 18 +- tests/kit/model_zoo/transformers/gpt.py | 11 +- tests/test_shardformer/test_model/_utils.py | 1 - .../test_model/test_shard_gpt2.py | 29 +- 10 files changed, 167 insertions(+), 614 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 08dee96613e0..bf4fa77c6c23 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -1119,9 +1119,14 @@ def prepare_varlen_batch( the batch dim to a packed 1d sequence. Contingent on model forward shape definitions. Returns: - inputs_embeds: Packed input embeddings of shape [B, Sq // sp_size, ...]. - mask_info: A dictionary of mask info. - position_ids: Packed position ids of shape [..., Sq // sp_size]. + torch.Tensor: + Packed input embeddings of shape [B, Sq // sp_size, ...]. + + Dict[str, Any]: + A dictionary containing mask info. + + torch.Tensor: + Packed position ids of shape [..., Sq // sp_size]. """ _load_varlen_helpers() diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 000934ad91a2..88a32121f0b8 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -311,14 +311,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - - if self.seq_parallel_mode is None: - # Set up backprop all-reduce. - input_parallel = reduce_backward(input_, self.process_group) - output_parallel = matmul_with_async_comm( - input_parallel, self.weight, bias, self.process_group, self.async_communication - ) - elif self.seq_parallel_mode == "split_gather": + if self.seq_parallel_mode == "split_gather": input_parallel = input_ output_parallel = matmul_gather_forward_reducescatter_backward( input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap @@ -328,6 +321,14 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: output_parallel = matmul_gather_forward_reducescatter_backward( input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap, True ) + elif self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn": + # Set up backprop all-reduce. + input_parallel = reduce_backward(input_, self.process_group) + output_parallel = matmul_with_async_comm( + input_parallel, self.weight, bias, self.process_group, self.async_communication + ) + else: + raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!") if self.gather_output: # All-gather across the partitions. @@ -533,7 +534,7 @@ def forward(self, input_: Tensor) -> Tensor: handle.wait() output = torch.cat(output_parallel_list, dim=-1) else: - if self.seq_parallel_mode is None: + if self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn": output_parallel = torch.matmul(input_, self.weight) output = reduce_forward(output_parallel, self.process_group) elif self.seq_parallel_mode == "split_gather": @@ -542,6 +543,8 @@ def forward(self, input_: Tensor) -> Tensor: elif self.seq_parallel_mode == "ring": output_parallel = torch.matmul(input_, self.weight) output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + else: + raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!") if not self.skip_bias_add: if self.bias is not None: diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index c1a73ce05c97..4512e0c680f3 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -309,6 +309,9 @@ def split_batch_zigzag( """ sp_size = dist.get_world_size(sp_group) sp_rank = dist.get_rank(sp_group) + if sp_size == 1: + return batch + if isinstance(batch, torch.Tensor): batch = [batch] seq_dim = seq_dim if seq_dim != -1 else batch[0].dim() - 1 @@ -364,6 +367,9 @@ def split_varlen_zigzag( """ sp_size = dist.get_world_size(sp_group) sp_rank = dist.get_rank(sp_group) + if sp_size == 1: + return batch + if is_2d: assert max_seqlen > 0, "max_seqlen must be provided for 2D input" diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 12aa5c74d5f8..65c26a71207a 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -21,9 +21,10 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.layer import ColoAttention +from colossalai.shardformer.layer import ColoAttention, RingAttention from colossalai.shardformer.layer._operation import gather_sp_output, split_forward_gather_backward -from colossalai.shardformer.layer.utils import is_share_sp_tp +from colossalai.shardformer.layer.attn import AttnMaskType +from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag from colossalai.shardformer.shard import ShardConfig from ..layer import dist_cross_entropy @@ -40,10 +41,16 @@ def _get_attention_mask( encoder_hidden_states: Optional[torch.Tensor], encoder_attention_mask: Optional[torch.FloatTensor], ) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]: - batch_size, seq_len = hidden_states.shape[:2] + # Received input is already split for non-first pipeline stages, + # but attn mask isn't + batch_size = hidden_states.size(0) + seq_len = attention_mask.size(-1) + + sp_mode = shard_config.sequence_parallelism_mode # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.add_cross_attention and encoder_hidden_states is not None: + assert not sp_mode == "ring_attn", "Ring Attention only supports decoder-only." encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() if shard_config.enable_flash_attention: encoder_attention_mask = ColoAttention.prepare_attn_kwargs( @@ -63,6 +70,7 @@ def _get_attention_mask( encoder_attention_mask = {"attention_mask": None} else: encoder_attention_mask = None + # GPT2Attention mask. past_key_values_length = 0 if past_key_values is not None and past_key_values[0] is not None: @@ -70,6 +78,7 @@ def _get_attention_mask( if shard_config.enable_flash_attention: if attention_mask is not None: attention_mask = attention_mask.view(batch_size, -1) + attention_mask = ColoAttention.prepare_attn_kwargs( (batch_size, 1, seq_len, seq_len + past_key_values_length), hidden_states.dtype, @@ -148,7 +157,8 @@ def gpt2_model_forward( logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False - if stage_manager.is_first_stage(): + disable_pp = stage_manager is None + if disable_pp or stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: @@ -176,7 +186,7 @@ def gpt2_model_forward( # head_mask has shape n_layer x batch x n_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) - if stage_manager.is_first_stage(): + if disable_pp or stage_manager.is_first_stage(): if position_ids is None: position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0) @@ -190,7 +200,7 @@ def gpt2_model_forward( hidden_states = hidden_states + token_type_embeds hidden_states = self.drop(hidden_states) - attention_mask, encoder_attention_mask = _get_attention_mask( + attn_kwargs, encoder_attention_mask = _get_attention_mask( self, shard_config, hidden_states, @@ -214,22 +224,40 @@ def gpt2_model_forward( # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] sp_mode = shard_config.sequence_parallelism_mode - if stage_manager.is_first_stage() and shard_config.enable_sequence_parallelism: - if sp_mode == "split_gather": - hidden_states = split_forward_gather_backward( - hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group, - ) + sp_group = shard_config.sequence_parallel_process_group + if shard_config.enable_sequence_parallelism: + # Ring Attention's special zigzag batch processing + if sp_mode == "ring_attn": + assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." + # Get cu_seqlens + if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: + hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch( + attention_mask, sp_group, hidden_states, position_ids + ) + else: + hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group) + # Other sp modes + elif disable_pp or stage_manager.is_first_stage(): + if sp_mode == "split_gather": + hidden_states = split_forward_gather_backward( + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + ) + del attention_mask # Going through held blocks. - start_idx, end_idx = stage_index[0], stage_index[1] + if disable_pp: + start_idx, end_idx = 0, len(self.h) + else: + start_idx, end_idx = stage_index[0], stage_index[1] + for i in range(start_idx, end_idx): block = self.h[i] torch.cuda.set_device(hidden_states.device) # Ensure that attention_mask is always on the same device as hidden_states - if torch.is_tensor(attention_mask): - attention_mask = attention_mask.to(hidden_states.device) + if torch.is_tensor(attn_kwargs): + attn_kwargs = attn_kwargs.to(hidden_states.device) if isinstance(head_mask, torch.Tensor): head_mask = head_mask.to(hidden_states.device) if output_hidden_states: @@ -240,7 +268,7 @@ def gpt2_model_forward( block.__call__, hidden_states, None, - attention_mask, + attn_kwargs, head_mask[i], encoder_hidden_states, encoder_attention_mask, @@ -251,7 +279,7 @@ def gpt2_model_forward( outputs = block( hidden_states, layer_past=None, - attention_mask=attention_mask, + attention_mask=attn_kwargs, head_mask=head_mask[i], encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, @@ -269,7 +297,7 @@ def gpt2_model_forward( all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) # When sequence parallelism is done, gather the output tensor in forward and split it in backward - if stage_manager.is_last_stage() and shard_config.enable_sequence_parallelism: + if (disable_pp or stage_manager.is_last_stage()) and shard_config.enable_sequence_parallelism: if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): hidden_states = gather_sp_output( hidden_states, @@ -277,11 +305,12 @@ def gpt2_model_forward( sp_group=shard_config.tensor_parallel_process_group, sp_mode=shard_config.sequence_parallelism_mode, ) + # gather_sp_output could've changed seq length. input_shape = (*input_shape[:-1], hidden_states.size(-2)) output_shape = input_shape + (hidden_states.size(-1),) - if stage_manager.is_last_stage(): + if disable_pp or stage_manager.is_last_stage(): hidden_states = self.ln_f(hidden_states) hidden_states = hidden_states.view(output_shape) @@ -289,7 +318,7 @@ def gpt2_model_forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if stage_manager.is_last_stage(): + if disable_pp or stage_manager.is_last_stage(): if not return_dict: return tuple( v @@ -370,12 +399,26 @@ def gpt2_lmhead_model_forward( ) # If not at the last stage, return hidden_states as in GPT2Model - if not stage_manager.is_last_stage(): + disable_pp = stage_manager is None + if (not disable_pp) and (not stage_manager.is_last_stage()): return {"hidden_states": outputs["hidden_states"]} hidden_states = outputs[0] lm_logits = self.lm_head(hidden_states) - loss = dist_cross_entropy(labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype) + if shard_config.sequence_parallelism_mode == "ring_attn": + # Split labels in a zigzag fashion too + sp_group = shard_config.sequence_parallel_process_group + if not attention_mask.bool().all(): + # [B, max_seqlen // sp_size] + labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) + else: + labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True) + + if labels is not None: + loss = dist_cross_entropy( + labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype + ) + if not return_dict: output = (lm_logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output @@ -768,7 +811,7 @@ def gpt2_for_sequence_classification_forward( ) -def get_gpt2_flash_attention_forward(): +def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention def forward( @@ -815,7 +858,22 @@ def forward( if self.scale_attn_by_inverse_layer_idx: scale /= float(self.layer_idx + 1) dropout_p = self.attn_dropout.p if self.training else 0.0 - attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) + + sp_mode = shard_config.sequence_parallelism_mode + sp_group = shard_config.sequence_parallel_process_group + if sp_mode == "ring_attn": + attn_output = RingAttention.attention( + query, + key, + value, + sp_group, + **attention_mask, + dropout_p=dropout_p, + scale=scale, + inner_ring_size=shard_config.inner_ring_size, + ) + else: + attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) @@ -826,282 +884,6 @@ def forward( return forward -def get_gpt2_flash_attn_model_forward(shard_config: ShardConfig): - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - force_sp_output_gather: Optional[bool] = True, - ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if position_ids is not None: - position_ids = position_ids.view(-1, input_shape[-1]) - - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) - if position_ids is None: - position_ids = torch.arange( - past_length, - input_shape[-1] + past_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # head_mask has shape n_layer x batch x n_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds - - if token_type_ids is not None: - token_type_embeds = self.wte(token_type_ids) - hidden_states = hidden_states + token_type_embeds - - hidden_states = self.drop(hidden_states) - - output_shape = input_shape + (hidden_states.size(-1),) - attention_mask, encoder_attention_mask = _get_attention_mask( - self, - shard_config, - hidden_states, - past_key_values, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - ) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger = logging.get_logger(__name__) - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - all_hidden_states = () if output_hidden_states else None - - # split the input tensor along sequence dimension - # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - if shard_config and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode == "split_gather": - hidden_states = split_forward_gather_backward( - hidden_states, - dim=1, - process_group=shard_config.sequence_parallel_process_group, - ) - - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure layer_past is on same device as hidden_states (might not be correct) - if layer_past is not None: - layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) - # Ensure that attention_mask is always on the same device as hidden_states - if torch.is_tensor(attention_mask): - attention_mask = attention_mask.to(hidden_states.device) - if isinstance(head_mask, torch.Tensor): - head_mask = head_mask.to(hidden_states.device) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) - - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) - - # When sequence parallelism done, gather the output tensor in forward and split it in backward - if shard_config.enable_sequence_parallelism: - if ( - (not shard_config.parallel_output) - or force_sp_output_gather - or is_share_sp_tp(shard_config.sequence_parallelism_mode) - ): - hidden_states = gather_sp_output( - hidden_states, - sp_dim=1, - sp_group=shard_config.sequence_parallel_process_group, - sp_mode=shard_config.sequence_parallelism_mode, - ) - - hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - presents, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - - return forward - - -def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): - from transformers import GPT2LMHeadModel - - def forward( - self: GPT2LMHeadModel, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - force_sp_output_gather=False, - ) - hidden_states = transformer_outputs[0] - - lm_logits = self.lm_head(hidden_states) - loss = dist_cross_entropy(labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - cross_attentions=transformer_outputs.cross_attentions, - ) - - return forward - - def get_jit_fused_gpt2_mlp_forward(): from transformers.models.gpt2.modeling_gpt2 import GPT2MLP diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 86c8d0a2796d..697b88870bf1 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -58,10 +58,7 @@ def llama_model_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, - # Split output only when computing cross entropy using llama_for_causal_lm_forward - # or get_lm_forward_with_dist_cross_entropy - # Default to True to avoid bug when calling classification forward from huggingface - force_sp_output_gather: bool = True, + force_sp_output_gather: bool = True, # Set to false only when computing cross entropy ): logger = logging.get_logger(__name__) @@ -78,8 +75,9 @@ def llama_model_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict + disable_pp = stage_manager is None # retrieve input_ids and inputs_embeds - if stage_manager.is_first_stage(): + if disable_pp or stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: @@ -101,8 +99,8 @@ def llama_model_forward( sp_mode = shard_config.sequence_parallelism_mode sp_group = shard_config.sequence_parallel_process_group sp_size = shard_config.sequence_parallel_size - # For generating full positions ids (the states will be gathered along the seq dim before attention fwd). - if sp_mode != "ring_attn" and not stage_manager.is_first_stage(): + # Generating full positions ids for seq that's gathered before attn + if not disable_pp and (sp_mode != "ring_attn" and not stage_manager.is_first_stage()): seq_length *= sp_size past_seen_tokens = 0 @@ -117,7 +115,6 @@ def llama_model_forward( seq_length_with_past = seq_length + past_seen_tokens - # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False @@ -132,7 +129,7 @@ def llama_model_forward( position_ids = cache_position.unsqueeze(0) # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage - if not stage_manager.is_first_stage() and sp_mode == "ring_attn": + if (disable_pp or not stage_manager.is_first_stage()) and sp_mode == "ring_attn": _, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group) elif shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor @@ -150,7 +147,7 @@ def llama_model_forward( # Support SP + PP # TODO: support padded casual cu_seqlens across stages - if stage_manager.is_first_stage(): + if disable_pp or stage_manager.is_first_stage(): # Ring Attention zigzag batch processing if sp_mode == "ring_attn": assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." @@ -177,8 +174,8 @@ def llama_model_forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None + start_idx, end_idx = (0, len(self.layers)) if disable_pp else (stage_index[0], stage_index[1]) - start_idx, end_idx = stage_index[0], stage_index[1] num_ckpt_layers = 0 if self.gradient_checkpointing and self.training: num_ckpt_layers = end_idx - start_idx @@ -224,7 +221,7 @@ def llama_model_forward( if output_attentions: all_self_attns += (layer_outputs[1],) - if stage_manager.is_last_stage(): + if disable_pp or stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) if shard_config.enable_sequence_parallelism: if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): @@ -234,7 +231,7 @@ def llama_model_forward( if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None - if stage_manager.is_last_stage(): + if disable_pp or stage_manager.is_last_stage(): if not return_dict: return tuple( v @@ -318,7 +315,7 @@ def llama_for_causal_lm_forward( # Split labels in a zigzag fashion too sp_group = shard_config.sequence_parallel_process_group if attention_mask.bool().all(): - labels = split_batch_zigzag(labels, sp_group, seq_dim=1) + labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True) else: # [B, max_seqlen // sp_size] labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) @@ -344,7 +341,8 @@ def llama_for_causal_lm_forward( ) past_key_values = None - if stage_manager.is_last_stage(): + disable_pp = stage_manager is None + if disable_pp or stage_manager.is_last_stage(): hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss = None @@ -622,257 +620,3 @@ def forward( return attn_output, attn_weights, past_key_value return forward - - -def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): - logger = logging.get_logger(__name__) - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - # Split output only when computing cross entropy using llama_for_causal_lm_forward - # or get_lm_forward_with_dist_cross_entropy - # Default to True to avoid bug when calling classification forward from huggingface - force_sp_output_gather: bool = True, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" - ) - - if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - past_seen_tokens = 0 - seq_len = inputs_embeds.shape[1] - batch_size = inputs_embeds.shape[0] - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() - - if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") - cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - if shard_config.enable_flash_attention: - mask_shape = (batch_size, 1, seq_len, past_seen_tokens + seq_len) - attn_kwargs: dict = ColoAttention.prepare_attn_kwargs( - mask_shape, - inputs_embeds.dtype, - inputs_embeds.device, - q_padding_mask=attention_mask, - is_causal=True, - invert=(sp_mode != "ring_attn"), - ) - - else: - attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) - - # Ring Attention zigzag batch processing - if sp_mode == "ring_attn": - assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." - if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: - inputs_embeds, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch( - attention_mask, sp_group, inputs_embeds, position_ids - ) - else: - inputs_embeds, position_ids = split_batch_zigzag([inputs_embeds, position_ids], sp_group) - attn_kwargs = {"attention_mask_type": attn_kwargs["attention_mask_type"]} # drop redundant tensors - - elif is_share_sp_tp(sp_mode): - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) - elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attn_kwargs, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attn_kwargs, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - if shard_config.enable_sequence_parallelism: - # Cases that don't support parallelizing cross entropy computation along sequence - if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather: - hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = None - if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache - ) - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - return forward - - -def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): - from transformers import LlamaForCausalLM - - def forward( - self: LlamaForCausalLM, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output: - # Special processing: Split labels in a zigzag fashion too - sp_group = shard_config.sequence_parallel_process_group - if attention_mask.bool().all(): - labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True) - else: - # [B, max_seq_len // sp_size] - labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - force_sp_output_gather=False, - ) - - hidden_states = outputs[0] - if self.config.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] - logits = torch.cat(logits, dim=-1) - else: - logits = self.lm_head(hidden_states) - logits = logits.float() - loss = None - if labels is not None: - loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - return forward diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 83829757f804..03ab477c7e8d 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -6,13 +6,7 @@ import colossalai.shardformer.layer as col_nn -from ..modeling.gpt2 import ( - GPT2PipelineForwards, - get_gpt2_flash_attention_forward, - get_gpt2_flash_attn_model_forward, - get_jit_fused_gpt2_mlp_forward, - get_lm_forward_with_dist_cross_entropy, -) +from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, get_jit_fused_gpt2_mlp_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -70,7 +64,7 @@ def module_policy(self): warnings.warn( f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" ) - sp_mode = "split_gather" + self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather" overlap = self.shard_config.enable_sequence_overlap sp_partial_derived = sp_mode in ["split_gather", "ring"] use_flash_attention = self.shard_config.enable_flash_attention @@ -197,7 +191,7 @@ def module_policy(self): if use_flash_attention: self.append_or_create_method_replacement( description={ - "forward": get_gpt2_flash_attention_forward(), + "forward": get_gpt2_flash_attention_forward(shard_config=self.shard_config), }, policy=policy, target_key=attn_cls, @@ -205,7 +199,9 @@ def module_policy(self): # This supports SP + flash attn if not self.shard_config.pipeline_stage_manager: - policy[GPT2Model].method_replacement = {"forward": get_gpt2_flash_attn_model_forward(self.shard_config)} + policy[GPT2Model].method_replacement = { + "forward": partial(GPT2PipelineForwards.gpt2_model_forward, shard_config=self.shard_config) + } return policy @@ -329,7 +325,7 @@ def module_policy(self): } if self.shard_config.parallel_output: addon_module[GPT2LMHeadModel].method_replacement = { - "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) + "forward": partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, shard_config=self.shard_config) } else: addon_module = { diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index f72a72df0b1b..ec04a55b3412 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -16,12 +16,7 @@ VocabParallelLMHead1D, ) -from ..modeling.llama import ( - LlamaPipelineForwards, - get_llama_flash_attention_forward, - get_llama_flash_attention_model_forward, - get_lm_forward_with_dist_cross_entropy, -) +from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"] @@ -98,11 +93,9 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.pipeline_stage_manager is None: self.append_or_create_method_replacement( description={ - "forward": get_llama_flash_attention_model_forward( - self.shard_config, - sp_mode=sp_mode, - sp_size=sp_size, - sp_group=sp_group, + "forward": partial( + LlamaPipelineForwards.llama_model_forward, + shard_config=self.shard_config, ), }, policy=policy, @@ -342,7 +335,8 @@ def module_policy(self): elif self.shard_config.enable_tensor_parallelism or self.shard_config.enable_sequence_parallelism: # Compute loss distributedly along the sequence dimension new_item[LlamaForCausalLM].method_replacement = { - "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) + # "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) + "forward": partial(LlamaPipelineForwards.llama_for_causal_lm_forward, shard_config=self.shard_config) } return policy diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index f71776b6b4e0..f2b139beca83 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -27,7 +27,16 @@ def data_gen_for_lm(): # LM data gen # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` data = data_gen() - data["labels"] = data["input_ids"].clone() + + # Test padded sequence for Ring Attention + padding = torch.zeros(1, data["input_ids"].shape[1] // 2, dtype=torch.long) + data["input_ids"] = torch.cat([data["input_ids"], padding], dim=1) + data["attention_mask"] = torch.cat([data["attention_mask"], padding], dim=1) + + ignore_idx = -100 + labels = data["input_ids"].clone() + labels[~data["attention_mask"].bool()] = ignore_idx + data["labels"] = labels return data diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index c94d049fc13d..0f9ec601387b 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -157,7 +157,6 @@ def build_model_from_hybrid_plugin( sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3) criterion = loss_fn - plugin = pluggin_cls(**test_config) booster = Booster(plugin=plugin) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 9db5c0c680c9..6a2515a99b6a 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -138,15 +138,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + # TODO: Ring Attention + PP seems to have some precision issue to be resolved { - "tp_size": 4, + "sp_size": 2, + "tp_size": 2, "pp_size": 1, - "num_microbatches": 1, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring", - "enable_flash_attention": True, + "sequence_parallelism_mode": "ring_attn", + "num_microbatches": 1, + "enable_all_optimization": True, "use_lazy_init": True, - "precision": "fp32", + "precision": "fp16", "initial_scale": 1, }, { @@ -203,7 +205,16 @@ def run_gpt2_test(test_config): loss_fn, _, ) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and name != "transformers_gpt_lm": + # Only wrote zigzag splitting for cross entropy loss + continue + + try: + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + except Exception as e: + print(f"Failed config: {test_config} for model {name}") + raise (e) clear_layout_converter() torch.cuda.empty_cache() @@ -244,7 +255,11 @@ def run_gpt2_3d_test(test_config): loss_fn, _, ) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + try: + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + except Exception as e: + print(f"Failed config: {test_config} for model {name}") + raise (e) clear_layout_converter() torch.cuda.empty_cache() From c6067fe3ee1f290f02ad43c44b4d81ccd4ee17af Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 22 Aug 2024 11:42:42 +0000 Subject: [PATCH 63/71] trim llama forward logic --- colossalai/shardformer/layer/_operation.py | 2 ++ colossalai/shardformer/modeling/llama.py | 12 +++++------- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index efe4d80babbb..1419af49c418 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -1003,6 +1003,8 @@ def gather_sp_output(hidden_states, sp_group, sp_mode, sp_dim=1): """ Gather the output of the last layer for cross entropy computation """ + if dist.get_world_size(sp_group) == 1: + return hidden_states # Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group) scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group) hidden_states = gather_forward_split_backward(hidden_states, sp_dim, sp_group, grad_scale=scale) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 697b88870bf1..9803b14b2185 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -86,10 +86,10 @@ def llama_model_forward( batch_size, seq_length, _ = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") - device = input_ids.device if input_ids is not None else inputs_embeds.device if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds + device = hidden_states.device else: input_shape = hidden_states.shape[:-1] batch_size, seq_length = input_shape @@ -99,8 +99,8 @@ def llama_model_forward( sp_mode = shard_config.sequence_parallelism_mode sp_group = shard_config.sequence_parallel_process_group sp_size = shard_config.sequence_parallel_size - # Generating full positions ids for seq that's gathered before attn - if not disable_pp and (sp_mode != "ring_attn" and not stage_manager.is_first_stage()): + # Generating full positions ids for modes that gather sequence before attn + if stage_manager and (sp_mode != "ring_attn" and not stage_manager.is_first_stage()): seq_length *= sp_size past_seen_tokens = 0 @@ -146,7 +146,6 @@ def llama_model_forward( attn_kwargs = self._update_causal_mask(attention_mask, hidden_states, cache_position) # Support SP + PP - # TODO: support padded casual cu_seqlens across stages if disable_pp or stage_manager.is_first_stage(): # Ring Attention zigzag batch processing if sp_mode == "ring_attn": @@ -223,9 +222,8 @@ def llama_model_forward( if disable_pp or stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) - if shard_config.enable_sequence_parallelism: - if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): - hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) + if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) # add hidden states from the last decoder layer if output_hidden_states: From 051590d1a599223361e1d0733f18fcaff26054c7 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 22 Aug 2024 22:46:57 +0000 Subject: [PATCH 64/71] GPT support sp + pp --- colossalai/shardformer/modeling/gpt2.py | 6 +- colossalai/shardformer/modeling/llama.py | 10 ++- colossalai/shardformer/policies/gpt2.py | 65 +++++++++---------- .../gpt/hybridparallelism/benchmark.py | 7 +- .../test_model/test_shard_gpt2.py | 13 +++- 5 files changed, 59 insertions(+), 42 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 65c26a71207a..019adc9095aa 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -225,7 +225,7 @@ def gpt2_model_forward( # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] sp_mode = shard_config.sequence_parallelism_mode sp_group = shard_config.sequence_parallel_process_group - if shard_config.enable_sequence_parallelism: + if disable_pp or stage_manager.is_first_stage(): # Ring Attention's special zigzag batch processing if sp_mode == "ring_attn": assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." @@ -244,6 +244,10 @@ def gpt2_model_forward( dim=1, process_group=shard_config.tensor_parallel_process_group, ) + elif sp_mode == "ring_attn": + # Later stages already received split hidden states + _, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group) + del attention_mask # Going through held blocks. diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 9803b14b2185..1aed7d9d4906 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -127,14 +127,12 @@ def llama_model_forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - # embed positions, for the first stage, hidden_states is the input embeddings, - # for the other stages, hidden_states is the output of the previous stage + if (disable_pp or not stage_manager.is_first_stage()) and sp_mode == "ring_attn": _, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group) elif shard_config.enable_flash_attention: - # in this case, attention_mask is a dict rather than a tensor mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) - attn_kwargs = ColoAttention.prepare_attn_kwargs( + attn_kwargs: dict = ColoAttention.prepare_attn_kwargs( mask_shape, hidden_states.dtype, hidden_states.device, @@ -143,9 +141,9 @@ def llama_model_forward( invert=(sp_mode != "ring_attn"), ) else: - attn_kwargs = self._update_causal_mask(attention_mask, hidden_states, cache_position) + attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position) - # Support SP + PP + # Support SP + PP. Later stages have already received the split input. if disable_pp or stage_manager.is_first_stage(): # Ring Attention zigzag batch processing if sp_mode == "ring_attn": diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 03ab477c7e8d..eb826305dec8 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -197,8 +197,7 @@ def module_policy(self): target_key=attn_cls, ) - # This supports SP + flash attn - if not self.shard_config.pipeline_stage_manager: + if not self.shard_config.pipeline_stage_manager and self.shard_config.enable_sequence_parallelism: policy[GPT2Model].method_replacement = { "forward": partial(GPT2PipelineForwards.gpt2_model_forward, shard_config=self.shard_config) } @@ -307,39 +306,39 @@ def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel module_policy = super().module_policy() - + module_policy[GPT2LMHeadModel] = ModulePolicyDescription() if self.shard_config.enable_tensor_parallelism: - addon_module = { - GPT2LMHeadModel: ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", - target_module=col_nn.VocabParallelLMHead1D, - kwargs={ - "gather_output": False, - "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, - }, - ) - ], - ) - } - if self.shard_config.parallel_output: - addon_module[GPT2LMHeadModel].method_replacement = { - "forward": partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, shard_config=self.shard_config) - } + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.VocabParallelLMHead1D, + kwargs={ + "gather_output": False, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, + ), + policy=module_policy, + target_key=GPT2LMHeadModel, + ) else: - addon_module = { - GPT2LMHeadModel: ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", - target_module=col_nn.PaddingLMHead, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, - ) - ] - ) - } - module_policy.update(addon_module) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=module_policy, + target_key=GPT2LMHeadModel, + ) + + if self.shard_config.parallel_output: + self.append_or_create_method_replacement( + description={ + "forward": partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, shard_config=self.shard_config) + }, + policy=module_policy, + target_key=GPT2LMHeadModel, + ) if self.pipeline_stage_manager is not None: self.set_pipeline_forward( diff --git a/examples/language/gpt/hybridparallelism/benchmark.py b/examples/language/gpt/hybridparallelism/benchmark.py index 8c236b524c26..2a801586e56a 100644 --- a/examples/language/gpt/hybridparallelism/benchmark.py +++ b/examples/language/gpt/hybridparallelism/benchmark.py @@ -28,7 +28,7 @@ "118M": GPT2Config(activation_function="gelu"), "338M": GPT2Config(n_embd=1024, n_head=16, n_layer=24, activation_function="gelu"), "738M": GPT2Config(n_embd=1280, n_head=20, n_layer=36, activation_function="gelu"), - "6.21B": GPT2Config(n_embd=4096, n_head=32, n_layer=32, n_positions=4096, activation_function="gelu"), + "6.21B": GPT2Config(n_embd=4096, n_head=32, n_layer=32, n_positions=16384, activation_function="gelu"), } @@ -60,6 +60,8 @@ def main(): parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini") parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") + parser.add_argument("--sp", type=int, default=1, help="Sequence parallel size") + parser.add_argument("--sp_mode", type=str, default="ring_attn", help="Sequence parallel mode") parser.add_argument("--mbs", type=int, default=1) parser.add_argument("--zero", type=int, default=0) parser.add_argument("--pp_style", type=str, default="1f1b") @@ -129,6 +131,9 @@ def empty_init(): tp_size=args.tp, pp_size=args.pp, pp_style=args.pp_style, + sp_size=args.sp, + sequence_parallelism_mode=args.sp_mode, + enable_sequence_parallelism=True, zero_stage=args.zero, num_model_chunks=args.num_model_chunks, enable_all_optimization=True, diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 6a2515a99b6a..f5b7cf5823db 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -138,7 +138,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - # TODO: Ring Attention + PP seems to have some precision issue to be resolved + { + "sp_size": 2, + "tp_size": 1, + "pp_size": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring_attn", + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "sp_size": 2, "tp_size": 2, From ce1184c78360a851fca638dc985df0a654590200 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 23 Aug 2024 07:23:39 +0000 Subject: [PATCH 65/71] attempt to simplify code --- colossalai/shardformer/modeling/gpt2.py | 16 +++++++--------- colossalai/shardformer/modeling/llama.py | 16 +++++++++------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 019adc9095aa..16b2526cf9d2 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -23,7 +23,6 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import ColoAttention, RingAttention from colossalai.shardformer.layer._operation import gather_sp_output, split_forward_gather_backward -from colossalai.shardformer.layer.attn import AttnMaskType from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag from colossalai.shardformer.shard import ShardConfig @@ -133,7 +132,7 @@ def gpt2_model_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, - force_sp_output_gather: Optional[bool] = True, + force_sp_gather: Optional[bool] = True, ) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. # Please refer to original code of transformers for more details. @@ -229,15 +228,14 @@ def gpt2_model_forward( # Ring Attention's special zigzag batch processing if sp_mode == "ring_attn": assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." - # Get cu_seqlens - if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: + if not attention_mask.bool().all(): hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch( attention_mask, sp_group, hidden_states, position_ids ) else: hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group) # Other sp modes - elif disable_pp or stage_manager.is_first_stage(): + else: if sp_mode == "split_gather": hidden_states = split_forward_gather_backward( hidden_states, @@ -247,7 +245,6 @@ def gpt2_model_forward( elif sp_mode == "ring_attn": # Later stages already received split hidden states _, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group) - del attention_mask # Going through held blocks. @@ -301,8 +298,9 @@ def gpt2_model_forward( all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) # When sequence parallelism is done, gather the output tensor in forward and split it in backward - if (disable_pp or stage_manager.is_last_stage()) and shard_config.enable_sequence_parallelism: - if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + gather_output = (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode) + if disable_pp or stage_manager.is_last_stage(): + if gather_output: hidden_states = gather_sp_output( hidden_states, sp_dim=1, @@ -399,7 +397,7 @@ def gpt2_lmhead_model_forward( hidden_states=hidden_states, stage_index=stage_index, shard_config=shard_config, - force_sp_output_gather=False, + force_sp_gather=False, ) # If not at the last stage, return hidden_states as in GPT2Model diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 1aed7d9d4906..32b5402b8326 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -25,7 +25,6 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.layer import AttnMaskType from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag from colossalai.shardformer.shard import ShardConfig @@ -58,7 +57,7 @@ def llama_model_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, - force_sp_output_gather: bool = True, # Set to false only when computing cross entropy + force_sp_gather: bool = True, # Set to false only when computing cross entropy ): logger = logging.get_logger(__name__) @@ -128,7 +127,8 @@ def llama_model_forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - if (disable_pp or not stage_manager.is_first_stage()) and sp_mode == "ring_attn": + no_split_input = disable_pp or not stage_manager.is_first_stage() + if no_split_input and sp_mode == "ring_attn": _, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group) elif shard_config.enable_flash_attention: mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) @@ -144,11 +144,12 @@ def llama_model_forward( attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position) # Support SP + PP. Later stages have already received the split input. - if disable_pp or stage_manager.is_first_stage(): + split_input = disable_pp or stage_manager.is_first_stage() + if split_input: # Ring Attention zigzag batch processing if sp_mode == "ring_attn": assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." - if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: + if not attention_mask.bool().all(): hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch( attention_mask, sp_group, hidden_states, position_ids ) @@ -218,9 +219,10 @@ def llama_model_forward( if output_attentions: all_self_attns += (layer_outputs[1],) + gather_output = (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode) if disable_pp or stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) - if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): + if gather_output: hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) # add hidden states from the last decoder layer @@ -333,7 +335,7 @@ def llama_for_causal_lm_forward( hidden_states=hidden_states, stage_index=stage_index, shard_config=shard_config, - force_sp_output_gather=False, + force_sp_gather=False, ) past_key_values = None From 6d5fc3aa25156dfe4d135cfc2f803989f385111d Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 23 Aug 2024 10:01:54 +0000 Subject: [PATCH 66/71] debug --- colossalai/shardformer/modeling/gpt2.py | 4 ++-- colossalai/shardformer/modeling/llama.py | 3 +-- colossalai/shardformer/policies/llama.py | 8 +++----- tests/test_shardformer/test_model/test_shard_gpt2.py | 2 +- 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 16b2526cf9d2..da28ea8a7d0e 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -304,8 +304,8 @@ def gpt2_model_forward( hidden_states = gather_sp_output( hidden_states, sp_dim=1, - sp_group=shard_config.tensor_parallel_process_group, - sp_mode=shard_config.sequence_parallelism_mode, + sp_group=sp_group, + sp_mode=sp_mode, ) # gather_sp_output could've changed seq length. diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 32b5402b8326..2db8e5ca13b9 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -219,10 +219,9 @@ def llama_model_forward( if output_attentions: all_self_attns += (layer_outputs[1],) - gather_output = (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode) if disable_pp or stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) - if gather_output: + if (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode): # noqa hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) # add hidden states from the last decoder layer diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 47f8677db31f..bc9bf3326e97 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -94,11 +94,9 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.pipeline_stage_manager is None: self.append_or_create_method_replacement( description={ - "forward": get_llama_flash_attention_model_forward( - self.shard_config, - sp_mode=sp_mode, - sp_size=sp_size, - sp_group=sp_group, + "forward": partial( + LlamaPipelineForwards.llama_model_forward, + shard_config=self.shard_config, ), }, policy=policy, diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index f5b7cf5823db..393f7ffca7d3 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -187,7 +187,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, { "tp_size": 2, "pp_size": 2, - "num_microbatches": 4, + "num_microbatches": 2, "enable_all_optimization": True, "use_lazy_init": True, "precision": "fp16", From 4a32c681386c5faf48cf33d49e0d9ea034dc3985 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 27 Aug 2024 12:26:24 +0800 Subject: [PATCH 67/71] fix all-reduce elapsed time --- examples/language/gpt/hybridparallelism/benchmark.py | 2 ++ examples/language/performance_evaluator.py | 8 +++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/language/gpt/hybridparallelism/benchmark.py b/examples/language/gpt/hybridparallelism/benchmark.py index 2a801586e56a..ea4a53363d43 100644 --- a/examples/language/gpt/hybridparallelism/benchmark.py +++ b/examples/language/gpt/hybridparallelism/benchmark.py @@ -219,6 +219,8 @@ def empty_init(): performance_evaluator.on_step_start(step) outputs = model(**batch) loss = outputs[0] + del outputs + booster.backward(loss, optimizer) optimizer.step() optimizer.zero_grad() diff --git a/examples/language/performance_evaluator.py b/examples/language/performance_evaluator.py index f5ad1d23d2a7..65c7e49a2f03 100644 --- a/examples/language/performance_evaluator.py +++ b/examples/language/performance_evaluator.py @@ -6,7 +6,6 @@ from torch import Tensor from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler -from colossalai.accelerator import get_accelerator from colossalai.cluster import DistCoordinator @@ -22,8 +21,11 @@ def divide(x: float, y: float) -> float: def all_reduce_mean(x: float, world_size: int) -> float: if world_size == 1: return x - tensor = torch.tensor([x], device=get_accelerator().get_current_device()) - dist.all_reduce(tensor) + + # Use CPU tensor to avoid OOM/weird NCCl error + gloo_group = dist.new_group(backend="gloo") + tensor = torch.tensor([x], device="cpu") + dist.all_reduce(tensor, group=gloo_group) tensor = tensor / world_size return tensor.item() From 5365117788cb17a1b4e7cd30f8cd42c5cd6b6f77 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 28 Aug 2024 10:44:20 +0800 Subject: [PATCH 68/71] update gpt max seqlen to 32k --- examples/language/gpt/hybridparallelism/benchmark.py | 2 +- tests/test_shardformer/test_layer/ncu_ring_attn.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 tests/test_shardformer/test_layer/ncu_ring_attn.py diff --git a/examples/language/gpt/hybridparallelism/benchmark.py b/examples/language/gpt/hybridparallelism/benchmark.py index ea4a53363d43..91b9e6c04950 100644 --- a/examples/language/gpt/hybridparallelism/benchmark.py +++ b/examples/language/gpt/hybridparallelism/benchmark.py @@ -28,7 +28,7 @@ "118M": GPT2Config(activation_function="gelu"), "338M": GPT2Config(n_embd=1024, n_head=16, n_layer=24, activation_function="gelu"), "738M": GPT2Config(n_embd=1280, n_head=20, n_layer=36, activation_function="gelu"), - "6.21B": GPT2Config(n_embd=4096, n_head=32, n_layer=32, n_positions=16384, activation_function="gelu"), + "6.21B": GPT2Config(n_embd=4096, n_head=32, n_layer=32, n_positions=32768, activation_function="gelu"), } diff --git a/tests/test_shardformer/test_layer/ncu_ring_attn.py b/tests/test_shardformer/test_layer/ncu_ring_attn.py new file mode 100644 index 000000000000..a94d27df3962 --- /dev/null +++ b/tests/test_shardformer/test_layer/ncu_ring_attn.py @@ -0,0 +1,6 @@ +import torch +from flash_attn import flash_attn_qkvpacked_func + +bs, seq_len, nheads, d = 4, 4096, 32, 128 +qkv = torch.randn(bs, seq_len, 3, nheads, d, device="cuda:0", dtype=torch.bfloat16) +out, lse, _ = flash_attn_qkvpacked_func(qkv, causal=True, return_attn_probs=True) From 177142aa79b17590f2d6f400fa0c2852ccfdc2c1 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 28 Aug 2024 21:54:28 +0000 Subject: [PATCH 69/71] fix typos --- colossalai/shardformer/layer/_operation.py | 6 ++++-- colossalai/shardformer/modeling/llama.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 3bb0796ba37c..e2409ca94eff 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -1101,10 +1101,12 @@ def gather_sp_output(hidden_states, shard_config, sp_dim=1): """ Gather the output of the last layer for cross entropy computation """ - sp_group = shard_config.sequence_parallel_group - sp_mode = shard_config.sequence_parallel_mode + sp_group = shard_config.sequence_parallel_process_group + sp_mode = shard_config.sequence_parallelism_mode if dist.get_world_size(sp_group) == 1: return hidden_states + # Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group) scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group) hidden_states = gather_forward_split_backward(hidden_states, sp_dim, sp_group, grad_scale=scale) + return hidden_states diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index a84a1238f694..47c17e7494f2 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -250,7 +250,7 @@ def llama_model_forward( hidden_states=all_hidden_states, attentions=all_self_attns, ) - # always return dict for imediate stage + # always return dict for intermediate stage return {"hidden_states": hidden_states} @staticmethod From fc798f443397e315b0bd75bf7352452266015f68 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 29 Aug 2024 16:48:46 +0800 Subject: [PATCH 70/71] fix typos --- colossalai/shardformer/layer/_operation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index e2409ca94eff..f970d8ccc85d 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -1103,10 +1103,13 @@ def gather_sp_output(hidden_states, shard_config, sp_dim=1): """ sp_group = shard_config.sequence_parallel_process_group sp_mode = shard_config.sequence_parallelism_mode + fp8_comm = shard_config.fp8_communication if dist.get_world_size(sp_group) == 1: return hidden_states # Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group) scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group) - hidden_states = gather_forward_split_backward(hidden_states, sp_dim, sp_group, grad_scale=scale) + hidden_states = gather_forward_split_backward( + hidden_states, sp_dim, sp_group, grad_scale=scale, fp8_communication=fp8_comm + ) return hidden_states From 04e1c1e7c4d2f7a1a4eed77691e694a103346fde Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 2 Sep 2024 09:31:50 +0000 Subject: [PATCH 71/71] remove --- tests/test_shardformer/test_layer/ncu_ring_attn.py | 6 ------ 1 file changed, 6 deletions(-) delete mode 100644 tests/test_shardformer/test_layer/ncu_ring_attn.py diff --git a/tests/test_shardformer/test_layer/ncu_ring_attn.py b/tests/test_shardformer/test_layer/ncu_ring_attn.py deleted file mode 100644 index a94d27df3962..000000000000 --- a/tests/test_shardformer/test_layer/ncu_ring_attn.py +++ /dev/null @@ -1,6 +0,0 @@ -import torch -from flash_attn import flash_attn_qkvpacked_func - -bs, seq_len, nheads, d = 4, 4096, 32, 128 -qkv = torch.randn(bs, seq_len, 3, nheads, d, device="cuda:0", dtype=torch.bfloat16) -out, lse, _ = flash_attn_qkvpacked_func(qkv, causal=True, return_attn_probs=True)