-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement * [misc] fix typo, remove redundant file (#5867) * [misc] fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] deepseek support & unit test * [misc] remove debug code & useless print * [misc] fix typos (#5872) * [Feature] remove modeling file, use auto config. (#5884) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [Deepseek] remove redundant code (#5888) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [misc] remove redundant code * [Feature/deepseek] resolve comment. (#5889) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [misc] remove redundant code * [misc] mv module replacement into if branch * [misc] add some warning message and modify some code in unit test * [misc] fix typos --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
7997683
commit 3420921
Showing
7 changed files
with
748 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
import warnings | ||
from functools import partial | ||
from typing import Callable, Dict, List, Union | ||
|
||
import torch.nn as nn | ||
from torch import Tensor | ||
from torch.nn import Module | ||
|
||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col | ||
from colossalai.shardformer.modeling.deepseek import DeepseekPipelineForwards, EPDeepseekMoE | ||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription | ||
|
||
__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"] | ||
|
||
|
||
class DeepseekPolicy(Policy): | ||
def config_sanity_check(self): | ||
pass | ||
|
||
def preprocess(self): | ||
if self.shard_config.enable_tensor_parallelism: | ||
# Resize embedding | ||
vocab_size = self.model.config.vocab_size | ||
world_size = self.shard_config.tensor_parallel_size | ||
|
||
if vocab_size % world_size != 0: | ||
new_vocab_size = vocab_size + world_size - vocab_size % world_size | ||
self.model.resize_token_embeddings(new_vocab_size) | ||
|
||
return self.model | ||
|
||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: | ||
policy = {} | ||
|
||
if self.shard_config.enable_sequence_parallelism: | ||
self.shard_config.enable_sequence_parallelism = False | ||
raise NotImplementedError( | ||
"Deepseek dosen't support sequence parallelism now, will ignore the sequence parallelism flag." | ||
) | ||
|
||
if self.shard_config.enable_tensor_parallelism: | ||
raise NotImplementedError("Tensor parallelism is not supported for Deepseek model now.") | ||
|
||
if getattr(self.shard_config, "ep_group", None) is not None: | ||
# expert parallel | ||
self.append_or_create_submodule_replacement( | ||
description=[ | ||
SubModuleReplacementDescription( | ||
suffix="mlp", | ||
target_module=EPDeepseekMoE, | ||
kwargs={"ep_group": self.shard_config.ep_group}, | ||
) | ||
], | ||
policy=policy, | ||
target_key="DeepseekDecoderLayer", | ||
) | ||
|
||
# optimization configuration | ||
if self.shard_config.enable_fused_normalization: | ||
self.append_or_create_submodule_replacement( | ||
description=[ | ||
SubModuleReplacementDescription( | ||
suffix="input_layernorm", | ||
target_module=FusedRMSNorm, | ||
), | ||
SubModuleReplacementDescription( | ||
suffix="post_attention_layernorm", | ||
target_module=FusedRMSNorm, | ||
), | ||
], | ||
policy=policy, | ||
target_key="DeepseekDecoderLayer", | ||
) | ||
|
||
self.append_or_create_submodule_replacement( | ||
description=SubModuleReplacementDescription( | ||
suffix="norm", | ||
target_module=FusedRMSNorm, | ||
), | ||
policy=policy, | ||
target_key="DeepseekModel", | ||
) | ||
|
||
if self.shard_config.enable_flash_attention: | ||
warnings.warn( | ||
"Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False." | ||
) | ||
self.shard_config.enable_flash_attention = False | ||
|
||
return policy | ||
|
||
def postprocess(self): | ||
return self.model | ||
|
||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: | ||
"""If under pipeline parallel setting, replacing the original forward method of huggingface | ||
to customized forward method, and add this changing to policy.""" | ||
if self.pipeline_stage_manager: | ||
stage_manager = self.pipeline_stage_manager | ||
if self.model.__class__.__name__ == "DeepseekModel": | ||
module = self.model | ||
else: | ||
module = self.model.model | ||
|
||
layers_per_stage = stage_manager.distribute_layers(len(module.layers)) | ||
stage_index = stage_manager.get_stage_index(layers_per_stage) | ||
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} | ||
self.append_or_create_method_replacement( | ||
description=method_replacement, policy=policy, target_key=model_cls | ||
) | ||
|
||
return | ||
|
||
def get_held_layers(self) -> List[Module]: | ||
"""Get pipeline layers for current stage.""" | ||
assert self.pipeline_stage_manager is not None | ||
|
||
if self.model.__class__.__name__ == "DeepseekModel": | ||
module = self.model | ||
else: | ||
module = self.model.model | ||
stage_manager = self.pipeline_stage_manager | ||
|
||
held_layers = [] | ||
layers_per_stage = stage_manager.distribute_layers(len(module.layers)) | ||
if stage_manager.is_first_stage(): | ||
held_layers.append(module.embed_tokens) | ||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) | ||
held_layers.extend(module.layers[start_idx:end_idx]) | ||
if stage_manager.is_last_stage(): | ||
held_layers.append(module.norm) | ||
|
||
return held_layers | ||
|
||
|
||
class DeepseekModelPolicy(DeepseekPolicy): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
def module_policy(self): | ||
policy = super().module_policy() | ||
if self.pipeline_stage_manager: | ||
# set None as default | ||
self.set_pipeline_forward( | ||
model_cls="DeepseekModel", | ||
new_forward=DeepseekPipelineForwards.deepseek_model_forward, | ||
policy=policy, | ||
) | ||
return policy | ||
|
||
def get_held_layers(self) -> List[Module]: | ||
"""Get pipeline layers for current stage.""" | ||
held_layers = super().get_held_layers() | ||
return held_layers | ||
|
||
def get_shared_params(self) -> List[Dict[int, Tensor]]: | ||
"""No shared params in llama model""" | ||
return [] | ||
|
||
|
||
class DeepseekForCausalLMPolicy(DeepseekPolicy): | ||
def module_policy(self): | ||
policy = super().module_policy() | ||
# TODO: assign pg mesh from plugin to all modules | ||
if self.shard_config.enable_tensor_parallelism: | ||
# add a new item for casual lm | ||
new_item = { | ||
"DeepseekForCausalLM": ModulePolicyDescription( | ||
sub_module_replacement=[ | ||
SubModuleReplacementDescription( | ||
suffix="lm_head", | ||
target_module=Linear1D_Col, | ||
kwargs=dict(gather_output=True), | ||
) | ||
] | ||
) | ||
} | ||
policy.update(new_item) | ||
|
||
if self.pipeline_stage_manager: | ||
# set None as default | ||
self.set_pipeline_forward( | ||
model_cls="DeepseekForCausalLM", | ||
new_forward=DeepseekPipelineForwards.deepseek_for_causal_lm_forward, | ||
policy=policy, | ||
) | ||
|
||
return policy | ||
|
||
def get_held_layers(self) -> List[Module]: | ||
"""Get pipeline layers for current stage.""" | ||
stage_manager = self.pipeline_stage_manager | ||
held_layers = super().get_held_layers() | ||
if stage_manager.is_last_stage(): | ||
held_layers.append(self.model.lm_head) | ||
return held_layers | ||
|
||
def get_shared_params(self) -> List[Dict[int, Tensor]]: | ||
deepseek_model = self.model.model | ||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: | ||
if ( | ||
id(deepseek_model.embed_tokens.weight) == id(self.model.lm_head.weight) | ||
and self.pipeline_stage_manager.num_stages > 1 | ||
): | ||
# tie weights | ||
return [ | ||
{ | ||
0: deepseek_model.embed_tokens.weight, | ||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, | ||
} | ||
] | ||
return [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from copy import deepcopy | ||
|
||
import pytest | ||
import torch | ||
import torch.distributed as dist | ||
from torch.testing import assert_close | ||
from transformers import AutoConfig, AutoModel | ||
|
||
import colossalai | ||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin | ||
from colossalai.shardformer.modeling.deepseek import EPDeepseekMoE | ||
from colossalai.testing.utils import spawn | ||
|
||
tokens, n_experts = 7, 4 | ||
hidden_size = 8 | ||
top_k = 2 | ||
|
||
|
||
def check_deepseek_moe_layer(): | ||
torch.cuda.set_device(dist.get_rank()) | ||
plugin = MoeHybridParallelPlugin( | ||
precision="bf16", | ||
tp_size=1, | ||
pp_size=1, | ||
ep_size=dist.get_world_size(), | ||
) | ||
|
||
config = AutoConfig.from_pretrained( | ||
"deepseek-ai/deepseek-moe-16b-base", | ||
num_hidden_layers=1, | ||
n_routed_experts=n_experts, | ||
num_experts_per_tok=top_k, | ||
hidden_size=hidden_size, | ||
intermediate_size=hidden_size * 2, | ||
first_k_dense_replace=0, | ||
num_attention_heads=2, | ||
trust_remote_code=True, | ||
) | ||
torch.manual_seed(0) | ||
# get the moe layer in auto model | ||
orig_model = AutoModel.from_config(config, trust_remote_code=True).layers[0].mlp.cuda() | ||
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda() | ||
orig_output = orig_model(x) | ||
model = deepcopy(orig_model) | ||
model = EPDeepseekMoE.from_native_module(model, ep_group=plugin.ep_group) | ||
ep_output = model(x) | ||
assert_close(orig_output, ep_output) | ||
orig_loss = orig_output.mean() | ||
orig_loss.backward() | ||
ep_loss = ep_output.mean() | ||
ep_loss.backward() | ||
assert_close(orig_loss, ep_loss) | ||
name_to_p = {n: p for n, p in orig_model.named_parameters()} | ||
for n, ep_p in model.named_parameters(): | ||
p = name_to_p[n] | ||
if ep_p.grad is not None: | ||
assert_close(p.grad, ep_p.grad) | ||
|
||
|
||
def run_dist(rank: int, world_size: int, port: int): | ||
colossalai.launch(rank, world_size, "localhost", port) | ||
check_deepseek_moe_layer() | ||
|
||
|
||
# @pytest.mark.parametrize("world_size", [2, 4]) | ||
@pytest.mark.parametrize("world_size", [2]) | ||
def test_deepseek_moe_layer(world_size: int): | ||
spawn(run_dist, world_size) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_deepseek_moe_layer(2) |
Oops, something went wrong.