Skip to content

Commit

Permalink
q1 index only once
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Aug 5, 2024
1 parent 224ac3d commit 07b4fb4
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,11 @@ def forward(
sp_size = kv_comms[0].world_size
sp_rank = kv_comms[0].rank

# Non-contiguous indexing creates a new contiguous tensor,
# so only do it once
if sp_rank != sp_size - 1:
q1 = q[half_idx_back]

# Pre-allocate double buffer for overlapping and receiving next step's inputs
kv_buffers = [torch.stack((k, v))] # (2, B, Sq, H, D)
kv_buffers.append(torch.empty_like(kv_buffers[0]))
Expand Down Expand Up @@ -700,7 +705,7 @@ def forward(
# Received the inner kv chunks
# Drop the first half of q
kv_block = kv_buffers[i % 2]
q_block = q[half_idx_back]
q_block = q1

(
_,
Expand Down Expand Up @@ -814,8 +819,6 @@ def backward(ctx, dout, _):
else:
dkv_comm = RingComm(sp_group)

# Non-contiguous indexing creates a new contiguous tensor,
# so only do it once
if sp_rank != sp_size - 1:
softmax_lse1 = softmax_lse[:, half_idx_back]
dout = dout.contiguous()
Expand Down

0 comments on commit 07b4fb4

Please sign in to comment.