From 3bd5c9f36f90e8cf72c12b7d7c10c44ef03c1780 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 23 Apr 2024 14:12:20 +0800 Subject: [PATCH] [exampe] update llama example (#5626) * [plugin] support dp inside for hybriad parallel * [example] update llama benchmark * [example] update llama benchmark * [example] update llama readme * [example] update llama readme --- colossalai/booster/plugin/gemini_plugin.py | 1 + .../booster/plugin/hybrid_parallel_plugin.py | 28 +- examples/language/llama2/README.md | 117 +------ examples/language/llama2/attn.py | 1 - examples/language/llama2/benchmark.py | 62 +++- examples/language/llama2/finetune.py | 313 ----------------- examples/language/llama2/pretrain.py | 328 ------------------ examples/language/llama2/requirements.txt | 5 +- 8 files changed, 72 insertions(+), 783 deletions(-) delete mode 120000 examples/language/llama2/attn.py delete mode 100644 examples/language/llama2/finetune.py delete mode 100644 examples/language/llama2/pretrain.py diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 442ac4a8da06..a67ca18a3456 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -424,6 +424,7 @@ def __init__( ) self.extra_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) if self.extra_dp_size > 1 else None self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if self.tp_size > 1 else None + self.dp_size = self.zero_size * self.extra_dp_size self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 8d12eb80621d..95fb2def10a4 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -34,7 +34,6 @@ from .pp_plugin_base import PipelinePluginBase -DP_AXIS, PP_AXIS, TP_AXIS, SP_AXIS = 0, 1, 2, 3 SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} @@ -987,6 +986,7 @@ def __init__( gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, enable_metadata_cache: bool = True, make_vocab_size_divisible_by: int = 64, + dp_outside: bool = True, ) -> None: super().__init__() assert ( @@ -1034,7 +1034,12 @@ def __init__( self.enable_flash_attention = enable_flash_attention self.enable_jit_fused = enable_jit_fused self.enable_sequence_parallelism = enable_sequence_parallelism - self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) + if dp_outside: + self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) + else: + self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) self.stage_manager = None self.schedule = None self.custom_policy = custom_policy @@ -1048,7 +1053,7 @@ def __init__( assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism" self.stage_manager = PipelineStageManager( self.pg_mesh, - pipeline_axis=PP_AXIS, + pipeline_axis=self.pp_axis, enable_interleave=pp_style == "interleaved", num_model_chunks=num_model_chunks, ) @@ -1072,13 +1077,13 @@ def __init__( else: raise NotImplementedError() - self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) - self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) - self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) + self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) + self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) + self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis) if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]: - self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) else: - self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS) + self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, @@ -1169,7 +1174,7 @@ def configure( and self.sequence_parallelism_mode == "all_to_all" ) if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": - dp_group = self.pg_mesh.create_group_along_axis([DP_AXIS, SP_AXIS]) + dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) else: dp_group = self.dp_group model = HybridParallelModule( @@ -1317,7 +1322,10 @@ def prepare_dataloader( _kwargs = kwargs.copy() distributed_sampler_cls = distributed_sampler_cls or DistributedSampler sampler = distributed_sampler_cls( - dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle + dataset, + num_replicas=self.pg_mesh.size(self.dp_axis), + rank=self.pg_mesh.coordinate(self.dp_axis), + shuffle=shuffle, ) # Deterministic dataloader diff --git a/examples/language/llama2/README.md b/examples/language/llama2/README.md index 068f15cbb041..11b2ee511a6e 100644 --- a/examples/language/llama2/README.md +++ b/examples/language/llama2/README.md @@ -1,4 +1,4 @@ -# Pretraining LLaMA-1/2: best practices for building LLaMA-1/2-like base models +# Pretraining LLaMA-1/2/3: best practices for building LLaMA-1/2/3-like base models ### LLaMA2

@@ -16,38 +16,10 @@ - 65-billion-parameter large model pretraining accelerated by 38% [[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining) -## Dataset - -Different from the original LLaMA, we use [RedPajama](https://www.together.xyz/blog/redpajama) dataset, which is a reproduction of the LLaMA training dataset containing over 1.2 trillion tokens. The full dataset is ~5TB unzipped on disk and ~3TB to download compressed. - -A smaller, more consumable random sample can be downloaded through [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T). If you just want to try out the pretraining script, you can use a 1B-token sample subset of RedPajama, which is available at [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample). - -RedPajama-Data-1T consists of seven data slices: - -| | RedPajama | LLaMA | -|---------------|--------------|---------------| -| CommonCrawl | 878 billion | 852 billion | -| C4 | 175 billion | 190 billion | -| Github | 59 billion | 100 billion | -| Books | 26 billion | 25 billion | -| ArXiv | 28 billion | 33 billion | -| Wikipedia | 24 billion | 25 billion | -| StackExchange | 20 billion | 27 billion | -| Total | 1.2 trillion | 1.25 trillion | - -## Training - -We follow the hyperparameter settings from the original LLaMA paper. We use AdamW with $beta1=0.9$ and $beta2=0.95$. We use a cosine learning rate schedule, such that the final learning rate is equal to 10% of the maximal learning rate. We use a weight decay of 0.1 and gradient clipping of 1.0. We use 2,000 warmup steps. - -| params | learning rate | batch size | -|--------|---------------|------------| -| 6.7B | 3.0e-4 | 4M | -| 13.0B | 3.0e-4 | 4M | -| 32.5B | 1.5e-4 | 4M | -| 65.2B | 1.5e-4 | 4M | - ## Usage +> ⚠ This example only has benchmarking script. For training/finetuning, please refer to the [applications/Colossal-LLaMA](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA). + ### 1. Installation Please install the latest ColossalAI from source. @@ -62,52 +34,6 @@ Then install other dependencies. pip install -r requirements.txt ``` -Additionally, we recommend you to use torch 1.13.1. We've tested our code on torch 1.13.1 and found it's compatible with our code and flash attention. - -### 2. Download the dataset - -The dataset can be automatically downloaded by using `huggingface/datasets`. You can specify the dataset path by `-d` or `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`. - -### 3. Command line arguments - -Yon can use colossalai run to launch multi-nodes training: -```bash -colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \ -pretrain.py --OTHER_CONFIGURATIONS -``` - -Here is a sample hostfile: - -```text -hostname1 -hostname2 -hostname3 -hostname4 -``` - -Make sure master node can access all nodes (including itself) by ssh without password. - -Here is details about CLI arguments: - -- Model configuration: `-c`, `--config`. `7b`, `13b`, `30b` and `65b` are supported for LLaMA-1, `7b`, `13b`, and `70b` are supported for LLaMA-2. -- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins). -- Dataset path: `-d`, `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`. It support any dataset from `datasets` with the same data format as RedPajama. -- Number of epochs: `-e`, `--num_epochs`. The default value is 1. -- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2. -- Learning rate: `--lr`. The default value is 3e-4. -- Weight decay: `-w`, `--weight_decay`. The default value is 0.1. -- Warmup steps: `-s`, `--warmup_steps`. The default value is 2000. -- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size. -- Max length: `-l`, `--max_length`. The default value is 4096. -- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported. -- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000. -- Checkpoint directory: `-o`, `--save_dir`. The directory path to save checkpoints. The default value is `checkpoint`. -- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`. -- Gradient clipping: `--gradient_clipping`. The default value is 1.0. -- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`. -- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention. - - ### 4. Shell Script Examples For your convenience, we provide some shell scripts to run benchmark with various configurations. @@ -193,40 +119,3 @@ If you run the above command successfully, you will get the following results: year={2023} } ``` - - -# Fine-tune Llama2 - -We also provide a example to fine-tune llama2 in `finetune.py`, - -Make sure master node can access all nodes (including itself) by ssh without password. - -Here is details about CLI arguments: - -- Pretrained checkpoint path: `--model_path`, the path of your model checkpoint, it can be your local directory or a Hugging Face tag. -- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins). -- Dataset path: `-d`, `--dataset`. The default dataset is `yizhongw/self_instruct`. It support any dataset from `datasets` with the same data format as `yizhongw/self_instruct`. -- task name: `--task_name`, the task to fine-tune, it's also related to the target of loading dataset, The default value is `super_natural_instructions`. -- Number of epochs: `-e`, `--num_epochs`. The default value is 1. -- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2. -- Learning rate: `--lr`. The default value is 3e-4. -- Weight decay: `-w`, `--weight_decay`. The default value is 0.1. -- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size. -- Max length: `-l`, `--max_length`. The default value is 4096. -- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported. -- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000. -- Checkpoint directory: `-o`, `--save_dir`. The directory path to save checkpoints. The default value is `checkpoint`. -- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`. -- Gradient clipping: `--gradient_clipping`. The default value is 1.0. -- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`. -- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention. - - -```shell -torchrun --standalone --nproc_per_node 8 finetune.py \ - --plugin "hybrid_parallel" \ - --dataset "yizhongw/self_instruct" \ - --model_path "/path/llama" \ - --task_name "super_natural_instructions" \ - --save_dir "/path/output" -``` diff --git a/examples/language/llama2/attn.py b/examples/language/llama2/attn.py deleted file mode 120000 index 4e95c7bfa519..000000000000 --- a/examples/language/llama2/attn.py +++ /dev/null @@ -1 +0,0 @@ -../../../applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py \ No newline at end of file diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py index 832465490907..ff94891f50ec 100644 --- a/examples/language/llama2/benchmark.py +++ b/examples/language/llama2/benchmark.py @@ -3,14 +3,13 @@ from contextlib import nullcontext import torch -from attn import replace_with_flash_attention from data_utils import RandomDataset from model_utils import format_numel_str, get_model_numel from performance_evaluator import PerformanceEvaluator from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import LlamaForCausalLM import colossalai from colossalai.accelerator import get_accelerator @@ -19,6 +18,7 @@ from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam +from colossalai.shardformer import PipelineGradientCheckpointConfig from examples.language.data_utils import RandomDataset from examples.language.model_utils import format_numel_str, get_model_numel from examples.language.performance_evaluator import PerformanceEvaluator @@ -78,6 +78,7 @@ def main(): parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel") parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled") + parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) args = parser.parse_args() colossalai.launch_from_torch({}) @@ -86,6 +87,19 @@ def main(): def empty_init(): pass + # ckpt config for LLaMA3-70B on 64 H100 GPUs + ckpt_config = ( + PipelineGradientCheckpointConfig( + num_stages=args.pp, + num_model_chunks=1, + num_model_layers=80, + num_layers_per_stage=[19, 20, 20, 21], + num_ckpt_layers_per_stage=[19, 19, 19, 13], + ) + if args.custom_ckpt + else None + ) + # ============================== # Initialize Booster # ============================== @@ -98,6 +112,8 @@ def empty_init(): offload_param_frac=args.offload_param_frac, tp_size=args.tp, extra_dp_size=args.extra_dp, + enable_fused_normalization=torch.cuda.is_available(), + enable_flash_attention=args.xformers, ) elif args.plugin == "gemini_auto": plugin = GeminiPlugin( @@ -106,26 +122,34 @@ def empty_init(): warmup_non_model_data_ratio=args.warmup_ratio, tp_size=args.tp, extra_dp_size=args.extra_dp, + enable_fused_normalization=torch.cuda.is_available(), + enable_flash_attention=args.xformers, ) elif args.plugin == "fsdp": if use_empty_init: plugin = TorchFSDPPlugin( mixed_precision=MixedPrecision( - param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + param_dtype=torch.float16, + reduce_dtype=torch.float16, + buffer_dtype=torch.float16, ), param_init_fn=empty_init(), ) else: plugin = TorchFSDPPlugin( mixed_precision=MixedPrecision( - param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + param_dtype=torch.float16, + reduce_dtype=torch.float16, + buffer_dtype=torch.float16, ) ) elif args.plugin == "fsdp_cpu": if use_empty_init: plugin = TorchFSDPPlugin( mixed_precision=MixedPrecision( - param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + param_dtype=torch.float16, + reduce_dtype=torch.float16, + buffer_dtype=torch.float16, ), cpu_offload=CPUOffload(offload_params=True), param_init_fn=empty_init(), @@ -133,7 +157,9 @@ def empty_init(): else: plugin = TorchFSDPPlugin( mixed_precision=MixedPrecision( - param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + param_dtype=torch.float16, + reduce_dtype=torch.float16, + buffer_dtype=torch.float16, ), cpu_offload=CPUOffload(offload_params=True), ) @@ -141,12 +167,13 @@ def empty_init(): plugin = HybridParallelPlugin( tp_size=args.tp, pp_size=args.pp, - pp_style="interleaved", zero_stage=args.zero, - num_model_chunks=2, enable_fused_normalization=torch.cuda.is_available(), + enable_flash_attention=args.xformers, microbatch_size=args.mbs, precision="bf16", + dp_outside=False, + gradient_checkpoint_config=ckpt_config, ) elif args.plugin == "3d_cpu": plugin = HybridParallelPlugin( @@ -155,6 +182,7 @@ def empty_init(): zero_stage=args.zero, cpu_offload=True, enable_fused_normalization=torch.cuda.is_available(), + enable_flash_attention=args.xformers, microbatch_size=args.mbs, initial_scale=2**8, precision="bf16", @@ -167,9 +195,12 @@ def empty_init(): # ============================== # Initialize Dataset and Dataloader # ============================== - dp_size = plugin.dp_size if isinstance(plugin, HybridParallelPlugin) else coordinator.world_size + dp_size = getattr(plugin, "dp_size", coordinator.world_size) - config = MODEL_CONFIGS[args.config] + if args.config in MODEL_CONFIGS: + config = MODEL_CONFIGS[args.config] + else: + config = AutoConfig.from_pretrained(args.config, trust_remote_code=True) dataset = RandomDataset( num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size ) @@ -184,14 +215,17 @@ def empty_init(): else nullcontext() ) + init_kwargs = {} + if config.model_type == "chatglm": + init_kwargs["empty_init"] = False + with init_ctx: - model = LlamaForCausalLM(config) + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, **init_kwargs) if args.grad_checkpoint: model.gradient_checkpointing_enable() - - if args.xformers: - replace_with_flash_attention(model) + if config.model_type == "chatglm": + model.transformer.encoder.gradient_checkpointing = True model_numel = get_model_numel(model) coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") diff --git a/examples/language/llama2/finetune.py b/examples/language/llama2/finetune.py deleted file mode 100644 index 69b4ebe42bf7..000000000000 --- a/examples/language/llama2/finetune.py +++ /dev/null @@ -1,313 +0,0 @@ -import argparse -import math -import os -import resource -from contextlib import nullcontext -from functools import partial -from typing import Optional, Tuple - -import torch -import torch.distributed as dist -import torch.nn as nn -from attn import replace_with_flash_attention -from data_utils import load_json, prepare_dataloader, save_json -from datasets import load_dataset -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler -from torch.utils.tensorboard import SummaryWriter -from tqdm import tqdm -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import LlamaForCausalLM -from transformers.models.llama.tokenization_llama import LlamaTokenizer - -import colossalai -from colossalai.accelerator import get_accelerator -from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin -from colossalai.cluster import DistCoordinator -from colossalai.lazy import LazyInitContext -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.nn.optimizer import HybridAdam - - -def get_model_numel(model: nn.Module) -> int: - return sum(p.numel() for p in model.parameters()) - - -def format_numel_str(numel: int) -> str: - B = 1024**3 - M = 1024**2 - K = 1024 - if numel >= B: - return f"{numel / B:.2f} B" - elif numel >= M: - return f"{numel / M:.2f} M" - elif numel >= K: - return f"{numel / K:.2f} K" - else: - return f"{numel}" - - -def tokenize_batch_for_finetune(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048): - texts = [sample["prompt"] + sample["completion"] for sample in batch] - data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length) - data = {k: v.cuda() for k, v in data.items()} - data["labels"] = data["input_ids"].clone() - return data - - -def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - tensor = tensor.data - tensor.div_(dist.get_world_size()) - return tensor - - -def save( - booster: Booster, - model: nn.Module, - optimizer: Optimizer, - lr_scheduler: _LRScheduler, - epoch: int, - step: int, - batch_size: int, - coordinator: DistCoordinator, - save_dir: str, -): - save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}") - os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) - - booster.save_model(model, os.path.join(save_dir, "model"), shard=True) - booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) - booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) - running_states = { - "epoch": epoch, - "step": step, - "sample_start_index": step * batch_size, - } - if coordinator.is_master(): - save_json(running_states, os.path.join(save_dir, "running_states.json")) - - -def load( - booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str -) -> Tuple[int, int, int]: - booster.load_model(model, os.path.join(load_dir, "model")) - booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer")) - booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler")) - running_states = load_json(os.path.join(load_dir, "running_states.json")) - return running_states["epoch"], running_states["step"], running_states["sample_start_index"] - - -def _criterion(outputs, inputs): - return outputs.loss - - -def main(): - # ============================== - # Parse Arguments - # ============================== - parser = argparse.ArgumentParser() - parser.add_argument("--model_path", type=str, help="pretrained checkpoint path, used with mode==finetune") - parser.add_argument( - "-p", - "--plugin", - choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "hybrid_parallel"], - default="gemini", - help="Choose which plugin to use", - ) - parser.add_argument("-d", "--dataset", type=str, default="yizhongw/self_instruct", help="Data set path") - parser.add_argument("--task_name", type=str, default="super_natural_instructions", help="task to run") - parser.add_argument("-e", "--num_epochs", type=int, default=1, help="Number of epochs") - parser.add_argument("-b", "--batch_size", type=int, default=2, help="Local batch size") - parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") - parser.add_argument("-w", "--weigth_decay", type=float, default=0.1, help="Weight decay") - parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") - parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") - parser.add_argument("-x", "--mixed_precision", default="fp16", choices=["fp16", "bf16"], help="Mixed precision") - parser.add_argument("-i", "--save_interval", type=int, default=1000, help="Save interval") - parser.add_argument("-o", "--save_dir", type=str, default="checkpoint", help="Checkpoint directory") - parser.add_argument("-f", "--load", type=str, default=None, help="Load checkpoint") - parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping") - parser.add_argument("-t", "--tensorboard_dir", type=str, default="tb_logs", help="Tensorboard directory") - parser.add_argument("-a", "--flash_attention", action="store_true", help="Use Flash Attention") - args = parser.parse_args() - - # ============================== - # Initialize Distributed Training - # ============================== - colossalai.launch_from_torch({}) - coordinator = DistCoordinator() - - # ============================== - # Initialize Booster - # ============================== - if args.plugin == "gemini": - plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip) - elif args.plugin == "gemini_auto": - plugin = GeminiPlugin( - precision=args.mixed_precision, placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip - ) - elif args.plugin == "zero2": - plugin = LowLevelZeroPlugin( - stage=2, precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip - ) - elif args.plugin == "zero2_cpu": - plugin = LowLevelZeroPlugin( - stage=2, precision=args.mixed_precision, initial_scale=2**16, cpu_offload=True, max_norm=args.grad_clip - ) - elif args.plugin == "hybrid_parallel": - # modify the param accordingly, default configuration is for llama2-7b - plugin = HybridParallelPlugin( - tp_size=4, - pp_size=2, - num_microbatches=None, - microbatch_size=1, - enable_jit_fused=False, - zero_stage=0, - precision="fp32", - initial_scale=1, - ) - else: - raise ValueError(f"Unknown plugin {args.plugin}") - - booster = Booster(plugin=plugin) - - use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 - is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() - print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage) - - # ============================== - # Initialize Tensorboard - # ============================== - if print_flag: - os.makedirs(args.tensorboard_dir, exist_ok=True) - writer = SummaryWriter(args.tensorboard_dir) - - # ============================== - # Initialize Model, Optimizer and LR Scheduler - # ============================== - - config = LlamaConfig.from_pretrained(args.model_path) - # use lazy init when using GeminiPlugin - init_ctx = ( - LazyInitContext(default_device=get_accelerator().get_current_device()) - if isinstance(plugin, GeminiPlugin) - else nullcontext() - ) - - with init_ctx: - model = LlamaForCausalLM(config) - - # ============================== - # Initialize Tokenizer, Dataset and Dataloader - # ============================== - tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - # follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257 - tokenizer.pad_token = tokenizer.unk_token - - dataset = load_dataset(args.dataset, args.task_name) - train_ds = dataset["train"] - dataloader = prepare_dataloader( - train_ds, - batch_size=args.batch_size, - shuffle=True, - drop_last=True, - collate_fn=partial(tokenize_batch_for_finetune, tokenizer=tokenizer, max_length=args.max_length), - ) - - if args.grad_checkpoint: - model.gradient_checkpointing_enable() - if args.flash_attention: - replace_with_flash_attention(model) - - model_numel = get_model_numel(model) - coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") - - optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay) - total_step = args.num_epochs * len(dataloader) - lr_scheduler = CosineAnnealingWarmupLR( - optimizer, total_steps=total_step, warmup_steps=math.ceil(total_step * 0.03), eta_min=0.1 * args.lr - ) - default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 - torch.set_default_dtype(default_dtype) - model, optimizer, _, dataloader, lr_scheduler = booster.boost( - model, optimizer, dataloader=dataloader, lr_scheduler=lr_scheduler - ) - torch.set_default_dtype(torch.float) - - booster.load_model(model, args.model_path) - - coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") - coordinator.print_on_master( - f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" - ) - - # load checkpoint if specified - start_epoch = 0 - start_step = 0 - sampler_start_idx = 0 - if args.load is not None: - coordinator.print_on_master("Loading checkpoint") - start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load) - coordinator.print_on_master(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}") - - num_steps_per_epoch = len(dataloader) - - # if resume training, set the sampler start index to the correct value - dataloader.sampler.set_start_index(sampler_start_idx) - for epoch in range(start_epoch, args.num_epochs): - dataloader.sampler.set_epoch(epoch) - step_nums = num_steps_per_epoch - start_step - dataloader_iter = iter(dataloader) - - with tqdm( - range(step_nums), - desc=f"Epoch {epoch}", - disable=not print_flag, - total=num_steps_per_epoch, - initial=start_step, - ) as pbar: - for step in pbar: - if use_pipeline: - outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True) - loss = outputs["loss"] - else: - batch = next(dataloader_iter) - outputs = model(**batch) - loss = outputs[0] - booster.backward(loss, optimizer) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - if not use_pipeline: - all_reduce_mean(loss) - if print_flag: - pbar.set_postfix({"loss": loss.item()}) - writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step) - - if args.save_interval > 0 and (step + 1) % args.save_interval == 0: - coordinator.print_on_master(f"Saving checkpoint") - save( - booster, - model, - optimizer, - lr_scheduler, - epoch, - step + 1, - args.batch_size, - coordinator, - args.save_dir, - ) - coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}") - # the continue epochs are not resumed, so we need to reset the sampler start index and start step - dataloader.sampler.set_start_index(0) - start_step = 0 - - coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") - - -if __name__ == "__main__": - main() diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py deleted file mode 100644 index 970cd5290f9f..000000000000 --- a/examples/language/llama2/pretrain.py +++ /dev/null @@ -1,328 +0,0 @@ -import argparse -import os -import resource -from contextlib import nullcontext -from functools import partial -from typing import Optional, Tuple - -import torch -import torch.distributed as dist -import torch.nn as nn -from attn import replace_with_flash_attention -from data_utils import load_json, prepare_dataloader, save_json -from datasets import load_dataset -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler -from torch.utils.tensorboard import SummaryWriter -from tqdm import tqdm -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import LlamaForCausalLM -from transformers.models.llama.tokenization_llama import LlamaTokenizer - -import colossalai -from colossalai.accelerator import get_accelerator -from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin -from colossalai.cluster import DistCoordinator -from colossalai.lazy import LazyInitContext -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.nn.optimizer import HybridAdam - -MODEL_CONFIGS = { - "7b": LlamaConfig(max_position_embeddings=4096), - "13b": LlamaConfig( - hidden_size=5120, - intermediate_size=13824, - num_hidden_layers=40, - num_attention_heads=40, - max_position_embeddings=4096, - ), - "70b": LlamaConfig( - hidden_size=8192, - intermediate_size=28672, - num_hidden_layers=80, - num_attention_heads=64, - max_position_embeddings=4096, - num_key_value_heads=8, - ), -} - - -def get_model_numel(model: nn.Module) -> int: - return sum(p.numel() for p in model.parameters()) - - -def format_numel_str(numel: int) -> str: - B = 1024**3 - M = 1024**2 - K = 1024 - if numel >= B: - return f"{numel / B:.2f} B" - elif numel >= M: - return f"{numel / M:.2f} M" - elif numel >= K: - return f"{numel / K:.2f} K" - else: - return f"{numel}" - - -def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048): - texts = [sample["text"] for sample in batch] - data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length) - data = {k: v.cuda() for k, v in data.items()} - data["labels"] = data["input_ids"].clone() - return data - - -def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - tensor = tensor.data - tensor.div_(dist.get_world_size()) - return tensor - - -def save( - booster: Booster, - model: nn.Module, - optimizer: Optimizer, - lr_scheduler: _LRScheduler, - epoch: int, - step: int, - batch_size: int, - coordinator: DistCoordinator, - save_dir: str, -): - save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}") - os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) - - booster.save_model(model, os.path.join(save_dir, "model"), shard=True) - booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) - booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) - running_states = { - "epoch": epoch, - "step": step, - "sample_start_index": step * batch_size, - } - if coordinator.is_master(): - save_json(running_states, os.path.join(save_dir, "running_states.json")) - - -def load( - booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str -) -> Tuple[int, int, int]: - booster.load_model(model, os.path.join(load_dir, "model")) - booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer")) - booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler")) - running_states = load_json(os.path.join(load_dir, "running_states.json")) - return running_states["epoch"], running_states["step"], running_states["sample_start_index"] - - -def _criterion(outputs, inputs): - return outputs.loss - - -def main(): - # ============================== - # Parse Arguments - # ============================== - parser = argparse.ArgumentParser() - parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration") - parser.add_argument( - "-p", - "--plugin", - choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "hybrid_parallel"], - default="gemini", - help="Choose which plugin to use", - ) - parser.add_argument( - "-d", "--dataset", type=str, default="togethercomputer/RedPajama-Data-1T-Sample", help="Data set path" - ) - parser.add_argument("-e", "--num_epochs", type=int, default=1, help="Number of epochs") - parser.add_argument("-b", "--batch_size", type=int, default=2, help="Local batch size") - parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") - parser.add_argument("-w", "--weigth_decay", type=float, default=0.1, help="Weight decay") - parser.add_argument("-s", "--warmup_steps", type=int, default=2000, help="Warmup steps") - parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") - parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") - parser.add_argument("-x", "--mixed_precision", default="fp16", choices=["fp16", "bf16"], help="Mixed precision") - parser.add_argument("-i", "--save_interval", type=int, default=1000, help="Save interval") - parser.add_argument("-o", "--save_dir", type=str, default="checkpoint", help="Checkpoint directory") - parser.add_argument("-f", "--load", type=str, default=None, help="Load checkpoint") - parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping") - parser.add_argument("-t", "--tensorboard_dir", type=str, default="tb_logs", help="Tensorboard directory") - parser.add_argument("-a", "--flash_attention", action="store_true", help="Use Flash Attention") - args = parser.parse_args() - - # ============================== - # Initialize Distributed Training - # ============================== - colossalai.launch_from_torch({}) - coordinator = DistCoordinator() - - # ============================== - # Initialize Booster - # ============================== - if args.plugin == "gemini": - plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip) - elif args.plugin == "gemini_auto": - plugin = GeminiPlugin( - precision=args.mixed_precision, placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip - ) - elif args.plugin == "zero2": - plugin = LowLevelZeroPlugin( - stage=2, precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip - ) - elif args.plugin == "zero2_cpu": - plugin = LowLevelZeroPlugin( - stage=2, precision=args.mixed_precision, initial_scale=2**16, cpu_offload=True, max_norm=args.grad_clip - ) - elif args.plugin == "hybrid_parallel": - # modify the param accordingly, default configuration is for llama2-7b - plugin = HybridParallelPlugin( - tp_size=4, - pp_size=2, - num_microbatches=None, - microbatch_size=1, - enable_jit_fused=False, - zero_stage=0, - precision=args.mixed_precision, - initial_scale=1, - ) - else: - raise ValueError(f"Unknown plugin {args.plugin}") - - booster = Booster(plugin=plugin) - - use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 - is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() - print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage) - - # ============================== - # Initialize Tensorboard - # ============================== - if print_flag: - os.makedirs(args.tensorboard_dir, exist_ok=True) - writer = SummaryWriter(args.tensorboard_dir) - - # ============================== - # Initialize Tokenizer, Dataset and Dataloader - # ============================== - tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - # follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257 - tokenizer.pad_token = tokenizer.unk_token - - dataset = load_dataset(args.dataset) - train_ds = dataset["train"] - dataloader = prepare_dataloader( - train_ds, - batch_size=args.batch_size, - shuffle=True, - drop_last=True, - collate_fn=partial(tokenize_batch_for_pretrain, tokenizer=tokenizer, max_length=args.max_length), - ) - - # ============================== - # Initialize Model, Optimizer and LR Scheduler - # ============================== - config = MODEL_CONFIGS[args.config] - # use lazy init when using GeminiPlugin - init_ctx = ( - LazyInitContext(default_device=get_accelerator().get_current_device()) - if isinstance(plugin, GeminiPlugin) - else nullcontext() - ) - - with init_ctx: - model = LlamaForCausalLM(config) - - if args.grad_checkpoint: - model.gradient_checkpointing_enable() - if args.flash_attention: - replace_with_flash_attention(model) - - model_numel = get_model_numel(model) - coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") - - optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay) - lr_scheduler = CosineAnnealingWarmupLR( - optimizer, total_steps=args.num_epochs * len(dataloader), warmup_steps=args.warmup_steps, eta_min=0.1 * args.lr - ) - default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 - torch.set_default_dtype(default_dtype) - model, optimizer, _, dataloader, lr_scheduler = booster.boost( - model, optimizer, dataloader=dataloader, lr_scheduler=lr_scheduler - ) - torch.set_default_dtype(torch.float) - - coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") - coordinator.print_on_master( - f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" - ) - - # load checkpoint if specified - start_epoch = 0 - start_step = 0 - sampler_start_idx = 0 - if args.load is not None: - coordinator.print_on_master("Loading checkpoint") - start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load) - coordinator.print_on_master(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}") - - num_steps_per_epoch = len(dataloader) - - # if resume training, set the sampler start index to the correct value - dataloader.sampler.set_start_index(sampler_start_idx) - for epoch in range(start_epoch, args.num_epochs): - dataloader.sampler.set_epoch(epoch) - dataloader_iter = iter(dataloader) - - with tqdm( - range(start_step, num_steps_per_epoch), - desc=f"Epoch {epoch}", - disable=not print_flag, - total=num_steps_per_epoch, - initial=start_step, - ) as pbar: - for step in pbar: - if use_pipeline: - outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True) - loss = outputs["loss"] - else: - batch = next(dataloader_iter) - outputs = model(**batch) - loss = outputs[0] - booster.backward(loss, optimizer) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - if not use_pipeline: - all_reduce_mean(loss) - if print_flag: - pbar.set_postfix({"loss": loss.item()}) - writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step) - - if args.save_interval > 0 and (step + 1) % args.save_interval == 0: - coordinator.print_on_master(f"Saving checkpoint") - save( - booster, - model, - optimizer, - lr_scheduler, - epoch, - step + 1, - args.batch_size, - coordinator, - args.save_dir, - ) - coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}") - # the continue epochs are not resumed, so we need to reset the sampler start index and start step - dataloader.sampler.set_start_index(0) - start_step = 0 - - coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") - - -if __name__ == "__main__": - main() diff --git a/examples/language/llama2/requirements.txt b/examples/language/llama2/requirements.txt index 6b475682dad0..438a4999a3fe 100644 --- a/examples/language/llama2/requirements.txt +++ b/examples/language/llama2/requirements.txt @@ -1,9 +1,8 @@ -colossalai>=0.3.2 +colossalai>=0.3.6 datasets numpy -torch>=1.12.0,<=2.0.0 tqdm transformers -flash-attn>=2.0.0,<=2.0.5 +flash-attn>=2.0.0 SentencePiece==0.1.99 tensorboard==2.14.0