diff --git a/internlm/data/build_dataloader.py b/internlm/data/build_dataloader.py index 64da9539..e99bbfc7 100644 --- a/internlm/data/build_dataloader.py +++ b/internlm/data/build_dataloader.py @@ -2,12 +2,13 @@ import subprocess from functools import partial +import torch import torch.distributed as dist from torch.utils.data import ConcatDataset, DataLoader +from internlm.accelerator.abstract_accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.data.megatron.batch_sampler import MegatronBatchSampler from internlm.data.megatron.collaters import megatron_collate_fn from internlm.data.megatron.dataset import build_megatron_dataset from internlm.data.mocked.batch_sampler import MockedSequentialBatchSampler @@ -41,8 +42,8 @@ from internlm.utils.logger import get_logger from internlm.utils.utils import DataType -# global llm logger logger = get_logger(__file__) +internlm_accelerator = get_accelerator() def get_tokenized_train_loader_items(data_cfg): @@ -156,10 +157,14 @@ def get_streaming_train_loader_items(data_cfg): def get_megatron_train_loader_items(data_cfg): + assert data_cfg.get( + "pack_sample_into_one", False + ), "megatron dataloader curently only supports pack_sample_into_one=True" try: from internlm.data.megatron import helpers # noqa # pylint: disable=W0611 except ImportError: - if gpc.is_rank_for_log(): + # Compile dynamic library on-demand + if gpc.get_global_rank() % internlm_accelerator.device_count() == 0: subprocess.run( # noqa # pylint: disable=W1510 [ "g++", @@ -173,23 +178,28 @@ def get_megatron_train_loader_items(data_cfg): "internlm/data/megatron/helpers.cpp", "-o", "internlm/data/megatron/helpers.so", - ] + ], ) + torch.distributed.barrier() + + # NOTICE: Currently we only support single megatron dataset, a.k.a., single .bin and .idx + # Megatron dataset (.bin and.idx) should be generated by Megatron-LM tools/preprocess_data.py + # https://github.com/NVIDIA/Megatron-LM/blob/main/tools/preprocess_data.py train_ds = build_megatron_dataset( data_prefix=data_cfg.train_folder, - data_impl=data_cfg.get("data_impl", "infer"), - splits_string="1.0, 0.0, 0.0", - train_valid_test_num_samples=[9600000, 0, 0], seq_len=data_cfg.seq_len, seed=data_cfg.get("seed", 1024), - skip_warmup=True, ) - train_sampler = MegatronBatchSampler( - total_samples=len(train_ds), - consumed_samples=0, + train_sampler = StaticBatchSampler( + train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds], batch_size=data_cfg.micro_num * data_cfg.micro_bsz, + rampup_batch_size=data_cfg.rampup_batch_size, + micro_bsz=data_cfg.micro_bsz, + seed=data_cfg.get("seed", 1024), drop_last=True, + data_rank=gpc.get_local_rank(ParallelMode.DATA), + data_world_size=gpc.get_world_size(ParallelMode.DATA), ) train_collate_fn = partial( @@ -203,14 +213,18 @@ def get_mock_train_loader_items(data_cfg): assert data_cfg.get( "pack_sample_into_one", False ), "mocked dataloader curently only supports pack_sample_into_one=True" + train_ds = MockedDataset( train_folder=data_cfg.train_folder, micro_bsz=data_cfg.micro_bsz, micro_num=data_cfg.micro_num, seq_len=data_cfg.seq_len, ) + train_sampler = MockedSequentialBatchSampler(train_ds, data_cfg.micro_num) + train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.seq_len * data_cfg.micro_bsz) + return train_ds, train_sampler, train_collate_fn diff --git a/internlm/data/megatron/__init__.py b/internlm/data/megatron/__init__.py index 5e447596..5405f6f8 100644 --- a/internlm/data/megatron/__init__.py +++ b/internlm/data/megatron/__init__.py @@ -1,9 +1,7 @@ -from .batch_sampler import MegatronBatchSampler from .collaters import megatron_collate_fn from .dataset import build_megatron_dataset __all__ = [ - "MegatronBatchSampler", "build_megatron_dataset", "megatron_collate_fn", ] diff --git a/internlm/data/megatron/batch_sampler.py b/internlm/data/megatron/batch_sampler.py deleted file mode 100644 index 049cfcf7..00000000 --- a/internlm/data/megatron/batch_sampler.py +++ /dev/null @@ -1,62 +0,0 @@ -import copy -import math - -from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc - - -class MegatronBatchSampler: - """ - MegatronBatchSampler - """ - - def __init__(self, total_samples, consumed_samples, batch_size, drop_last=True): - # Keep a copy of input params for later use. - self.total_samples = total_samples - self.consumed_samples = consumed_samples - self.batch_size = batch_size - self.drop_last = drop_last - - self.dp_rank = gpc.get_local_rank(ParallelMode.DATA) - self.dp_size = gpc.get_world_size(ParallelMode.DATA) - - # Sanity checks. - assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples) - assert self.consumed_samples < self.total_samples, "no samples left to consume: {}, {}".format( - self.consumed_samples, self.total_samples - ) - assert self.batch_size > 0 - assert self.dp_size > 0 - assert self.dp_rank < self.dp_size, "dp_rank should be smaller than dp_size: {}, " "{}".format( - self.dp_rank, self.dp_size - ) - - def __len__(self): - if self.drop_last and self.total_samples % self.dp_size != 0: - return math.ceil(self.total_samples - self.dp_size) / self.dp_size - else: - return math.ceil(self.total_samples / self.dp_size) - - def get_start_end_idx(self): - start_idx = self.dp_rank * self.batch_size - end_idx = start_idx + self.batch_size - return start_idx, end_idx - - def __iter__(self): - batch = [] - # Last batch will be dropped if drop_last is not set False - for idx in range(self.consumed_samples, self.total_samples): - batch.append(idx) - if len(batch) == self.batch_size * self.dp_size: - start_idx, end_idx = self.get_start_end_idx() - yield batch[start_idx:end_idx] - batch = [] - - # Check the last partial batch and see drop_last is set - if len(batch) > 0 and not self.drop_last: - start_idx, end_idx = self.get_start_end_idx() - yield batch[start_idx:end_idx] - - # TODO: implement copy method that compatible with InternEvo trainstate - def copy(self): - return copy.deepcopy(self) diff --git a/internlm/data/megatron/collaters.py b/internlm/data/megatron/collaters.py index 252bc289..c6ffc80e 100644 --- a/internlm/data/megatron/collaters.py +++ b/internlm/data/megatron/collaters.py @@ -2,48 +2,36 @@ def megatron_collate_fn(batch, micro_num, micro_bsz, seq_len): - - input_ids_result = [[] for _ in range(micro_num)] - labels_result = [[] for _ in range(micro_num)] - cu_seqlens = [] + input_ids_list = [[] for _ in range(micro_num)] + labels_list = [[] for _ in range(micro_num)] cu_seqlens_list = [] - indexes = [] indexes_list = [] - for i, item in enumerate(batch): - assert i < micro_num * micro_bsz - seq_len_list = item["text"] - assert len(seq_len_list) == seq_len + 1 - - micro_bsz_index = i % micro_bsz - micro_num_index = i // micro_bsz - - input_ids_result[micro_num_index].append(seq_len_list[:-1]) - labels_result[micro_num_index].append(seq_len_list[1:]) - - cu_seqlens.append(seq_len * micro_bsz_index) - indexes = indexes + list(range(seq_len)) + assert len(batch) == micro_bsz * micro_num + for idx, b in enumerate(batch): + tokens = b["text"] + # The length of megatron preprocessed data samples is (seq_len + 1) + # So we use the first seq_len tokens as input and the last seq_len tokens as shifted labels + assert len(tokens) == seq_len + 1 + micro_bsz_index = idx % micro_bsz + micro_num_index = idx // micro_bsz + input_ids_list[micro_num_index].append(tokens[:-1]) + labels_list[micro_num_index].append(tokens[1:]) if micro_bsz_index == micro_bsz - 1: - input_ids_result[micro_num_index] = torch.cat( - [torch.from_numpy(arr).long() for arr in input_ids_result[micro_num_index]], dim=0 + # Since megatron data sample is numpy format, we need to convert it to tensor and concate within micro batch + input_ids_list[micro_num_index] = torch.cat( + [torch.from_numpy(arr) for arr in input_ids_list[micro_num_index]], dim=0 ) - labels_result[micro_num_index] = torch.cat( - [torch.from_numpy(arr).long() for arr in labels_result[micro_num_index]], dim=0 + labels_list[micro_num_index] = torch.cat( + [torch.from_numpy(arr) for arr in labels_list[micro_num_index]], dim=0 ) - cu_seqlens.append(seq_len * micro_bsz) - cu_seqlens_list.append(torch.IntTensor(cu_seqlens)) - cu_seqlens = [] - indexes_list.append(torch.IntTensor(indexes)) - indexes = [] - - input_ids = torch.stack(input_ids_result) - labels = torch.stack(labels_result) - indexes = torch.stack(indexes_list) + cu_seqlens_list.append(torch.IntTensor([i * seq_len for i in range(micro_bsz + 1)])) + indexes_list.append(torch.IntTensor(list(range(seq_len)) * micro_bsz)) return { - "input_ids": input_ids, + "input_ids": torch.stack(input_ids_list), "cu_seqlens": cu_seqlens_list, - "indexes": indexes, + "indexes": torch.stack(indexes_list), "type_ids": torch.zeros(micro_num, micro_bsz * seq_len, dtype=torch.int64), - }, labels + }, torch.stack(labels_list) diff --git a/internlm/data/megatron/dataset.py b/internlm/data/megatron/dataset.py index 7dba0294..88f4697b 100644 --- a/internlm/data/megatron/dataset.py +++ b/internlm/data/megatron/dataset.py @@ -1,5 +1,6 @@ # adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/datasets/gpt_dataset.py # adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/datasets/indexed_dataset.py + import hashlib import os import struct @@ -764,82 +765,25 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): return indexed_dataset -def get_train_valid_test_split_(splits_string, size): - """Get dataset splits from comma or '/' separated string list.""" - - splits = [] - if splits_string.find(",") != -1: - splits = [float(s) for s in splits_string.split(",")] - elif splits_string.find("/") != -1: - splits = [float(s) for s in splits_string.split("/")] - else: - splits = [float(splits_string)] - while len(splits) < 3: - splits.append(0.0) - splits = splits[:3] - splits_sum = sum(splits) - assert splits_sum > 0.0 - splits = [split / splits_sum for split in splits] - splits_index = [0] - for index, split in enumerate(splits): - splits_index.append(splits_index[index] + int(round(split * float(size)))) - diff = splits_index[-1] - size - for index in range(1, len(splits_index)): - splits_index[index] -= diff - assert len(splits_index) == 4 - assert splits_index[-1] == size - return splits_index - - def build_megatron_dataset( data_prefix, - data_impl, - splits_string, - train_valid_test_num_samples, seq_len, seed, - skip_warmup, - return_doc_ids=False, - *, - data_cache_path=None, ): - # Indexed dataset. - indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) - - total_num_of_documents = indexed_dataset.sizes.shape[0] - splits = get_train_valid_test_split_(splits_string, total_num_of_documents) - - # Print stats about the splits. - print_rank_0(" > dataset split:") - - def print_split_stats(index, name): - print_rank_0(" {}:".format(name)) - print_rank_0( - " document indices in [{}, {}) total of {} " - "documents".format(splits[index], splits[index + 1], splits[index + 1] - splits[index]) - ) - - print_split_stats(0, "train") - - def build_dataset(index, name): - dataset = None - if splits[index + 1] > splits[index]: - documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32) - dataset = GPTDataset( - name, - data_prefix, - documents, - indexed_dataset, - splits_string, - train_valid_test_num_samples[index], - seq_len, - seed, - return_doc_ids, - data_cache_path=data_cache_path, - ) - return dataset - - train_dataset = build_dataset(0, "train") - - return train_dataset + indexed_dataset = get_indexed_dataset_(data_prefix, data_impl="infer", skip_warmup=True) + + # GPT dataset. + return GPTDataset( + name="train", + data_prefix=data_prefix, + documents=np.arange(start=0, stop=indexed_dataset.sizes.shape[0], step=1, dtype=np.int32), + indexed_dataset=indexed_dataset, + splits_string="1.0, 0.0, 0.0", # proportion of dataset for train/valid/test, we set 1.0 for train only + num_samples=gpc.config.data.micro_bsz + * gpc.config.data.micro_num + * gpc.get_world_size(ParallelMode.DATA) + * gpc.config.data.total_steps, # total number of train samples + seq_length=seq_len, + seed=seed, + ) diff --git a/internlm/data/mocked/batch_sampler.py b/internlm/data/mocked/batch_sampler.py index 737566fa..62f3dcea 100644 --- a/internlm/data/mocked/batch_sampler.py +++ b/internlm/data/mocked/batch_sampler.py @@ -1,24 +1,46 @@ -import copy - - class MockedSequentialBatchSampler: """ - MockedSequentialBatchSampler + A batch sampler that yields sequential batches of a specified size from a dataset. """ def __init__(self, train_ds, micro_num): + """ + Initialize the MockedSequentialBatchSampler. + + Args: + train_ds: The training dataset to sample from. + micro_num (int): The number of micro batches. + """ self.train_ds = train_ds self.micro_num = micro_num + self.batch_count = 0 + self.num_consumed_samples_in_epoch = 0 + def __iter__(self): num_samples = len(self.train_ds) - for start in range(0, num_samples, self.micro_num): + while self.num_consumed_samples_in_epoch < num_samples: + start = self.num_consumed_samples_in_epoch end = min(start + self.micro_num, num_samples) + self.batch_count += 1 + self.num_consumed_samples_in_epoch += end - start yield list(range(start, end)) def __len__(self): return (len(self.train_ds) + self.micro_num - 1) // self.micro_num - # TODO: implement copy method that compatible with InternEvo trainstate + def state_dict(self): + states = { + "batch_count": self.batch_count, + "num_consumed_samples_in_epoch": self.num_consumed_samples_in_epoch, + } + return states + + def load_state_dict(self, states): + self.batch_count = states["batch_count"] + self.num_consumed_samples_in_epoch = states["num_consumed_samples_in_epoch"] + def copy(self): - return copy.deepcopy(self) + copy_sampler = MockedSequentialBatchSampler(self.train_ds, self.micro_num) + copy_sampler.load_state_dict(self.state_dict()) + return copy_sampler diff --git a/internlm/data/mocked/dataset.py b/internlm/data/mocked/dataset.py index 0d0e488e..88020a78 100644 --- a/internlm/data/mocked/dataset.py +++ b/internlm/data/mocked/dataset.py @@ -108,7 +108,7 @@ def __init__(self, train_folder: str, micro_bsz: int, micro_num: int, seq_len: i ] # simple sanity check: ensure loaded per-step data is equivalent to saved per-step data - self.sanity_check(tokens_list, labels_list) + self._sanity_check(tokens_list, labels_list) def __len__(self) -> int: return len(self.db_tokens) @@ -122,7 +122,7 @@ def __getitem__(self, idx: int) -> Dict[str, List[int]]: "type_ids": [0] * (self.micro_bsz * self.seq_len), } - def sanity_check(self, tokens_list: List[torch.Tensor], labels_list: List[torch.Tensor]): + def _sanity_check(self, tokens_list: List[torch.Tensor], labels_list: List[torch.Tensor]): tokens_list_tocheck = [] for i in range(len(self.db_tokens)): tokens_list_tocheck += self.db_tokens[i] diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 34c1479b..81eda53e 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -125,7 +125,7 @@ def set_fp32_attr_for_model(model: Union[nn.Module, nn.ModuleList]): def set_parallel_attr_for_param_groups(model: Union[nn.Module, nn.ModuleList]): - def _check_module_pure_dp_wdp(name, module): # pylint: disable=W0613 + def _check_module_pure_dp(name, module): # pylint: disable=W0613 for param in module.parameters(): setattr(param, IS_REPLICA_ZERO_PARALLEL, True) @@ -181,11 +181,13 @@ def _check_module(name, module): setattr(param, IS_REPLICA_ZERO_PARALLEL, True) for _chunk in unwrap_naive_amp(model): - # special case for pure dp or pure wdp mode - if gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) and gpc.get_world_size( - ParallelMode.WEIGHT_DATA - ) == gpc.get_world_size(ParallelMode.GLOBAL): - _check_module_func = _check_module_pure_dp_wdp + # special case for pure dp mode + if ( + isinstance(gpc.config.parallel["tensor"], dict) + and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.mtp.name + and gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) + ): + _check_module_func = _check_module_pure_dp else: _check_module_func = _check_module # set param parallel attribute @@ -889,22 +891,34 @@ def traverse(module): def inject_config(model: nn.Module) -> None: + # Compatibility for Vision-Language Model if hasattr(model.config, "text_config"): - model_config = model.config.text_config + llm_cfg = model.config.text_config else: - model_config = model.config - gpc.config.model.vocab_size = gpc.config.VOCAB_SIZE = model_config.vocab_size - gpc.config.model.hidden_size = gpc.config.HIDDEN_SIZE = model_config.hidden_size - gpc.config.model.num_layers = gpc.config.NUM_LAYER = model_config.num_hidden_layers - gpc.config.model.num_attention_heads = gpc.config.NUM_ATTENTION_HEAD = model_config.num_attention_heads - gpc.config.model.mlp_ratio = gpc.config.MLP_RATIO = model_config.intermediate_size / model_config.hidden_size + llm_cfg = model.config + gpc.config.model.vocab_size = gpc.config.VOCAB_SIZE = llm_cfg.vocab_size + gpc.config.model.hidden_size = gpc.config.HIDDEN_SIZE = llm_cfg.hidden_size + gpc.config.model.num_layers = gpc.config.NUM_LAYER = llm_cfg.num_hidden_layers + # Compatibility for Mamba + if hasattr(llm_cfg, "num_attention_heads"): + gpc.config.model.num_attention_heads = gpc.config.NUM_ATTENTION_HEAD = llm_cfg.num_attention_heads + gpc.config.model.mlp_ratio = gpc.config.MLP_RATIO = llm_cfg.intermediate_size / llm_cfg.hidden_size # For models that use GQA - if hasattr(model_config, "num_key_value_heads"): - gpc.config.model.num_kv_attention_heads = gpc.config.NUM_KV_ATTENTION_HEAD = model_config.num_key_value_heads + if hasattr(llm_cfg, "num_key_value_heads"): + gpc.config.model.num_kv_attention_heads = gpc.config.NUM_KV_ATTENTION_HEAD = llm_cfg.num_key_value_heads def inject_model_helper(model: Union[nn.Module, nn.ModuleList], inject_info: Optional[Dict] = None) -> None: - # get inject_info + """ + Inject model helper functions. + + Args: + model (Union[nn.Module, nn.ModuleList]): + For built-in models, it is nn.Module for no pp and nn.ModuleList for pp. + For injected models, it is nn.Module. + inject_info (Optional[Dict]): configurations for injected_models. + """ + # parse inject_info if inject_info is not None: inject = inject_info.get("inject", False) interactive = inject_info.get("interactive", False) @@ -926,31 +940,37 @@ def inject_model_helper(model: Union[nn.Module, nn.ModuleList], inject_info: Opt "norm": inject_norm, } + # inject config + if inject: + inject_config(model) + if not isinstance(model, nn.ModuleList): model = [model] - - # inject modules for _chunk in model: - if gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) and gpc.get_world_size( - ParallelMode.WEIGHT_DATA - ) == gpc.get_world_size(ParallelMode.GLOBAL): + # Special case for pure dp mode: skip + if ( + isinstance(gpc.config.parallel["tensor"], dict) + and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.mtp.name + and gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) + ): continue + # In-place replacement or check for modules: "embed", "linear", "norm" + # (1) If inject=True, in-place replacement + # (2) If inject=False, check for mod in modules: inject_funcs[mod](_chunk, inject, interactive) - - # reset parameters and move model to device + # reset parameters if needed, model should have reset_parameters() method + if reset_params: + _chunk.reset_parameters() for _chunk in model: - if inject: - if reset_params: - _chunk.reset_parameters() + # If model is initialized on cpu, model should be moved to cuda device after injection + if not next(_chunk.parameters()).is_cuda: _chunk.to(get_current_device()) - # inject configs - if inject: - inject_config(model[0]) - if gpc.is_rank_for_log(): - logger.info( - f"inject is enabled, please check the model carefully, " - f"if there are any problems, please report issue to us. " - f"The injected model is \n {model}" - ) + # print injected model + if inject and gpc.is_rank_for_log(): + logger.info( + f"inject is enabled, please check the model carefully, " + f"if there are any problems, please report issue to us. " + f"The injected model is \n {model}" + )