Skip to content

Commit

Permalink
tests passed
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Apr 30, 2024
1 parent 85ba7ab commit ff819a5
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
1 change: 0 additions & 1 deletion colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,6 @@ def configure(
shard_config=self.shard_config,
dp_group=self.global_dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group, # TODO: add ep group. Modify shard_config to assign pg in policy
use_ddp=use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
Expand Down
8 changes: 4 additions & 4 deletions colossalai/shardformer/policies/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
else:
module = self.model.model

layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_index = stage_manager.get_stage_index(layers_per_stage)
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
Expand All @@ -131,10 +131,10 @@ def get_held_layers(self) -> List[Module]:
stage_manager = self.pipeline_stage_manager

held_layers = []
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
Expand Down

0 comments on commit ff819a5

Please sign in to comment.