Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Split cross-entropy computation in SP #5959

Merged
merged 74 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
c7a19e2
halfway
Edenzzzz Jun 28, 2024
92b4891
fix cross-PP-stage position id length diff bug
Edenzzzz Jun 28, 2024
a9ed834
fix typo
Edenzzzz Jun 29, 2024
20f2a73
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 29, 2024
edb9043
unified cross entropy func for all shardformer models
Edenzzzz Jul 2, 2024
bad530f
remove redundant lines
Edenzzzz Jul 2, 2024
57746c0
add basic ring attn; debug cross entropy
Edenzzzz Jul 8, 2024
1607ea0
fwd bwd logic complete
Edenzzzz Jul 13, 2024
1796b80
fwd bwd logic complete; add experimental triton rescale
Edenzzzz Jul 14, 2024
500454b
precision tests passed
Edenzzzz Jul 18, 2024
c1ea3ba
precision tests passed
Edenzzzz Jul 21, 2024
4b8a412
fix typos and remove misc files
Edenzzzz Jul 22, 2024
4f864b2
update softmax_lse shape by new interface
Edenzzzz Jul 22, 2024
864dac6
change tester name
Edenzzzz Jul 22, 2024
69bf303
remove buffer clone; support packed seq layout
Edenzzzz Jul 23, 2024
ec4fab7
add varlen tests
Edenzzzz Jul 24, 2024
cd9349e
fix typo
Edenzzzz Jul 26, 2024
25d3e38
all tests passed
Edenzzzz Aug 1, 2024
1234d99
add dkv_group; fix mask
Edenzzzz Aug 1, 2024
f6a8f12
remove debug statements
Edenzzzz Aug 1, 2024
c0b7e96
adapt chatglm, command-R, qwen
Edenzzzz Aug 1, 2024
0eb6fdf
debug
Edenzzzz Aug 5, 2024
2fae794
halfway
Edenzzzz Jun 28, 2024
91bab84
fix cross-PP-stage position id length diff bug
Edenzzzz Jun 28, 2024
c5c14b6
fix typo
Edenzzzz Jun 29, 2024
a00c93b
fix typo
Edenzzzz Jun 29, 2024
ab9b784
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 29, 2024
7d99bc0
unified cross entropy func for all shardformer models
Edenzzzz Jul 2, 2024
423102a
remove redundant lines
Edenzzzz Jul 2, 2024
91fb3c1
add basic ring attn; debug cross entropy
Edenzzzz Jul 8, 2024
c050293
fwd bwd logic complete
Edenzzzz Jul 13, 2024
65b4b76
fwd bwd logic complete; add experimental triton rescale
Edenzzzz Jul 14, 2024
5824ede
precision tests passed
Edenzzzz Jul 18, 2024
0e72997
precision tests passed
Edenzzzz Jul 21, 2024
f5a1b99
fix typos and remove misc files
Edenzzzz Jul 22, 2024
52331c9
add sp_mode to benchmark; fix varlen interface
Edenzzzz Jul 22, 2024
c70d03a
update softmax_lse shape by new interface
Edenzzzz Jul 22, 2024
9c83343
add varlen tests
Edenzzzz Jul 24, 2024
cc6472a
fix typo
Edenzzzz Jul 26, 2024
ed4ad6d
all tests passed
Edenzzzz Aug 1, 2024
36f691d
add dkv_group; fix mask
Edenzzzz Aug 1, 2024
6b5d1bf
remove debug statements
Edenzzzz Aug 1, 2024
bdad28a
add comments
Edenzzzz Aug 2, 2024
89343fd
q1 index only once
Edenzzzz Aug 5, 2024
551aaec
remove events to simplify stream sync
Edenzzzz Aug 6, 2024
43c0b65
simplify forward/backward logic
Edenzzzz Aug 9, 2024
fb4e905
2d ring forward passed
Edenzzzz Aug 12, 2024
8bc062d
2d ring backward passed
Edenzzzz Aug 13, 2024
d844ded
fixes
Edenzzzz Aug 14, 2024
8c01223
fix ring attn loss
Edenzzzz Aug 14, 2024
7a7fb1f
2D ring backward + llama passed
Edenzzzz Aug 14, 2024
78ed55d
merge
Edenzzzz Aug 14, 2024
70b1f5d
update logger
Edenzzzz Aug 15, 2024
f91586f
fix typo
Edenzzzz Aug 15, 2024
2ce53f1
rebase
Edenzzzz Aug 16, 2024
5fed9da
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 16, 2024
8ec9009
fix typo
Edenzzzz Aug 16, 2024
d0aeec9
remove typos
Edenzzzz Aug 16, 2024
de0afd1
fixes
Edenzzzz Aug 18, 2024
2fb7db6
support GPT
Edenzzzz Aug 19, 2024
a374633
Merge branch 'main' into split_ce
Edenzzzz Aug 20, 2024
63fd075
fix gpt2
Edenzzzz Aug 20, 2024
17002b6
gpt ring attn + TP passed
Edenzzzz Aug 22, 2024
c6067fe
trim llama forward logic
Edenzzzz Aug 22, 2024
051590d
GPT support sp + pp
Edenzzzz Aug 22, 2024
ce1184c
attempt to simplify code
Edenzzzz Aug 23, 2024
8ad3d5b
Merge branch 'main' into split_ce
Edenzzzz Aug 23, 2024
6d5fc3a
debug
Edenzzzz Aug 23, 2024
4a32c68
fix all-reduce elapsed time
Edenzzzz Aug 27, 2024
5365117
update gpt max seqlen to 32k
Edenzzzz Aug 28, 2024
d0107c3
merge with main
Edenzzzz Aug 28, 2024
177142a
fix typos
Edenzzzz Aug 28, 2024
fc798f4
fix typos
Edenzzzz Aug 29, 2024
04e1c1e
remove
Edenzzzz Sep 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,13 +1097,19 @@ def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication)


def gather_sp_output(hidden_states, sp_group, sp_mode, fp8_communication=False):
ver217 marked this conversation as resolved.
Show resolved Hide resolved
def gather_sp_output(hidden_states, shard_config, sp_dim=1):
"""
Gather the output of the last layer for cross entropy computation
"""
sp_group = shard_config.sequence_parallel_process_group
sp_mode = shard_config.sequence_parallelism_mode
fp8_comm = shard_config.fp8_communication
if dist.get_world_size(sp_group) == 1:
return hidden_states

# Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group)
scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=scale, fp8_communication=fp8_communication
hidden_states, sp_dim, sp_group, grad_scale=scale, fp8_communication=fp8_comm
)
return hidden_states
13 changes: 9 additions & 4 deletions colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,6 @@ def get_double_ring_groups(sp_group, inner_ring_size=None):
assert (
sp_size % inner_ring_size == 0
), f"sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"

logger = get_dist_logger()
logger.info(
f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!",
Expand Down Expand Up @@ -898,6 +897,7 @@ def backward(ctx, dout, _):

local_sp_rank = dist.get_rank(sp_group)
sp_size = dist.get_world_size(sp_group)

# Using separate streams (pg) for concurrent kv and dkv comm may
# cause NCCL "software caused connection abort" here...
local_kv_comm = RingComm(local_kv_group)
Expand Down Expand Up @@ -1119,9 +1119,14 @@ def prepare_varlen_batch(
the batch dim to a packed 1d sequence. Contingent on model forward shape definitions.

Returns:
inputs_embeds: Packed input embeddings of shape [B, Sq // sp_size, ...].
mask_info: A dictionary of mask info.
position_ids: Packed position ids of shape [..., Sq // sp_size].
torch.Tensor:
Packed input embeddings of shape [B, Sq // sp_size, ...].

Dict[str, Any]:
A dictionary containing mask info.

torch.Tensor:
Packed position ids of shape [..., Sq // sp_size].

"""
_load_varlen_helpers()
Expand Down
5 changes: 2 additions & 3 deletions colossalai/shardformer/layer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def dist_cross_entropy(
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
logits: torch.Tensor, # [B, S, Vocab_size]
shard_config: ShardConfig,
out_features: int,
vocab_size: int,
dtype: torch.dtype,
seq_dim: int = 1,
Expand Down Expand Up @@ -226,13 +225,13 @@ def dist_cross_entropy(
logits,
labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=out_features,
vocab_size=vocab_size,
ver217 marked this conversation as resolved.
Show resolved Hide resolved
dtype=dtype,
mode="sum",
)
else:
# NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D
logits = logits.view(-1, vocab_size)
logits = logits.view(-1, logits.size(-1))
loss = loss_fct(logits, labels)

# Reduce loss instead of gathering logits over seq dim for savings
Expand Down
39 changes: 26 additions & 13 deletions colossalai/shardformer/layer/qkv_fused_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,19 +313,19 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:

# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None

if self.seq_parallel_mode is None:
# Set up backprop all-reduce.
input_parallel = reduce_backward(input_, self.process_group, fp8_communication=self.fp8_communication)
output_parallel = matmul_with_async_comm(
if self.seq_parallel_mode == "split_gather":
input_parallel = input_
output_parallel = matmul_gather_forward_reducescatter_backward(
input_parallel,
self.weight,
bias,
self.process_group,
self.async_communication,
True,
1,
self.overlap,
fp8_communication=self.fp8_communication,
)
elif self.seq_parallel_mode == "split_gather":
elif self.seq_parallel_mode == "ring":
input_parallel = input_
output_parallel = matmul_gather_forward_reducescatter_backward(
input_parallel,
Expand All @@ -335,13 +335,22 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
True,
1,
self.overlap,
True,
fp8_communication=self.fp8_communication,
)
elif self.seq_parallel_mode == "ring":
input_parallel = input_
output_parallel = matmul_gather_forward_reducescatter_backward(
input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap, True
elif self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
# Set up backprop all-reduce.
input_parallel = reduce_backward(input_, self.process_group)
output_parallel = matmul_with_async_comm(
input_parallel,
self.weight,
bias,
self.process_group,
self.async_communication,
fp8_communication=self.fp8_communication,
)
else:
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")

if self.gather_output:
# All-gather across the partitions.
Expand Down Expand Up @@ -553,7 +562,7 @@ def forward(self, input_: Tensor) -> Tensor:
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
if self.seq_parallel_mode is None:
if self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
output_parallel = torch.matmul(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
elif self.seq_parallel_mode == "split_gather":
Expand All @@ -567,8 +576,12 @@ def forward(self, input_: Tensor) -> Tensor:
elif self.seq_parallel_mode == "ring":
output_parallel = torch.matmul(input_, self.weight)
output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, 1, self.fp8_communication
output_parallel,
self.process_group,
1,
)
else:
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")

if not self.skip_bias_add:
if self.bias is not None:
Expand Down
6 changes: 6 additions & 0 deletions colossalai/shardformer/layer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,9 @@ def split_batch_zigzag(
"""
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
if sp_size == 1:
return batch

if isinstance(batch, torch.Tensor):
batch = [batch]
seq_dim = seq_dim if seq_dim != -1 else batch[0].dim() - 1
Expand Down Expand Up @@ -364,6 +367,9 @@ def split_varlen_zigzag(
"""
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
if sp_size == 1:
return batch

if is_2d:
assert max_seqlen > 0, "max_seqlen must be provided for 2D input"

Expand Down
25 changes: 14 additions & 11 deletions colossalai/shardformer/modeling/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,14 +365,15 @@ def bloom_for_causal_lm_forward(
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states).contiguous()

loss = dist_cross_entropy(
labels,
lm_logits,
shard_config,
self.lm_head.out_features,
self.config.vocab_size,
self.transformer.dtype,
)
loss = None
if labels is not None:
loss = dist_cross_entropy(
labels,
lm_logits,
shard_config,
self.lm_head.out_features,
self.transformer.dtype,
)

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
Expand Down Expand Up @@ -1036,9 +1037,11 @@ def forward(
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)

loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
)
loss = None
if labels is not None:
loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype
)

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
Expand Down
Loading
Loading