Skip to content

Commit

Permalink
remove events to simplify stream sync
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Aug 6, 2024
1 parent 07b4fb4 commit 5c4b445
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,7 @@ class RingAttention(torch.autograd.Function):
TOTAL_SEQLEN: int = None
HALF_INDICES: Tuple = None
SUPPORTED_MASK_TYPES = (AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL)
CORRECTION_DONE = torch.cuda.Event()
ATTN_DONE = torch.cuda.Event()
ATTN_DONE: torch.cuda.Event = None

@staticmethod
def attention(
Expand Down Expand Up @@ -504,6 +503,8 @@ def attention(
Shape should be [total_q_seqlen, nHeads]
"""
_load_flash_attn()
if RingAttention.ATTN_DONE is None:
RingAttention.ATTN_DONE = torch.cuda.Event()
assert (
q.shape[2] == k.shape[2]
), "Q, K and V having different sequence lengths (inference or cross-attn)\
Expand Down Expand Up @@ -727,17 +728,15 @@ def forward(
causal=False,
**misc_kwargs,
)
RingAttention.ATTN_DONE.record(sp_streams[i % 2])
RingAttention.ATTN_DONE.record()

block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
) # (H, T) -> (T, H, 1)
assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1]
# Output and log sum exp correction
if i > 0:
sp_streams[i % 2].wait_event(RingAttention.CORRECTION_DONE)

# Overlap output correction with next flash attn kernel
# In reality this always finishes before next flash attn
if i == 0:
out = block_out[0]
softmax_lse = block_softmax_lse[0]
Expand All @@ -747,9 +746,7 @@ def forward(
out[half_idx_back], softmax_lse[half_idx_back] = _rescale_out_lse(
out[half_idx_back], block_out[i % 2], softmax_lse[half_idx_back], block_softmax_lse[i % 2]
)

RingAttention.CORRECTION_DONE.record(sp_streams[i % 2])
torch.cuda.current_stream().wait_event(RingAttention.CORRECTION_DONE)
torch.cuda.current_stream().wait_stream(sp_stream)

out = out.to(q.dtype)
if not is_packed:
Expand Down

0 comments on commit 5c4b445

Please sign in to comment.