Skip to content

Commit

Permalink
[shardformer] fix chatglm implementation (hpcaitech#5644)
Browse files Browse the repository at this point in the history
* [shardformer] fix chatglm policy

* [shardformer] fix chatglm flash attn

* [shardformer] update readme

* [shardformer] fix chatglm init

* [shardformer] fix chatglm test

* [pipeline] fix chatglm merge batch
  • Loading branch information
ver217 authored and wangbluo committed May 7, 2024
1 parent b65f351 commit e750916
Show file tree
Hide file tree
Showing 11 changed files with 193 additions and 117 deletions.
12 changes: 9 additions & 3 deletions colossalai/pipeline/schedule/one_f_one_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.utils._pytree import tree_map

from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils import get_current_device
Expand Down Expand Up @@ -327,7 +327,10 @@ def run_forward_only(
self.send_forward(output_obj)

if outputs is not None:
outputs = merge_batch(outputs)
if isinstance(model, ModelWrapper):
model = model.unwrap()
batch_size_dim = getattr(model, "batch_size_dim", 0)
outputs = merge_batch(outputs, batch_size_dim)
return {"loss": accum_loss, "outputs": outputs}

def run_forward_backward(
Expand Down Expand Up @@ -410,7 +413,10 @@ def run_forward_backward(
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)

if outputs is not None:
outputs = merge_batch(outputs)
if isinstance(model, ModelWrapper):
model = model.unwrap()
batch_size_dim = getattr(model, "batch_size_dim", 0)
outputs = merge_batch(outputs, batch_size_dim)
return {"loss": accum_loss, "outputs": outputs}

def forward_backward_step(
Expand Down
121 changes: 79 additions & 42 deletions colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,30 +114,30 @@ We will follow this roadmap to develop Shardformer:
- [x] Unit Testing
- [ ] Policy Implementation

| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
| :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: |
| bert | [] | [] | [] | [] | [] | [] | [] | [] | [] |
| t5 | [] | [] | [] | [] | [] | [] | [] | [ ] | [ ] |
| llama V1/V2 | [] | [] | [] | [] | [] | [] | [] | [ ] | [ ] |
| gpt2 | [] | [] | [] | [] | [] | [] | [] | [] | [] |
| opt | [] | [] | [] | [] | [] | [] | [] | [ ] | [ ] |
| bloom | [] | [] | [] | [] | [] | [] | [] | [] | [] |
| chatglm2 | [] | [] | [] | [] | [] | [] | [] | [] | [] |
| vit | [] | [] | [ ] | [] | [] | [] | [] | [ ] | [ ] |
| whisper | [] | [] | [] | [] | [] | [ ] | [] | [ ] | [ ] |
| sam | [] | [ ] | [ ] | [] | [] | [] | [] | [ ] | [ ] |
| blip2 | [] | [ ] | [ ] | [] | [] | [] | [] | [ ] | [ ] |
| falcon | [] | [] | [] | [] | [] | [ ] | [] | [ ] | [ ] |
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| mistral | [] | [ ] | [ ] | [] | [] | [] | [] | [ ] | [ ] |
| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
|:-----------:|:---------------:|:-----------------:|:-------------------:|:-------:|:-----------:|:------------------:|:---------------:|:-----------------:|:-------:|
| bert | [] | [] | [] | [] | [] | [] | [] | [] | [] |
| t5 | [] | [] | [] | [] | [] | [] | [] | [ ] | [ ] |
| llama V1/V2 | [] | [] | [] | [] | [] | [] | [] | [ ] | [ ] |
| gpt2 | [] | [] | [] | [] | [] | [] | [] | [] | [] |
| opt | [] | [] | [] | [] | [] | [] | [] | [ ] | [ ] |
| bloom | [] | [] | [] | [] | [] | [] | [] | [] | [] |
| chatglm2 | [] | [] | [] | [] | [] | [] | [] | [] | [] |
| vit | [] | [] | [ ] | [] | [] | [] | [] | [ ] | [ ] |
| whisper | [] | [] | [] | [] | [] | [ ] | [] | [ ] | [ ] |
| sam | [] | [ ] | [ ] | [] | [] | [] | [] | [ ] | [ ] |
| blip2 | [] | [ ] | [ ] | [] | [] | [] | [] | [ ] | [ ] |
| falcon | [] | [] | [] | [] | [] | [ ] | [] | [ ] | [ ] |
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| mistral | [] | [ ] | [ ] | [] | [] | [] | [] | [ ] | [ ] |


## 💡 API Design
Expand Down Expand Up @@ -391,6 +391,43 @@ _POLICY_LIST = {
}
```

#### How to support those models in huggingface model hub but not in the transformers library

There are two cases:

1. the modeling file is in the `transformers` library but the model weight is not in the `transformers` library. E.g. model structure of "01-ai/Yi-34B" is the same as LLaMA but the weight is not in the `transformers` library. In this case, we should support llama as usual and Yi-34B is also supported by the llama policy. We do not need to add a new policy for Yi-34B.
2. the modeling file is not in the `transformers` library, such as the "THUDM/chatglm2-6b".

Take "THUDM/chatglm2-6b" as an example, we clearly illustrate how to support this model in the `shardformer`.

Unlike llama which is in `transformers` library, we cannot import chatglm2 model directly. Thus, the key in policy should be str of class name, rather than class itself.

E.g. for llama:
```python
policy[LlamaDecoderLayer] = ModulePolicyDescription(...)
```

for chatglm2:
```python
policy["GLMBlock"] = ModulePolicyDescription(...)
```

Then when registering such models in the autopolicy, we should follow below format:
```python
"transformers_modules.<modeling_filename>.<class_name>": PolicyLocation(
file_name="<policy_filename>", class_name="<policy_class_name>"
)
```

As for chatglm2 model, it should be:
```python
"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
)
```

When using such models, `AutoModel` is supported as usual. The policy will be automatically loaded by the autopolicy.

### Write Your Unit Testing

This section serves as the guideline for testing the `shardformer` module.
Expand Down Expand Up @@ -424,13 +461,13 @@ We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate
We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length.

In the case of using 2 GPUs, the training times are as follows.
| N_CTX | org_model | shard_model |
| :------: | :-----: | :-----: |
| 256 | 11.2ms | 17.2ms |
| 512 | 9.8ms | 19.5ms |
| 1024 | 19.6ms | 18.9ms |
| 2048 | 46.6ms | 30.8ms |
| 4096 | 160.5ms | 90.4ms |
| N_CTX | org_model | shard_model |
|:-----:|:---------:|:-----------:|
| 256 | 11.2ms | 17.2ms |
| 512 | 9.8ms | 19.5ms |
| 1024 | 19.6ms | 18.9ms |
| 2048 | 46.6ms | 30.8ms |
| 4096 | 160.5ms | 90.4ms |


<p align="center">
Expand All @@ -440,13 +477,13 @@ In the case of using 2 GPUs, the training times are as follows.

In the case of using 4 GPUs, the training times are as follows.

| N_CTX | org_model | shard_model |
| :------: | :-----: | :-----: |
| 256 | 10.0ms | 21.1ms |
| 512 | 11.5ms | 20.2ms |
| 1024 | 22.1ms | 20.6ms |
| 2048 | 46.9ms | 24.8ms |
| 4096 | 160.4ms | 68.0ms |
| N_CTX | org_model | shard_model |
|:-----:|:---------:|:-----------:|
| 256 | 10.0ms | 21.1ms |
| 512 | 11.5ms | 20.2ms |
| 1024 | 22.1ms | 20.6ms |
| 2048 | 46.9ms | 24.8ms |
| 4096 | 160.4ms | 68.0ms |



Expand Down Expand Up @@ -475,10 +512,10 @@ warmup_fraction = 0.03


| accuracy | f1 | loss | GPU number | model sharded |
| :------: | :-----: | :-----: | :--------: | :---------: |
| 0.82971 | 0.87713 | 0.23194 | 4 | True |
| 0.83797 | 0.88006 | 0.22683 | 2 | True |
| 0.84521 | 0.88700 | 0.21822 | 1 | False |
|:--------:|:-------:|:-------:|:----------:|:-------------:|
| 0.82971 | 0.87713 | 0.23194 | 4 | True |
| 0.83797 | 0.88006 | 0.22683 | 2 | True |
| 0.84521 | 0.88700 | 0.21822 | 1 | False |


Overall, the results demonstrate that using shardformers during model training does not affect the convergence.
19 changes: 8 additions & 11 deletions colossalai/shardformer/layer/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,19 +281,16 @@ def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *arg
)

LazyInitContext.materialize(module)
# to check if it is huggingface LlamaRMSNorm or MistralRMSNorm
if module.__class__.__name__ in ["LlamaRMSNorm", "Qwen2RMSNorm", "MistralRMSNorm"]:
normalized_shape = module.weight.shape[0]
eps = module.variance_epsilon
elementwise_affine = True
else:
# get the attributes of the module
normalized_shape = module.normalized_shape
eps = module.eps
elementwise_affine = module.elementwise_affine

# try to get normalized_shape, eps, elementwise_affine from the module
normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0])
eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
elementwise_affine = getattr(module, "elementwise_affine", True)

rmsnorm = FusedRMSNormWithHook(
normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine
normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
)

rmsnorm.weight = module.weight
Expand Down
21 changes: 15 additions & 6 deletions colossalai/shardformer/modeling/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel


def get_flash_core_attention_forward():
Expand All @@ -31,7 +30,12 @@ def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_
device=query_layer.device,
)
temp_mask = (
torch.ones(query_layer.shape[2], key_layer.shape[2], dtype=torch.bool, device=query_layer.device)
torch.ones(
query_layer.shape[2],
key_layer.shape[2],
dtype=torch.bool,
device=query_layer.device,
)
.tril(diagonal=0)
.expand(query_layer.shape[0], 1, -1, -1)
)
Expand All @@ -49,6 +53,7 @@ def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_
attention_mask=attn_bias,
attention_mask_type=attention_mask_type,
dropout_p=dropout_p,
scale=1.0 / self.norm_factor,
)
context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
Expand Down Expand Up @@ -115,7 +120,7 @@ class ChatGLMPipelineForwards:

@staticmethod
def chatglm_model_forward(
self: ChatGLMModel,
self: "ChatGLMModel",
input_ids,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
Expand Down Expand Up @@ -194,7 +199,9 @@ def chatglm_model_forward(
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = split_forward_gather_backward(
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
)
for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx)
Expand Down Expand Up @@ -224,7 +231,9 @@ def chatglm_model_forward(
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = gather_forward_split_backward(
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down Expand Up @@ -254,7 +263,7 @@ def chatglm_model_forward(

@staticmethod
def chatglm_for_conditional_generation_forward(
self: ChatGLMForConditionalGeneration,
self: "ChatGLMForConditionalGeneration",
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
Expand Down
13 changes: 10 additions & 3 deletions colossalai/shardformer/policies/auto_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,10 @@ class PolicyLocation:
file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"
),
# ChatGLM
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation(
"transformers_modules.modeling_chatglm.ChatGLMModel": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMModelPolicy"
),
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
),
# Falcon
Expand Down Expand Up @@ -212,6 +212,13 @@ def _fullname(obj):
module = klass.__module__
if module == "builtins":
return klass.__qualname__ # avoid outputs like 'builtins.str'
# patch custom models which are not in transformers
# it can be like 'transformers_modules.THUDM.chatglm3-6b.103caa40027ebfd8450289ca2f278eac4ff26405.modeling_chatglm' (from huggingface hub)
# or like 'transformers_modules.chatglm.modeling_chatglm' (from local directory)
if module.startswith("transformers_modules"):
split_module = module.split(".")
if len(split_module) >= 2:
module = f"{split_module[0]}.{split_module[-1]}"
return module + "." + klass.__qualname__


Expand All @@ -230,7 +237,7 @@ def get_autopolicy(model: nn.Module) -> Policy:

if policy_location is None:
raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
f"Auto policy for {model.__class__.__qualname__} ({full_name}) is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
)
else:
policy = import_policy(policy_location)
Expand Down
Loading

0 comments on commit e750916

Please sign in to comment.