Skip to content

Commit

Permalink
Merge branch 'ring-attn' into split_ce
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Aug 7, 2024
2 parents 0345fc3 + 5c4b445 commit 842136b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
37 changes: 21 additions & 16 deletions colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,10 +415,10 @@ def _rescale_out_lse(out, block_out, lse, block_lse):
Compute the new attention denominator:
exp(lse) + exp(block_lse) = exp(max_scale) * (exp(min_scale - max_scale) + 1)
Args:
out: (B, Sq, H, D)
block_out: (B, Sq, H, D)
lse: (B, H, Sq, 1)
block_lse: (B, H, Sq, 1)
out: (T, H, D)
block_out: (T, H, D)
lse: (H, T, 1)
block_lse: (H, T, 1)
"""

# min_scale = torch.min(lse, block_lse)
Expand Down Expand Up @@ -455,8 +455,7 @@ class RingAttention(torch.autograd.Function):
TOTAL_SEQLEN: int = None
HALF_INDICES: Tuple = None
SUPPORTED_MASK_TYPES = (AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL)
CORRECTION_DONE = torch.cuda.Event()
ATTN_DONE = torch.cuda.Event()
ATTN_DONE: torch.cuda.Event = None

@staticmethod
def attention(
Expand Down Expand Up @@ -504,6 +503,8 @@ def attention(
Shape should be [total_q_seqlen, nHeads]
"""
_load_flash_attn()
if RingAttention.ATTN_DONE is None:
RingAttention.ATTN_DONE = torch.cuda.Event()
assert (
q.shape[2] == k.shape[2]
), "Q, K and V having different sequence lengths (inference or cross-attn)\
Expand Down Expand Up @@ -620,6 +621,11 @@ def forward(
sp_size = kv_comms[0].world_size
sp_rank = kv_comms[0].rank

# Non-contiguous indexing creates a new contiguous tensor,
# so only do it once
if sp_rank != sp_size - 1:
q1 = q[half_idx_back]

# Pre-allocate double buffer for overlapping and receiving next step's inputs
kv_buffers = [torch.stack((k, v))] # (2, B, Sq, H, D)
kv_buffers.append(torch.empty_like(kv_buffers[0]))
Expand Down Expand Up @@ -700,7 +706,7 @@ def forward(
# Received the inner kv chunks
# Drop the first half of q
kv_block = kv_buffers[i % 2]
q_block = q[half_idx_back]
q_block = q1

(
_,
Expand All @@ -722,17 +728,15 @@ def forward(
causal=False,
**misc_kwargs,
)
RingAttention.ATTN_DONE.record(sp_streams[i % 2])
RingAttention.ATTN_DONE.record()

block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
) # (H, T) -> (T, H, 1)
assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1]
# Output and log sum exp correction
if i > 0:
sp_streams[i % 2].wait_event(RingAttention.CORRECTION_DONE)

# Overlap output correction with next flash attn kernel
# In reality this always finishes before next flash attn
if i == 0:
out = block_out[0]
softmax_lse = block_softmax_lse[0]
Expand All @@ -742,9 +746,7 @@ def forward(
out[half_idx_back], softmax_lse[half_idx_back] = _rescale_out_lse(
out[half_idx_back], block_out[i % 2], softmax_lse[half_idx_back], block_softmax_lse[i % 2]
)

RingAttention.CORRECTION_DONE.record(sp_streams[i % 2])
torch.cuda.current_stream().wait_event(RingAttention.CORRECTION_DONE)
torch.cuda.current_stream().wait_stream(sp_stream)

out = out.to(q.dtype)
if not is_packed:
Expand Down Expand Up @@ -790,7 +792,6 @@ def backward(ctx, dout, _):
max_seqlen_kv = ctx.max_seqlen_kv
dkv_group = ctx.dkv_group
misc_kwargs = ctx.misc_kwargs
dout = dout.contiguous()
del misc_kwargs["block_table"]

assert (
Expand All @@ -815,6 +816,10 @@ def backward(ctx, dout, _):
else:
dkv_comm = RingComm(sp_group)

if sp_rank != sp_size - 1:
softmax_lse1 = softmax_lse[:, half_idx_back]
dout = dout.contiguous()

# Double comm buffers for sending and receiving kv
kv_buffers = [torch.stack((k, v))] # (2, T, H, D)
kv_buffers.append(torch.empty_like(kv_buffers[0]))
Expand Down Expand Up @@ -899,7 +904,7 @@ def backward(ctx, dout, _):
k_,
v_,
out_,
softmax_lse[:, half_idx_back],
softmax_lse1,
dq_,
dk_,
dv_,
Expand Down
1 change: 1 addition & 0 deletions tests/test_shardformer/test_layer/test_ring_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype):
ring_dq, ring_dk, ring_dv = [x.transpose(1, 2) for x in (q.grad, k.grad, v.grad)]
dqkv = qkv.grad
local_dqkv = split_batch_zigzag(dqkv, sp_group)

assert_close(ring_dq, local_dqkv[:, :, 0], atol=atol, rtol=rtol)
assert_close(ring_dk, local_dqkv[:, :, 1], atol=atol, rtol=rtol)
assert_close(ring_dv, local_dqkv[:, :, 2], atol=atol, rtol=rtol)
Expand Down

0 comments on commit 842136b

Please sign in to comment.