diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 9c4056305bdb..3de32812fdb9 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -620,6 +620,11 @@ def forward( sp_size = kv_comms[0].world_size sp_rank = kv_comms[0].rank + # Non-contiguous indexing creates a new contiguous tensor, + # so only do it once + if sp_rank != sp_size - 1: + q1 = q[half_idx_back] + # Pre-allocate double buffer for overlapping and receiving next step's inputs kv_buffers = [torch.stack((k, v))] # (2, B, Sq, H, D) kv_buffers.append(torch.empty_like(kv_buffers[0])) @@ -700,7 +705,7 @@ def forward( # Received the inner kv chunks # Drop the first half of q kv_block = kv_buffers[i % 2] - q_block = q[half_idx_back] + q_block = q1 ( _, @@ -814,8 +819,6 @@ def backward(ctx, dout, _): else: dkv_comm = RingComm(sp_group) - # Non-contiguous indexing creates a new contiguous tensor, - # so only do it once if sp_rank != sp_size - 1: softmax_lse1 = softmax_lse[:, half_idx_back] dout = dout.contiguous()