Skip to content

Commit

Permalink
merge with main
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Aug 28, 2024
2 parents 5365117 + cc1b0ef commit d0107c3
Show file tree
Hide file tree
Showing 82 changed files with 2,831 additions and 294 deletions.
1 change: 0 additions & 1 deletion .compatibility
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
2.1.0-12.1.0
2.2.2-12.1.0
2.3.0-12.1.0
2.4.0-12.4.1
2 changes: 1 addition & 1 deletion .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ jobs:
if: needs.detect.outputs.anyLibraryFileChanged == 'true'
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
options: --gpus all --rm -v /dev/shm -v /data/scratch:/data/scratch
timeout-minutes: 90
defaults:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build_on_schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
if: github.repository == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
timeout-minutes: 90
steps:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/doc_test_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
needs: detect-changed-doc
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
options: --gpus all --rm
timeout-minutes: 30
defaults:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/doc_test_on_schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
name: Test the changed Doc
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
options: --gpus all --rm
timeout-minutes: 60
steps:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/example_check_on_dispatch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
fail-fast: false
matrix: ${{fromJson(needs.manual_check_matrix_preparation.outputs.matrix)}}
container:
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/ -v /dev/shm
timeout-minutes: 15
steps:
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/example_check_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ on:
paths:
- "examples/**"
- "!examples/**.md"
- ".github/workflows/example_check_on_pr.yml"

jobs:
# This is for changed example files detect and output a matrix containing all the corresponding directory name.
Expand Down Expand Up @@ -89,7 +90,7 @@ jobs:
fail-fast: false
matrix: ${{fromJson(needs.detect-changed-example.outputs.matrix)}}
container:
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/ -v /dev/shm
timeout-minutes: 30
concurrency:
Expand All @@ -107,7 +108,7 @@ jobs:
- name: Install Colossal-AI
run: |
BUILD_EXT=1 pip install -v .
BUILD_EXT=1 pip install -v -e .
- name: Store Colossal-AI Cache
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/example_check_on_schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
fail-fast: false
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/ -v /dev/shm
timeout-minutes: 30
steps:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_chatgpt_examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/scratch/examples-data --shm-size=10.24gb
timeout-minutes: 60
defaults:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_chatgpt_unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/scratch/examples-data
timeout-minutes: 30
defaults:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_colossalqa_unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
volumes:
- /data/scratch/test_data_colossalqa:/data/scratch/test_data_colossalqa
- /data/scratch/llama-tiny:/data/scratch/llama-tiny
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ Please visit our [documentation](https://www.colossalai.org/) and [examples](htt
## Installation

Requirements:
- PyTorch >= 2.1
- PyTorch >= 2.2
- Python >= 3.7
- CUDA >= 11.0
- [NVIDIA GPU Compute Capability](https://developer.nvidia.com/cuda-gpus) >= 7.0 (V100/RTX20 and higher)
Expand Down
4 changes: 4 additions & 0 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,9 @@ def __init__(
enable_jit_fused: bool = False,
enable_sequence_overlap: bool = False,
enable_async_reduce: bool = True,
use_fp8: bool = False,
verbose: bool = False,
fp8_communication: bool = False,
) -> None:
super().__init__()
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
Expand Down Expand Up @@ -401,6 +403,8 @@ def __init__(
master_weights=master_weights,
max_prefetch=max_prefetch,
enable_async_reduce=enable_async_reduce,
fp8_communication=fp8_communication,
use_fp8=use_fp8,
)
self.zero_optim_config = dict(
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
Expand Down
23 changes: 21 additions & 2 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp
from colossalai.shardformer.policies.base_policy import Policy
Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(
ddp_config: dict,
custom_policy: Policy,
overlap_allgather: bool = False,
use_fp8: bool = False,
) -> None:
self.stage_manager = shard_config.pipeline_stage_manager
self.shard_config = shard_config
Expand All @@ -75,6 +77,7 @@ def __init__(
self.use_ddp = use_ddp
self.require_grad_sync = True
self.overlap_allgather = overlap_allgather
self.use_fp8 = use_fp8

shardformer = ShardFormer(shard_config)
if custom_policy is not None:
Expand Down Expand Up @@ -112,6 +115,9 @@ def __init__(
module = DDP(module, process_group=dp_group, **ddp_config)

super().__init__(module)
self.op_hooks = []
if use_fp8:
self.op_hooks.append(FP8Hook())
if overlap_allgather:
self.op_hook = ZeroOpHook()
for p in module.parameters():
Expand Down Expand Up @@ -223,7 +229,11 @@ def _force_wait_all_gather(self):
wait_all_gather_handle(p)

def _wait_all_gather(self):
return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
return (
ColoParamOpHookManager.use_hooks(*self.op_hooks)
if (self.overlap_allgather or self.use_fp8)
else nullcontext()
)


def get_param_info(optim: Optimizer):
Expand Down Expand Up @@ -969,6 +979,7 @@ class HybridParallelPlugin(PipelinePluginBase):
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn".
It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default.
Expand Down Expand Up @@ -1020,6 +1031,8 @@ def __init__(
dp_outside: bool = True,
overlap_p2p: bool = True,
overlap_allgather: bool = False,
fp8_communication: bool = False,
use_fp8: bool = False,
inner_ring_size: int = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -1069,8 +1082,10 @@ def __init__(
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
self.use_fp8 = use_fp8
if dp_outside:
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
if sequence_parallelism_mode == "ring_attn":
# Swap tp and sp since 2D Ring has better inter-node latency
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size)
Expand Down Expand Up @@ -1117,13 +1132,15 @@ def __init__(
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
overlap_p2p=overlap_p2p,
fp8_communication=fp8_communication,
)
elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule(
stage_manager=self.stage_manager,
num_microbatches=num_microbatches,
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
fp8_communication=fp8_communication,
)
else:
raise NotImplementedError()
Expand Down Expand Up @@ -1158,6 +1175,7 @@ def __init__(
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication,
inner_ring_size=inner_ring_size,
)
self.amp_config = dict(
Expand Down Expand Up @@ -1250,7 +1268,7 @@ def configure(
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1 and self.pp_size == 1
)

# sync gradients across DP * SP ranks
# Apply Hybrid ZeRO across DP * SP ranks
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
Expand All @@ -1268,6 +1286,7 @@ def configure(
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
use_fp8=self.use_fp8,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if zero_stage == 0:
Expand Down
22 changes: 19 additions & 3 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero import LowLevelZeroOptimizer
Expand Down Expand Up @@ -62,7 +63,12 @@ class OptimizerParamCheckState(enum.Enum):

class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(
self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True
self,
module: nn.Module,
precision: str,
overlap_allgather: bool = False,
cast_inputs: bool = True,
use_fp8: bool = False,
) -> None:
super().__init__(module)
self.dtype = None
Expand All @@ -75,11 +81,16 @@ def __init__(
module = module.to(get_accelerator().get_current_device())
self.module = module
self.convert_fn = None
self.use_fp8 = use_fp8
if self.dtype is not None and cast_inputs:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
self.overlap_allgather = overlap_allgather
self.op_hooks = []
if overlap_allgather:
self.op_hook = ZeroOpHook()
self.op_hooks.append(ZeroOpHook())
if use_fp8:
self.op_hooks.append(FP8Hook())
if overlap_allgather or use_fp8:
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
p.__class__ = ColoParameter
Expand All @@ -89,7 +100,7 @@ def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
ctx = ColoParamOpHookManager.use_hooks(*self.op_hooks) if self.overlap_allgather else nullcontext()
with ctx:
return super().forward(*args, **kwargs)

Expand Down Expand Up @@ -337,6 +348,8 @@ def __init__(
master_weights: bool = True,
verbose: bool = False,
cast_inputs: bool = True,
fp8_communication: bool = False,
use_fp8: bool = False,
) -> None:
super().__init__()
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
Expand All @@ -360,12 +373,14 @@ def __init__(
cpu_offload=cpu_offload,
master_weights=master_weights,
overlap_allgather=overlap_allgather,
fp8_communication=fp8_communication,
)
self.lora_enabled = False
self.verbose = verbose
self.logger = get_dist_logger()
self.cast_inputs = cast_inputs

self.use_fp8 = use_fp8
# set class name with stage, for better error message
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")

Expand Down Expand Up @@ -484,6 +499,7 @@ def configure(
self.precision,
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
cast_inputs=self.cast_inputs,
use_fp8=self.use_fp8,
)

# TODO: Support Galore + ZeRO
Expand Down
5 changes: 5 additions & 0 deletions colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ def __init__(
moe_dp_outside: bool = True,
overlap_p2p: bool = True,
overlap_allgather: bool = False,
fp8_communication: bool = False,
use_fp8: bool = False,
) -> None:
self.logger = get_dist_logger()
if overlap_communication or zero_stage == 2:
Expand Down Expand Up @@ -327,6 +329,7 @@ def __init__(
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
else:
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
self.use_fp8 = use_fp8

self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
Expand All @@ -345,6 +348,7 @@ def __init__(
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication,
)
self.amp_config = dict(
initial_scale=initial_scale,
Expand Down Expand Up @@ -431,6 +435,7 @@ def configure(
use_ddp=use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
use_fp8=self.use_fp8,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.ep_size > 1:
Expand Down
7 changes: 7 additions & 0 deletions colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def __init__(
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
fp8_communication: bool = False,
) -> None:
super().__init__()
self.ddp_kwargs = dict(
Expand All @@ -189,6 +190,7 @@ def __init__(
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
)
self.fp8_communication = fp8_communication

def support_no_sync(self) -> bool:
return True
Expand Down Expand Up @@ -228,6 +230,11 @@ def configure(
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer)

if self.fp8_communication:
from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async

model.module.register_comm_hook(None, fp8_compress_ddp_grad_comm_hook_async)

return model, optimizer, criterion, dataloader, lr_scheduler

def control_checkpoint_io(self) -> bool:
Expand Down
Loading

0 comments on commit d0107c3

Please sign in to comment.