diff --git a/configs/57B_qwen2_MoE.py b/configs/57B_qwen2_MoE.py index 0fd67603..abfb0a5b 100644 --- a/configs/57B_qwen2_MoE.py +++ b/configs/57B_qwen2_MoE.py @@ -190,7 +190,6 @@ weight parallel (dict): 1. size: int, the size of weight parallel. 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. - 3. memory_pool: bool, enable/disable memory pool, defaults to False. expert parallel (dict): 1. size: int * if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size @@ -201,15 +200,14 @@ expert weight parallel (dict): 1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size. 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. - 3. memory_pool: bool, enable/disable memory pool, defaults to False. """ parallel = dict( zero1=dict(size=-1, fsdp=False), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), - weight=dict(size=1, overlap=True, memory_pool=True), + weight=dict(size=1, overlap=True), expert=dict(size=-1, no_tp=False), - expert_weight=dict(size=1, overlap=True, memory_pool=True), + expert_weight=dict(size=1, overlap=True), ) cudnn_deterministic = False diff --git a/configs/8x22B_mixtral.py b/configs/8x22B_mixtral.py index 56206bd4..debd423b 100644 --- a/configs/8x22B_mixtral.py +++ b/configs/8x22B_mixtral.py @@ -191,7 +191,6 @@ weight parallel (dict): 1. size: int, the size of weight parallel. 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. - 3. memory_pool: bool, enable/disable memory pool, defaults to False. expert parallel (dict): 1. size: int * if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size @@ -202,15 +201,14 @@ expert weight parallel (dict): 1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size. 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. - 3. memory_pool: bool, enable/disable memory pool, defaults to False. """ parallel = dict( zero1=dict(size=-1, fsdp=False), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), - weight=dict(size=1, overlap=True, memory_pool=True), + weight=dict(size=1, overlap=True), expert=dict(size=-1, no_tp=False), - expert_weight=dict(size=1, overlap=True, memory_pool=True), + expert_weight=dict(size=1, overlap=True), ) cudnn_deterministic = False diff --git a/configs/8x7B_mixtral.py b/configs/8x7B_mixtral.py index f589c967..322342ea 100644 --- a/configs/8x7B_mixtral.py +++ b/configs/8x7B_mixtral.py @@ -191,7 +191,6 @@ weight parallel (dict): 1. size: int, the size of weight parallel. 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. - 3. memory_pool: bool, enable/disable memory pool, defaults to False. expert parallel (dict): 1. size: int * if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size @@ -202,15 +201,14 @@ expert weight parallel (dict): 1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size. 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. - 3. memory_pool: bool, enable/disable memory pool, defaults to False. """ parallel = dict( zero1=dict(size=-1, fsdp=False), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), - weight=dict(size=1, overlap=True, memory_pool=True), + weight=dict(size=1, overlap=True), expert=dict(size=-1, no_tp=False), - expert_weight=dict(size=1, overlap=True, memory_pool=True), + expert_weight=dict(size=1, overlap=True), ) cudnn_deterministic = False diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 989b1c00..8d74e3c6 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -21,8 +21,14 @@ from internlm.utils.timeout import LLM_NCCL_TIMEOUT from internlm.utils.utils import TensorParallelMode -from . import process_group_initializer as pgroup_initializer -from .process_group_initializer import ParallelMode +from .process_group_initializer import ( + GroupConfig, + ParallelMode, + create_parallel_process_groups, + create_single_process_group, + generate_2d_attn_process_group, + generate_parallel_group_configs, +) from .random import add_seed, get_seeds, set_mode # for layernorm @@ -633,60 +639,60 @@ def init_parallel_groups(self): self.check_sanity() - initializer_args = [ - rank, - world_size, - self.weight_parallel_size, - self.weight_data_parallel_size, - self.sequence_parallel_size, - self.data_parallel_size, - self.pipeline_parallel_size, - self.tensor_parallel_size, - self.zero1_parallel_size, - self.nettest_parallel_size, - self.expert_parallel_size, - self.expert_tensor_parallel_size, - self.expert_weight_parallel_size, - self.expert_data_parallel_size, - parallel_config.sequence_2D, - ] - - # run initialization of different process groups - initializers = [] - if "gqa" in parallel_config and parallel_config["gqa"] is True: - initializers.append(pgroup_initializer.Initializer_GQA(*initializer_args)) - initializers.append(pgroup_initializer.Initializer_Weight(*initializer_args)) - initializers.append(pgroup_initializer.Initializer_Weight_Data(*initializer_args)) - initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args)) - initializers.append(pgroup_initializer.Initializer_Data(*initializer_args)) - initializers.append(pgroup_initializer.Initializer_ISP_Data(*initializer_args)) + parallel_sizes = { + ParallelMode.TENSOR: self.tensor_parallel_size, + ParallelMode.SEQUENCE: self.sequence_parallel_size, + ParallelMode.PIPELINE: self.pipeline_parallel_size, + ParallelMode.DATA: self.data_parallel_size, + ParallelMode.ZERO1: self.zero1_parallel_size, + ParallelMode.WEIGHT: self.weight_parallel_size, + ParallelMode.WEIGHT_DATA: self.weight_data_parallel_size, + ParallelMode.NETTEST: self.nettest_parallel_size, + ParallelMode.EXPERT: self.expert_parallel_size, + ParallelMode.EXPERT_WEIGHT: self.expert_weight_parallel_size, + ParallelMode.EXPERT_TENSOR: self.expert_tensor_parallel_size, + ParallelMode.EXPERT_DATA: self.expert_data_parallel_size, + } + + # process groups for parallelism. + enable_moe = self.config.model.get("num_experts", 1) > 1 + tp_mode = "mtp" if isinstance(parallel_config.tensor, int) else parallel_config.tensor.get("mode", "mtp") + is_fsdp = False if isinstance(parallel_config.zero1, int) else parallel_config.zero1.get("fsdp", False) + parallel_strategy = "fsdp" if is_fsdp else tp_mode + group_configs = generate_parallel_group_configs(parallel_strategy, parallel_sizes, enable_moe) + group_results = create_parallel_process_groups(world_size, rank, group_configs, with_cpu_group=False) + + # process group for extra gqa tensor parallel. if ( - isinstance(parallel_config["tensor"], dict) - and parallel_config["tensor"]["mode"] == TensorParallelMode.isp.name + "num_kv_attention_heads" in self.config.model + and self.config.model.num_kv_attention_heads < self.tensor_parallel_size ): - initializers.append(pgroup_initializer.Initializer_Zero1_ISP(*initializer_args)) - else: - initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args)) - if isinstance(parallel_config["zero1"], dict) and parallel_config["zero1"].get("fsdp", False): - initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args)) - initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args)) - if self.pipeline_parallel_size > 1: - initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args)) - if self.config.model.get("num_experts", 1) > 1: - if isinstance(parallel_config["tensor"], dict) and parallel_config["tensor"]["mode"] == "isp": - initializers.append(pgroup_initializer.Initializer_Expert_Weight_Data(*initializer_args)) - else: - initializers.append(pgroup_initializer.Initializer_Expert_Data(*initializer_args)) + group_results.append( + create_single_process_group( + world_size, + rank, + GroupConfig(ParallelMode.GQA, self.tensor_parallel_size // self.num_kv_attention_heads), + ) + ) + + # process group for network test. + group_results.append( + create_single_process_group( + world_size, + rank, + GroupConfig(ParallelMode.NETTEST, self.nettest_parallel_size, allow_partial_group=True), + ) + ) + + # process group for isp 2D attn. if parallel_config.sequence_2D.get("enable", False) is True: - initializers.append(pgroup_initializer.Initializer_2D_SEQUENCE_PARALLEL(*initializer_args)) + group_results.extend( + generate_2d_attn_process_group(world_size, rank, parallel_config.sequence_2D, parallel_sizes) + ) - for initializer in initializers: - parallel_setting = initializer.init_dist_group() - if isinstance(parallel_setting, list): - for args in parallel_setting: - self._register_dist(*args) - else: - self._register_dist(*parallel_setting) + # register process groups + for result in group_results: + self._register_dist(*result) def is_initialized(self, parallel_mode: ParallelMode): """Returns a boolean value indicating whether `parallel_mode` is initialized diff --git a/internlm/core/context/process_group_initializer.py b/internlm/core/context/process_group_initializer.py index fbc3e07a..5313ad92 100644 --- a/internlm/core/context/process_group_initializer.py +++ b/internlm/core/context/process_group_initializer.py @@ -6,11 +6,16 @@ import math from abc import ABC, abstractmethod from enum import Enum +from functools import reduce +from typing import Any, Dict, List, Optional, Tuple, Union import torch.distributed as dist +from internlm.utils.logger import get_logger from internlm.utils.timeout import LLM_NCCL_TIMEOUT +logger = get_logger(__file__) + # parallel modes class ParallelMode(Enum): @@ -81,6 +86,349 @@ class ParallelMode(Enum): DKV_INTRA_WINDOW = "dkv_intra_window" +class GroupConfig: + """config for initialze a process group""" + + def __init__( + self, + mode: ParallelMode, + size: int, + anonymous: bool = False, + allow_partial_group: bool = False, + subgroups: Optional[List["GroupConfig"]] = None, + ) -> None: + self.mode = mode + self.size = size + self.anonymous = anonymous + self.allow_partial_group = allow_partial_group + self.subgroups = subgroups if subgroups is not None else [] + + self._early_subgroup_checking() + + def _early_subgroup_checking(self) -> None: + if len(self.subgroups) == 0: + return + + group_target_size = reduce(lambda x, y: x * y, [_g.size for _g in self.subgroups]) + assert group_target_size <= self.size, "subgroup size should less than father group" + + +def init_cpu_group(group, ranks, use_cpu: bool = False): + if use_cpu: + cpu_group = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) if dist.get_backend() != "gloo" else group + ) + else: + cpu_group = None + + return cpu_group + + +def get_group_ranks( + global_ranks_or_sizes: Union[int, List[int]], + cur_group_size: int, + pre_group_size: int, + allow_partial_group: bool = False, +): + group_ranks = [] + + if isinstance(global_ranks_or_sizes, list): + global_size = len(global_ranks_or_sizes) + global_ranks = global_ranks_or_sizes + else: + global_size = global_ranks_or_sizes + global_ranks = None + + real_global_size = global_size + + if allow_partial_group: + global_size = math.ceil(global_size / cur_group_size) * cur_group_size + + assert global_size % cur_group_size == 0, "err1" + + def _get_local_starts(): + for i in range(0, global_size, cur_group_size * pre_group_size): + for j in range(pre_group_size): + yield 0 + i + j + + for start in _get_local_starts(): + ranks = [ + start + i * pre_group_size for i in range(cur_group_size) if start + i * pre_group_size < real_global_size + ] + if global_ranks is not None: + ranks = [global_ranks[_idx] for _idx in ranks] + + group_ranks.append(ranks) + + assert len(group_ranks) == global_size // cur_group_size, f"{group_ranks}, {global_size}, {cur_group_size}" + + return group_ranks + + +def _create_parallel_process_groups( + global_ranks_or_sizes: int, + self_rank: int, + pre_group_size: int, + group_configs: List[GroupConfig], + with_cpu_group: bool = False, +): + group_results = [] + + for group in group_configs: + if group.anonymous is True: + pre_group_size = pre_group_size * group.size + continue + + group_ranks, accelerator_group = None, None + all_group_ranks = get_group_ranks(global_ranks_or_sizes, group.size, pre_group_size, group.allow_partial_group) + + for idx, ranks in enumerate(all_group_ranks): + _pg = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) + if self_rank in ranks: + group_ranks, accelerator_group = all_group_ranks[idx], _pg + else: + dist.destroy_process_group(_pg) + + if group_ranks is None: + pre_group_size = pre_group_size * group.size + continue + + cpu_group = init_cpu_group(accelerator_group, group_ranks, with_cpu_group) + + group_results.append( + (group_ranks.index(self_rank), len(group_ranks), accelerator_group, cpu_group, group_ranks, group.mode) + ) + + if len(group.subgroups) > 0: + subgroup_results = _create_parallel_process_groups( + global_ranks_or_sizes, self_rank, pre_group_size, group.subgroups, with_cpu_group + ) + group_results.extend(subgroup_results) + + pre_group_size = pre_group_size * group.size + + return group_results + + +def create_parallel_process_groups( + world_size: int, self_rank: int, group_configs: List[List[GroupConfig]], with_cpu_group: bool = False +): + group_results = [] + already_allocated_group = {} + + def _checker(order: str, result: Tuple[Any]) -> bool: + parallel_mode = result[-1] + + if parallel_mode not in already_allocated_group: + already_allocated_group[parallel_mode] = (order, result) + return True + else: + # check + ranks_in_group_idx = -2 + pre_order, pre_allocate_result = already_allocated_group[parallel_mode] + + error_msg = ( + f"The ranks allocated for {parallel_mode} are inconsistent in config {pre_order} and {order}: " + + f"{pre_allocate_result[ranks_in_group_idx]} != {result[ranks_in_group_idx]}" + ) + assert pre_allocate_result[ranks_in_group_idx] == result[ranks_in_group_idx], error_msg + + # release process group + dist.destroy_process_group(result[2]) # accelerator_group + if with_cpu_group: + dist.destroy_process_group(result[3]) # cpu_group + + return False + + for order, group_config in group_configs: + pre_group_size = 1 + + results = _create_parallel_process_groups( + world_size, + self_rank, + pre_group_size, + group_config, + with_cpu_group, + ) + + for result in results: + if _checker(order, result) is True: + group_results.append(result) + + return group_results + + +def create_single_process_group( + world_size: int, self_rank: int, config: GroupConfig, with_cpu_group: bool = False, pre_anonymous_size: int = 1 +): + pre_group_size = pre_anonymous_size + + return _create_parallel_process_groups( + world_size, + self_rank, + pre_group_size, + [config], + with_cpu_group, + )[0] + + +MTP_GROUP_ORDER = [ParallelMode.TENSOR, ParallelMode.DATA, ParallelMode.PIPELINE] +MTP_MOE_GROUP_ORDER = [ParallelMode.EXPERT_TENSOR, ParallelMode.EXPERT, ParallelMode.EXPERT_DATA, ParallelMode.PIPELINE] +ISP_SP_GROUP_ORDER = [ParallelMode.TENSOR, ParallelMode.DATA, ParallelMode.PIPELINE] +ISP_WP_GROUP_ORDER = [ParallelMode.WEIGHT, ParallelMode.WEIGHT_DATA, ParallelMode.PIPELINE] +ISP_MOE_GROUP_ORDER = [ParallelMode.EXPERT_WEIGHT, ParallelMode.EXPERT, ParallelMode.EXPERT_DATA, ParallelMode.PIPELINE] +FSDP_ORDER = [ParallelMode.DATA] # TODO: should we support moe for fsdp? + +SUBGROUP_SPEC = { + "mtp": { + ParallelMode.DATA: [ParallelMode.ZERO1], + }, + "isp": { + ParallelMode.WEIGHT_DATA: [ParallelMode.ZERO1], + }, # TODO: WEIGHT_ZERO1 + "fsdp": { + ParallelMode.DATA: [ParallelMode.ZERO3_DP, ParallelMode.ZERO1], + }, +} + + +def generate_parallel_group_configs( + parallel_strategy: str, parallel_sizes: Dict[ParallelMode, int], enable_moe: bool = False +) -> List[List[GroupConfig]]: + + group_configs = [] + subgroup_spec = SUBGROUP_SPEC.get(parallel_strategy, SUBGROUP_SPEC["mtp"]) + + def _recurse_generater(order: List[ParallelMode]): + config = [] + + for mode in order: + # disable pp process group for compatibility when pp size is 1. + anonymous = mode is ParallelMode.PIPELINE and parallel_sizes[mode] == 1 + + if mode not in subgroup_spec: + config.append(GroupConfig(mode, parallel_sizes[mode], anonymous)) + else: + config.append( + GroupConfig( + mode, parallel_sizes[mode], anonymous, subgroups=_recurse_generater(subgroup_spec[mode]) + ) + ) + + return config + + if parallel_strategy == "isp": + # sp configs + group_configs.append(("isp-sp", _recurse_generater(ISP_SP_GROUP_ORDER))) + # wp configs + group_configs.append(("isp-wp", _recurse_generater(ISP_WP_GROUP_ORDER))) + if enable_moe: + group_configs.append(("isp-moe", _recurse_generater(ISP_MOE_GROUP_ORDER))) + elif parallel_strategy == "fsdp": + group_configs.append(("fsdp", _recurse_generater(FSDP_ORDER))) + else: # 3d parallel: mtp, msp, fsp + group_configs.append(("3d", _recurse_generater(MTP_GROUP_ORDER))) + if enable_moe: + group_configs.append(("3d-moe", _recurse_generater(MTP_MOE_GROUP_ORDER))) + + return group_configs + + +def generate_2d_attn_process_group( + world_size: int, + self_rank: int, + config: Dict[str, Any], + parallel_sizes: Dict[ParallelMode, int], + with_cpu_group: bool = False, +): + + assert config.context_size * config.head_size == parallel_sizes[ParallelMode.SEQUENCE] + assert world_size % parallel_sizes[ParallelMode.SEQUENCE] == 0 + + if config.window_size >= 8 or config.window_size == config.context_size: + logger.warning("interleaved is forced False when window size > 8 or equals context size.") + config.interleaved = False + + if config.device_placement_strategy.head_first and config.head_size > 1: + logger.warning("interleaved is forced False when head_first is True and head size > 1.") + config.interleaved = False + + group_results = [] + sp_pre_group_size = 1 + for parallel_mode in ISP_SP_GROUP_ORDER: + if parallel_mode is ParallelMode.TENSOR: # assert sp is tp. + break + else: + sp_pre_group_size *= parallel_sizes[parallel_mode] + + # head and context process groups. + if config.device_placement_strategy.head_first: + group_configs = [ + GroupConfig(ParallelMode.HEAD, config.head_size), + GroupConfig(ParallelMode.CONTEXT, config.context_size), + ] + context_results_index = 1 + else: + group_configs = [ + GroupConfig(ParallelMode.CONTEXT, config.context_size), + GroupConfig(ParallelMode.HEAD, config.head_size), + ] + context_results_index = 0 + + group_results.extend( + _create_parallel_process_groups(world_size, self_rank, sp_pre_group_size, group_configs, with_cpu_group) + ) + + # window process groups. + window_num = config.context_size // config.window_size + cp_pre_group_size = 1 if context_results_index == 0 else config.head_size + every_context_ranks = get_group_ranks(world_size, config.context_size, cp_pre_group_size) + + def _gen_window_process_groups(context_ranks: List[int]): + if not config.device_placement_strategy.interleaved: + window_ranks = context_ranks + else: + _indexes = [ + j * 2 + i * config.window_size if i % 2 == 0 else j * 2 + 1 + (i - 1) * config.window_size + for i in range(window_num) + for j in range(config.window_size) + ] + window_ranks = [context_ranks[_i] for _i in _indexes] + + group_results.extend( + _create_parallel_process_groups( + window_ranks, + self_rank, + 1, + [ + GroupConfig(ParallelMode.INTRA_WINDOW, config.window_size), + GroupConfig(ParallelMode.INTER_WINDOW, window_num), + ], + with_cpu_group, + ) + ) + group_results.extend( + _create_parallel_process_groups( + window_ranks, + self_rank, + 1, + [ + GroupConfig(ParallelMode.DKV_INTRA_WINDOW, config.window_size), + GroupConfig(ParallelMode.DKV_INTER_WINDOW, window_num), + ], + with_cpu_group, + ) + ) + + for context_ranks in every_context_ranks: + _gen_window_process_groups(context_ranks) + + # print(get_group_ranks(window_ranks, config.window_size, 1)) + # print(get_group_ranks(window_ranks, window_num, config.window_size)) + + return group_results + + class ProcessGroupInitializer(ABC): """An object, knowing the parallelism configuration, that initializes parallel groups. @@ -1124,11 +1472,10 @@ class Initializer_GQA(ProcessGroupInitializer): """ def __init__(self, *args, **kwargs): + self.num_attention_heads = kwargs.pop("num_attention_heads") + self.num_kv_attention_heads = kwargs.pop("num_kv_attention_heads") super().__init__(*args, **kwargs) - # TODO: should adapt to general case - self.num_kv_attention_heads = 8 - self.NUM_ATTENTION_HEAD = 32 - self.kv_head_repeats_num = self.NUM_ATTENTION_HEAD // self.num_kv_attention_heads + self.kv_head_repeats_num = self.tensor_parallel_size // self.num_kv_attention_heads self.num_kv_group_per_tp = self.num_kv_attention_heads self.num_kv_groups = self.num_kv_group_per_tp * self.data_parallel_size @@ -1159,7 +1506,6 @@ def init_dist_group(self, use_cpu: bool = False): group_world_size = None mode = ParallelMode.GQA - # TODO: consider PP for i in range(self.data_parallel_size): for j in range(self.num_kv_group_per_tp): ranks = [