Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM #5854

Closed
wants to merge 38 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
2793854
Add Ulysses SP support for Qwen2
GuangyaoZhang Jun 24, 2024
dd8b5ec
Add Ulysses SP support for ChatGLM
GuangyaoZhang Jun 25, 2024
4b1ce24
Add Ulysses SP support for Command-R
GuangyaoZhang Jun 25, 2024
5c5fd30
Fix pytest typo
GuangyaoZhang Jun 25, 2024
2a25a2a
[Feature] optimize PP overlap (#5735)
Edenzzzz Jun 26, 2024
8e718a1
[gemini] fixes for benchmarking (#5847)
botbw Jun 26, 2024
5dfbcd7
[zero] use bucket during allgather (#5860)
ver217 Jun 27, 2024
d9d5e7e
[shardformer] Support the T5ForTokenClassification model (#5816)
GuangyaoZhang Jun 27, 2024
3c7cda0
[Inference]Lazy Init Support (#5785)
LRY89757 Jun 27, 2024
eaea88c
[release] update version (#5864)
ver217 Jun 28, 2024
773d9f9
[shardformer]delete xformers (#5859)
flybird11111 Jun 28, 2024
416580b
[MoE/ZeRO] Moe refactor with zero refactor (#5821)
Hz188 Jun 28, 2024
3dc0d1d
ChatGLM, Qwen2, Command-R Support SP+PP together
GuangyaoZhang Jun 28, 2024
f9d544b
remove unnecessary pytest
GuangyaoZhang Jun 30, 2024
8ab46b4
[Shardformer] change qwen2 modeling into gradient checkpointing styl…
CjhHa1 Jul 1, 2024
936d0b0
[doc] Update llama + sp compatibility; fix dist optim table
Edenzzzz Jul 1, 2024
7c2f79f
[pre-commit.ci] pre-commit autoupdate (#5572)
pre-commit-ci[bot] Jul 1, 2024
ea94c07
[hotfix] fix the bug that large tensor exceed the maximum capacity of…
Hz188 Jul 2, 2024
133bbd5
revert some exchange to avoid misunderstanding caused by git diff
GuangyaoZhang Jul 3, 2024
eb24fcd
[Hotfix] Fix OPT gradient checkpointing forward
Edenzzzz Jul 3, 2024
6cd4c32
[shardformer] fix the moe (#5883)
wangbluo Jul 3, 2024
7afbc81
[quant] fix bitsandbytes version check (#5882)
ver217 Jul 4, 2024
7997683
[pre-commit.ci] pre-commit autoupdate (#5878)
pre-commit-ci[bot] Jul 4, 2024
3420921
[shardformer] DeepseekMoE support (#5871)
Hz188 Jul 5, 2024
8ec24b6
[Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Edenzzzz Jul 5, 2024
cba2052
[Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
LRY89757 Jul 8, 2024
392933a
ChatGLM sp with pp redundance removal
GuangyaoZhang Jul 8, 2024
66abf1c
[HotFix] CI,import,requirements-test for #5838 (#5892)
LRY89757 Jul 8, 2024
b554515
Add Ulysses SP support for Qwen2
GuangyaoZhang Jun 24, 2024
f5aa99b
Add Ulysses SP support for ChatGLM
GuangyaoZhang Jun 25, 2024
8cbb469
Add Ulysses SP support for Command-R
GuangyaoZhang Jun 25, 2024
6b5cf33
Fix pytest typo
GuangyaoZhang Jun 25, 2024
9861cd2
ChatGLM, Qwen2, Command-R Support SP+PP together
GuangyaoZhang Jun 28, 2024
2218792
remove unnecessary pytest
GuangyaoZhang Jun 30, 2024
7334a5b
revert some exchange to avoid misunderstanding caused by git diff
GuangyaoZhang Jul 3, 2024
0ae41f5
ChatGLM sp with pp redundance removal
GuangyaoZhang Jul 8, 2024
82164a7
Merge branch 'sp' of github.com:GuangyaoZhang/ColossalAI into sp
GuangyaoZhang Jul 9, 2024
64359a6
Empty Commit to trigger build on PR
GuangyaoZhang Jul 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def backward(ctx, grad_output):
if use_bias:
bias.view(bias.shape)

total_input = input
total_input = input.contiguous()
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
Expand Down
221 changes: 209 additions & 12 deletions colossalai/shardformer/modeling/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_forward_split_backward,
split_forward_gather_backward,
)


def get_flash_core_attention_forward():
Expand Down Expand Up @@ -203,6 +207,13 @@ def chatglm_model_forward(
dim=0,
process_group=shard_config.tensor_parallel_process_group,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
hidden_states,
dim=0,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=1 / shard_config.sequence_parallel_size,
)
for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx)
if output_hidden_states:
Expand Down Expand Up @@ -235,6 +246,13 @@ def chatglm_model_forward(
dim=0,
process_group=shard_config.tensor_parallel_process_group,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=0,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if stage_manager.is_last_stage():
Expand Down Expand Up @@ -329,7 +347,9 @@ def chatglm_for_conditional_generation_forward(
return transformer_outputs


def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode, sp_size, sp_group):
logger = logging.get_logger(__name__)

def forward(
self,
input_ids,
Expand Down Expand Up @@ -381,13 +401,27 @@ def forward(
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()

if sp_mode in ["all_to_all"] and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with sp mode `{sp_mode}`. Setting `use_cache=False`..."
)
use_cache = False
# Run encoder.
# [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size]
inputs_embeds = split_forward_gather_backward(
inputs_embeds,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
)
if sp_mode in ["split_gather"]:
inputs_embeds = split_forward_gather_backward(
inputs_embeds,
dim=0,
process_group=sp_group,
)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(
inputs_embeds,
dim=0,
process_group=sp_group,
grad_scale=1 / sp_size,
)
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
inputs_embeds,
full_attention_mask,
Expand All @@ -397,11 +431,19 @@ def forward(
output_hidden_states=output_hidden_states,
)

hidden_states = gather_forward_split_backward(
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
)
if sp_mode in ["split_gather"]:
hidden_states = gather_forward_split_backward(
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=0,
process_group=sp_group,
grad_scale=sp_size,
)

if not return_dict:
return tuple(
Expand All @@ -423,3 +465,158 @@ def forward(
)

return forward


def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, sp_mode, sp_size, sp_group):
from .chatglm2_6b.modeling_chatglm import apply_rotary_pos_emb, split_tensor_along_last_dim

def forward(
self,
hidden_states,
attention_mask,
rotary_pos_emb,
kv_cache=None,
use_cache=True,
):
if sp_mode is not None:
assert sp_mode in ["all_to_all", "split_gather"], "Invalid sp_mode"
assert (sp_size is not None) and (
sp_group is not None
), "Must specify sp_size and sp_group for sequence parallel"

mixed_x_layer = self.query_key_value(hidden_states)
if self.multi_query_attention:
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
[
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
],
dim=-1,
)
query_layer = query_layer.view(
query_layer.size()[:-1]
+ (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
)
key_layer = key_layer.view(
key_layer.size()[:-1]
+ (
self.num_multi_query_groups_per_partition,
self.hidden_size_per_attention_head,
)
)
value_layer = value_layer.view(
value_layer.size()[:-1]
+ (
self.num_multi_query_groups_per_partition,
self.hidden_size_per_attention_head,
)
)
else:
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)

# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
sq, bs, _, _ = value_layer.size()

query_layer = query_layer.reshape(sq, bs, -1)
key_layer = key_layer.reshape(sq, bs, -1)
value_layer = value_layer.reshape(sq, bs, -1)

query_layer = all_to_all_comm(query_layer, sp_group, gather_dim=0)
key_layer = all_to_all_comm(key_layer, sp_group, gather_dim=0)
value_layer = all_to_all_comm(value_layer, sp_group, gather_dim=0)

query_layer = query_layer.view(
sq * sp_size,
bs,
self.num_attention_heads_per_partition // sp_size,
self.hidden_size_per_attention_head,
).contiguous()

key_layer = key_layer.view(
sq * sp_size,
bs,
self.num_attention_heads_per_partition // sp_size,
self.hidden_size_per_attention_head,
).contiguous()

value_layer = value_layer.view(
sq * sp_size,
bs,
self.num_attention_heads_per_partition // sp_size,
self.hidden_size_per_attention_head,
).contiguous()

# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)

# adjust key and value for inference
if kv_cache is not None:
cache_k, cache_v = kv_cache
key_layer = torch.cat((cache_k, key_layer), dim=0)
value_layer = torch.cat((cache_v, value_layer), dim=0)
if use_cache:
kv_cache = (key_layer, value_layer)
else:
kv_cache = None

if self.multi_query_attention:
key_layer = key_layer.unsqueeze(-2)
key_layer = key_layer.expand(
-1,
-1,
-1,
self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition,
-1,
)
key_layer = key_layer.contiguous().view(
key_layer.size()[:2]
+ (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
)
value_layer = value_layer.unsqueeze(-2)
value_layer = value_layer.expand(
-1,
-1,
-1,
self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition,
-1,
)
value_layer = value_layer.contiguous().view(
value_layer.size()[:2]
+ (
self.num_attention_heads_per_partition // sp_size,
self.hidden_size_per_attention_head,
)
)

# ==================================
# core attention computation
# ==================================

context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
if sp_mode == "all_to_all":
context_layer = all_to_all_comm(context_layer, sp_group, gather_dim=2, scatter_dim=0)

# =================
# Output. [sq, b, h]
# =================
output = self.dense(context_layer)

return output, kv_cache

return forward
30 changes: 30 additions & 0 deletions colossalai/shardformer/modeling/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,21 @@ def command_model_forward(
)
use_cache = False

if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
hidden_states = split_forward_gather_backward(
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=1 / shard_config.sequence_parallel_size,
)

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
Expand Down Expand Up @@ -191,6 +206,21 @@ def command_model_forward(
if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)

if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
hidden_states = gather_forward_split_backward(
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
)

# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
Expand Down
Loading
Loading