Skip to content

Commit

Permalink
[feature] merge feature/dist-optim and update;
Browse files Browse the repository at this point in the history
  • Loading branch information
duanjunwen committed Apr 4, 2024
1 parent f1ae04d commit 34e20c2
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 0 deletions.
99 changes: 99 additions & 0 deletions examples/language/grok-1/grok1_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Dict, Union

import torch.nn as nn

from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription


class Grok1Policy(Policy):
def config_sanity_check(self):
pass

def preprocess(self) -> nn.Module:
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
assert vocab_size % world_size == 0, f"vocab_size {vocab_size} must be divisible by world_size {world_size}"
return self.model

def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}
if self.shard_config.enable_tensor_parallelism:
decoder_attribute_replacement = {
"attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"attn.num_key_value_heads": self.model.config.num_key_value_heads
// self.shard_config.tensor_parallel_size,
}
decoder_submodule_replacement = [
SubModuleReplacementDescription(
suffix="attn.q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attn.k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attn.v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attn.o_proj",
target_module=Linear1D_Row,
),
]
for i in range(self.model.config.num_experts):
decoder_submodule_replacement.extend(
[
SubModuleReplacementDescription(
suffix=f"moe_block.experts[{i}].linear",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix=f"moe_block.experts[{i}].linear_v",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix=f"moe_block.experts[{i}].linear_1",
target_module=Linear1D_Row,
),
]
)

policy["DecoderLayer"] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=decoder_submodule_replacement,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
),
policy=policy,
target_key="Grok1Model",
)
return policy

def postprocess(self):
return self.model


class Grok1ModelPolicy(Grok1Policy):
pass


class Grok1ForCausalLMPolicy(Grok1Policy):
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = super().module_policy()
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs={"gather_output": not self.shard_config.parallel_output},
),
policy=policy,
target_key="Grok1ModelForCausalLM",
)
return policy
1 change: 1 addition & 0 deletions examples/language/grok-1/test_ci.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pip install -r requirements.txt

0 comments on commit 34e20c2

Please sign in to comment.