Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/refactor process group #358

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions configs/57B_qwen2_MoE.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
expert parallel (dict):
1. size: int
* if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size
Expand All @@ -201,15 +200,14 @@
expert weight parallel (dict):
1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=1, overlap=True),
expert=dict(size=-1, no_tp=False),
expert_weight=dict(size=1, overlap=True, memory_pool=True),
expert_weight=dict(size=1, overlap=True),
)

cudnn_deterministic = False
Expand Down
6 changes: 2 additions & 4 deletions configs/8x22B_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
expert parallel (dict):
1. size: int
* if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size
Expand All @@ -202,15 +201,14 @@
expert weight parallel (dict):
1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=1, overlap=True),
expert=dict(size=-1, no_tp=False),
expert_weight=dict(size=1, overlap=True, memory_pool=True),
expert_weight=dict(size=1, overlap=True),
)

cudnn_deterministic = False
Expand Down
6 changes: 2 additions & 4 deletions configs/8x7B_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
expert parallel (dict):
1. size: int
* if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size
Expand All @@ -202,15 +201,14 @@
expert weight parallel (dict):
1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=1, overlap=True),
expert=dict(size=-1, no_tp=False),
expert_weight=dict(size=1, overlap=True, memory_pool=True),
expert_weight=dict(size=1, overlap=True),
)

cudnn_deterministic = False
Expand Down
110 changes: 58 additions & 52 deletions internlm/core/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,14 @@
from internlm.utils.timeout import LLM_NCCL_TIMEOUT
from internlm.utils.utils import TensorParallelMode

from . import process_group_initializer as pgroup_initializer
from .process_group_initializer import ParallelMode
from .process_group_initializer import (
GroupConfig,
ParallelMode,
create_parallel_process_groups,
create_single_process_group,
generate_2d_attn_process_group,
generate_parallel_group_configs,
)
from .random import add_seed, get_seeds, set_mode

# for layernorm
Expand Down Expand Up @@ -633,60 +639,60 @@ def init_parallel_groups(self):

self.check_sanity()

initializer_args = [
rank,
world_size,
self.weight_parallel_size,
self.weight_data_parallel_size,
self.sequence_parallel_size,
self.data_parallel_size,
self.pipeline_parallel_size,
self.tensor_parallel_size,
self.zero1_parallel_size,
self.nettest_parallel_size,
self.expert_parallel_size,
self.expert_tensor_parallel_size,
self.expert_weight_parallel_size,
self.expert_data_parallel_size,
parallel_config.sequence_2D,
]

# run initialization of different process groups
initializers = []
if "gqa" in parallel_config and parallel_config["gqa"] is True:
initializers.append(pgroup_initializer.Initializer_GQA(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Weight(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Weight_Data(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Data(*initializer_args))
initializers.append(pgroup_initializer.Initializer_ISP_Data(*initializer_args))
parallel_sizes = {
ParallelMode.TENSOR: self.tensor_parallel_size,
ParallelMode.SEQUENCE: self.sequence_parallel_size,
ParallelMode.PIPELINE: self.pipeline_parallel_size,
ParallelMode.DATA: self.data_parallel_size,
ParallelMode.ZERO1: self.zero1_parallel_size,
ParallelMode.WEIGHT: self.weight_parallel_size,
ParallelMode.WEIGHT_DATA: self.weight_data_parallel_size,
ParallelMode.NETTEST: self.nettest_parallel_size,
ParallelMode.EXPERT: self.expert_parallel_size,
ParallelMode.EXPERT_WEIGHT: self.expert_weight_parallel_size,
ParallelMode.EXPERT_TENSOR: self.expert_tensor_parallel_size,
ParallelMode.EXPERT_DATA: self.expert_data_parallel_size,
}

# process groups for parallelism.
enable_moe = self.config.model.get("num_experts", 1) > 1
tp_mode = "mtp" if isinstance(parallel_config.tensor, int) else parallel_config.tensor.get("mode", "mtp")
is_fsdp = False if isinstance(parallel_config.zero1, int) else parallel_config.zero1.get("fsdp", False)
parallel_strategy = "fsdp" if is_fsdp else tp_mode
group_configs = generate_parallel_group_configs(parallel_strategy, parallel_sizes, enable_moe)
group_results = create_parallel_process_groups(world_size, rank, group_configs, with_cpu_group=False)

# process group for extra gqa tensor parallel.
if (
isinstance(parallel_config["tensor"], dict)
and parallel_config["tensor"]["mode"] == TensorParallelMode.isp.name
"num_kv_attention_heads" in self.config.model
and self.config.model.num_kv_attention_heads < self.tensor_parallel_size
):
initializers.append(pgroup_initializer.Initializer_Zero1_ISP(*initializer_args))
else:
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
if isinstance(parallel_config["zero1"], dict) and parallel_config["zero1"].get("fsdp", False):
initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args))
if self.pipeline_parallel_size > 1:
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
if self.config.model.get("num_experts", 1) > 1:
if isinstance(parallel_config["tensor"], dict) and parallel_config["tensor"]["mode"] == "isp":
initializers.append(pgroup_initializer.Initializer_Expert_Weight_Data(*initializer_args))
else:
initializers.append(pgroup_initializer.Initializer_Expert_Data(*initializer_args))
group_results.append(
create_single_process_group(
world_size,
rank,
GroupConfig(ParallelMode.GQA, self.tensor_parallel_size // self.num_kv_attention_heads),
)
)

# process group for network test.
group_results.append(
create_single_process_group(
world_size,
rank,
GroupConfig(ParallelMode.NETTEST, self.nettest_parallel_size, allow_partial_group=True),
)
)

# process group for isp 2D attn.
if parallel_config.sequence_2D.get("enable", False) is True:
initializers.append(pgroup_initializer.Initializer_2D_SEQUENCE_PARALLEL(*initializer_args))
group_results.extend(
generate_2d_attn_process_group(world_size, rank, parallel_config.sequence_2D, parallel_sizes)
)

for initializer in initializers:
parallel_setting = initializer.init_dist_group()
if isinstance(parallel_setting, list):
for args in parallel_setting:
self._register_dist(*args)
else:
self._register_dist(*parallel_setting)
# register process groups
for result in group_results:
self._register_dist(*result)

def is_initialized(self, parallel_mode: ParallelMode):
"""Returns a boolean value indicating whether `parallel_mode` is initialized
Expand Down
Loading
Loading