Skip to content

Commit

Permalink
fix gpt2
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Aug 20, 2024
1 parent a374633 commit 63fd075
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,11 +925,13 @@ def forward(

# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward(
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
)
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = split_forward_gather_backward(
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
)

for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
# Model parallel
Expand Down

0 comments on commit 63fd075

Please sign in to comment.