Skip to content

Commit

Permalink
debug
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Aug 23, 2024
1 parent 8ad3d5b commit 6d5fc3a
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 10 deletions.
4 changes: 2 additions & 2 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,8 @@ def gpt2_model_forward(
hidden_states = gather_sp_output(
hidden_states,
sp_dim=1,
sp_group=shard_config.tensor_parallel_process_group,
sp_mode=shard_config.sequence_parallelism_mode,
sp_group=sp_group,
sp_mode=sp_mode,
)

# gather_sp_output could've changed seq length.
Expand Down
3 changes: 1 addition & 2 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,9 @@ def llama_model_forward(
if output_attentions:
all_self_attns += (layer_outputs[1],)

gather_output = (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode)
if disable_pp or stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)
if gather_output:
if (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode): # noqa
hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode)

# add hidden states from the last decoder layer
Expand Down
8 changes: 3 additions & 5 deletions colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,9 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
if self.pipeline_stage_manager is None:
self.append_or_create_method_replacement(
description={
"forward": get_llama_flash_attention_model_forward(
self.shard_config,
sp_mode=sp_mode,
sp_size=sp_size,
sp_group=sp_group,
"forward": partial(
LlamaPipelineForwards.llama_model_forward,
shard_config=self.shard_config,
),
},
policy=policy,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_shardformer/test_model/test_shard_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"num_microbatches": 2,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "fp16",
Expand Down

0 comments on commit 6d5fc3a

Please sign in to comment.