Skip to content

Commit

Permalink
[shardformer] DeepseekMoE support (#5871)
Browse files Browse the repository at this point in the history
* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file (#5867)

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos (#5872)

* [Feature] remove modeling file, use auto config. (#5884)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code (#5888)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. (#5889)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Hz188 and pre-commit-ci[bot] authored Jul 5, 2024
1 parent 7997683 commit 3420921
Show file tree
Hide file tree
Showing 7 changed files with 748 additions and 19 deletions.
2 changes: 1 addition & 1 deletion colossalai/cluster/process_group_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) ->
ProcessGroup: The process group with the given ranks.
"""
ranks_in_group = sorted(ranks_in_group)
if tuple(ranks_in_group) not in self._group_to_ranks:
if tuple(ranks_in_group) not in self._ranks_to_group:
group = dist.new_group(ranks_in_group, backend=backend)
self._ranks_to_group[tuple(ranks_in_group)] = group
self._group_to_ranks[group] = tuple(ranks_in_group)
Expand Down
429 changes: 429 additions & 0 deletions colossalai/shardformer/modeling/deepseek.py

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion colossalai/shardformer/policies/auto_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ class PolicyLocation:
"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
),
# Deepseek
"transformers_modules.modeling_deepseek.DeepSeekModel": PolicyLocation(
file_name="deepseek", class_name="DeepseekModelPolicy"
),
"transformers_modules.modeling_deepseek.DeepseekForCausalLM": PolicyLocation(
file_name="deepseek", class_name="DeepseekForCausalLMPolicy"
),
# Falcon
"transformers.models.falcon.modeling_falcon.FalconModel": PolicyLocation(
file_name="falcon", class_name="FalconModelPolicy"
Expand Down Expand Up @@ -252,7 +259,6 @@ def get_autopolicy(model: nn.Module) -> Policy:
"""
full_name = _fullname(model)
policy_location = _POLICY_LIST.get(full_name, None)

if policy_location is None:
raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} ({full_name}) is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
Expand Down
212 changes: 212 additions & 0 deletions colossalai/shardformer/policies/deepseek.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Union

import torch.nn as nn
from torch import Tensor
from torch.nn import Module

from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.modeling.deepseek import DeepseekPipelineForwards, EPDeepseekMoE
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"]


class DeepseekPolicy(Policy):
def config_sanity_check(self):
pass

def preprocess(self):
if self.shard_config.enable_tensor_parallelism:
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size

if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)

return self.model

def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}

if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
raise NotImplementedError(
"Deepseek dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
)

if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for Deepseek model now.")

if getattr(self.shard_config, "ep_group", None) is not None:
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="mlp",
target_module=EPDeepseekMoE,
kwargs={"ep_group": self.shard_config.ep_group},
)
],
policy=policy,
target_key="DeepseekDecoderLayer",
)

# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
),
],
policy=policy,
target_key="DeepseekDecoderLayer",
)

self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key="DeepseekModel",
)

if self.shard_config.enable_flash_attention:
warnings.warn(
"Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False."
)
self.shard_config.enable_flash_attention = False

return policy

def postprocess(self):
return self.model

def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "DeepseekModel":
module = self.model
else:
module = self.model.model

layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)

return

def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None

if self.model.__class__.__name__ == "DeepseekModel":
module = self.model
else:
module = self.model.model
stage_manager = self.pipeline_stage_manager

held_layers = []
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)

return held_layers


class DeepseekModelPolicy(DeepseekPolicy):
def __init__(self) -> None:
super().__init__()

def module_policy(self):
policy = super().module_policy()
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls="DeepseekModel",
new_forward=DeepseekPipelineForwards.deepseek_model_forward,
policy=policy,
)
return policy

def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
held_layers = super().get_held_layers()
return held_layers

def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama model"""
return []


class DeepseekForCausalLMPolicy(DeepseekPolicy):
def module_policy(self):
policy = super().module_policy()
# TODO: assign pg mesh from plugin to all modules
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
"DeepseekForCausalLM": ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True),
)
]
)
}
policy.update(new_item)

if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls="DeepseekForCausalLM",
new_forward=DeepseekPipelineForwards.deepseek_for_causal_lm_forward,
policy=policy,
)

return policy

def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers

def get_shared_params(self) -> List[Dict[int, Tensor]]:
deepseek_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (
id(deepseek_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1
):
# tie weights
return [
{
0: deepseek_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}
]
return []
6 changes: 3 additions & 3 deletions colossalai/shardformer/policies/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,16 +192,16 @@ def get_held_layers(self) -> List[Module]:
return held_layers

def get_shared_params(self) -> List[Dict[int, Tensor]]:
llama_model = self.model.model
mixtral_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
id(mixtral_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1
):
# tie weights
return [
{
0: llama_model.embed_tokens.weight,
0: mixtral_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}
]
Expand Down
72 changes: 72 additions & 0 deletions tests/test_moe/test_deepseek_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from copy import deepcopy

import pytest
import torch
import torch.distributed as dist
from torch.testing import assert_close
from transformers import AutoConfig, AutoModel

import colossalai
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.shardformer.modeling.deepseek import EPDeepseekMoE
from colossalai.testing.utils import spawn

tokens, n_experts = 7, 4
hidden_size = 8
top_k = 2


def check_deepseek_moe_layer():
torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1,
pp_size=1,
ep_size=dist.get_world_size(),
)

config = AutoConfig.from_pretrained(
"deepseek-ai/deepseek-moe-16b-base",
num_hidden_layers=1,
n_routed_experts=n_experts,
num_experts_per_tok=top_k,
hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
first_k_dense_replace=0,
num_attention_heads=2,
trust_remote_code=True,
)
torch.manual_seed(0)
# get the moe layer in auto model
orig_model = AutoModel.from_config(config, trust_remote_code=True).layers[0].mlp.cuda()
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
orig_output = orig_model(x)
model = deepcopy(orig_model)
model = EPDeepseekMoE.from_native_module(model, ep_group=plugin.ep_group)
ep_output = model(x)
assert_close(orig_output, ep_output)
orig_loss = orig_output.mean()
orig_loss.backward()
ep_loss = ep_output.mean()
ep_loss.backward()
assert_close(orig_loss, ep_loss)
name_to_p = {n: p for n, p in orig_model.named_parameters()}
for n, ep_p in model.named_parameters():
p = name_to_p[n]
if ep_p.grad is not None:
assert_close(p.grad, ep_p.grad)


def run_dist(rank: int, world_size: int, port: int):
colossalai.launch(rank, world_size, "localhost", port)
check_deepseek_moe_layer()


# @pytest.mark.parametrize("world_size", [2, 4])
@pytest.mark.parametrize("world_size", [2])
def test_deepseek_moe_layer(world_size: int):
spawn(run_dist, world_size)


if __name__ == "__main__":
test_deepseek_moe_layer(2)
Loading

0 comments on commit 3420921

Please sign in to comment.