From 6c619c9992222a203db5952e93ac8adb0cf96ef1 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Fri, 8 Nov 2024 03:30:21 +0000 Subject: [PATCH 01/18] update best answer function --- .../coati/reasoner/guided_search/mcts.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/applications/ColossalChat/coati/reasoner/guided_search/mcts.py b/applications/ColossalChat/coati/reasoner/guided_search/mcts.py index a87211da210c..438d0ed5fdb6 100644 --- a/applications/ColossalChat/coati/reasoner/guided_search/mcts.py +++ b/applications/ColossalChat/coati/reasoner/guided_search/mcts.py @@ -120,17 +120,16 @@ def simulate(self): self.back_propagation(child) return self.get_best_answer() - - def get_best_answer(self): + + def _iter_nodes(self): to_visit = deque([self.root]) - best_node = self.root - while to_visit: current_node = to_visit.popleft() - if current_node.Q > best_node.Q: - best_node = current_node + yield current_node to_visit.extend(current_node.children) - + + def get_best_answer(self): + best_node = max(self._iter_nodes(), key=lambda node: node.Q, default=self.root) return best_node.answer def self_refine(self, node: MCTSNode): From 308960534f3a5c67ba2abbc4a325e574610eee08 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 Nov 2024 12:38:37 +0000 Subject: [PATCH 02/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../ColossalChat/coati/reasoner/guided_search/mcts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/applications/ColossalChat/coati/reasoner/guided_search/mcts.py b/applications/ColossalChat/coati/reasoner/guided_search/mcts.py index 438d0ed5fdb6..100335ba96a4 100644 --- a/applications/ColossalChat/coati/reasoner/guided_search/mcts.py +++ b/applications/ColossalChat/coati/reasoner/guided_search/mcts.py @@ -120,14 +120,14 @@ def simulate(self): self.back_propagation(child) return self.get_best_answer() - + def _iter_nodes(self): to_visit = deque([self.root]) while to_visit: current_node = to_visit.popleft() yield current_node to_visit.extend(current_node.children) - + def get_best_answer(self): best_node = max(self._iter_nodes(), key=lambda node: node.Q, default=self.root) return best_node.answer From ed817a29f9a182cd5e346bc801a44fccf83a2133 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Fri, 8 Nov 2024 12:42:48 +0000 Subject: [PATCH 03/18] update --- applications/ColossalChat/coati/reasoner/guided_search/mcts.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/applications/ColossalChat/coati/reasoner/guided_search/mcts.py b/applications/ColossalChat/coati/reasoner/guided_search/mcts.py index 438d0ed5fdb6..64503630446e 100644 --- a/applications/ColossalChat/coati/reasoner/guided_search/mcts.py +++ b/applications/ColossalChat/coati/reasoner/guided_search/mcts.py @@ -1,10 +1,9 @@ """ Implementation of MCTS + Self-refine algorithm. -Reference: +Structure is adapted from https://github.com/BrendanGraham14/mcts-llm/ with the following reference: 1. "Accessing GPT-4 level Mathematical Olympiad Solutions via Monte Carlo Tree Self-refine with LLaMa-3 8B: A Technical Report" -2. https://github.com/BrendanGraham14/mcts-llm/ 3. https://github.com/trotsky1997/MathBlackBox/ 4. https://github.com/openreasoner/openr/blob/main/reason/guided_search/tree.py """ From 1210dbea975019f51e49a24c7cfc8e3507eadb0a Mon Sep 17 00:00:00 2001 From: Tong Li Date: Mon, 11 Nov 2024 07:26:32 +0000 Subject: [PATCH 04/18] update tokenization function --- .../coati/dataset/tokenization_utils.py | 46 ++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py index 020432b9ec3c..533d0acadae9 100755 --- a/applications/ColossalChat/coati/dataset/tokenization_utils.py +++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py @@ -1,13 +1,14 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -tokenization utils for constructing dataset for ppo, dpo, sft, rm +Tokenization Utils for Constructing Dataset for RL. """ import warnings from copy import deepcopy from typing import Any, Dict, List, Union +import torch from coati.dataset.conversation import Conversation from coati.dataset.utils import split_templated_prompt_into_chunks, tokenize_and_concatenate from datasets import dataset_dict @@ -393,3 +394,46 @@ def tokenize_kto( "input_id_decode": decoded_full_prompt, "completion_decode": decoded_completion, } + + +def tokenize_process_reward( + data_point: Dict[str, str], + tokenizer: PreTrainedTokenizer, + conversation_template: Conversation = None, + max_length: int = 4096, +) -> Dict[str, Union[int, str, List[int]]]: + """ + Tokenize function designed for tokenizing Math-Shepherd dataset. + + The datapoint has the following format: + { + "input": problem + step-by-step solution, + "label": problem + step-by-step solution with automatic label, + "task": GSM8K or MATH + } + + """ + input = data_point["input"] + label = data_point["label"] + + template = deepcopy(conversation_template) + template.append_message("user", input) + template.append_message("assistant", label) + prompt = template.get_prompt(add_generation_prompt=True) + reward_signal_id = tokenizer.convert_tokens_to_ids(template.reward_signal) + tokenized = tokenizer(prompt, add_special_tokens=False)["input_ids"] + + tokenized_tensor = torch.tensor(tokenized) + loss_mask = torch.isin(tokenized_tensor, torch.tensor(reward_signal_id)) + + label = (tokenized_tensor * loss_mask).tolist() + decoded_input = tokenizer.decode(tokenized, skip_special_tokens=False) + decoded_label = tokenizer.decode(label, skip_special_tokens=False) + + return { + "input": tokenized, + "label": label, + "loss_mask": loss_mask, + "decoded_input": decoded_input, + "decoded_label": decoded_label, + } From 73ebbef3a36ac01ea0086b0fef1a8e59e94c8f5f Mon Sep 17 00:00:00 2001 From: Tong Li Date: Mon, 11 Nov 2024 10:17:31 +0000 Subject: [PATCH 05/18] add prm dataset example --- .../ColossalChat/conversation_template/tiny-llama.json | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/applications/ColossalChat/conversation_template/tiny-llama.json b/applications/ColossalChat/conversation_template/tiny-llama.json index 59196159f930..49a1ec2e3d2f 100644 --- a/applications/ColossalChat/conversation_template/tiny-llama.json +++ b/applications/ColossalChat/conversation_template/tiny-llama.json @@ -4,5 +4,10 @@ "stop_ids": [ 2 ], - "end_of_assistant": "" + "end_of_assistant": "", + "step_score_signal": "ΠΊΠΈ", + "reward_signal": [ + "+", + "-" + ] } From 794e0d4f4aa2d8b1dd33b830227195447d1d87a4 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Mon, 11 Nov 2024 11:38:14 +0000 Subject: [PATCH 06/18] update conversation --- .../ColossalChat/coati/dataset/__init__.py | 3 +- .../coati/dataset/conversation.py | 58 +++++++++++-------- .../prepare_dataset.py | 57 +++--------------- 3 files changed, 45 insertions(+), 73 deletions(-) diff --git a/applications/ColossalChat/coati/dataset/__init__.py b/applications/ColossalChat/coati/dataset/__init__.py index 8e9060a1a1f9..78bd463591ea 100755 --- a/applications/ColossalChat/coati/dataset/__init__.py +++ b/applications/ColossalChat/coati/dataset/__init__.py @@ -7,7 +7,7 @@ StatefulDistributedSampler, load_tokenized_dataset, ) -from .tokenization_utils import tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft +from .tokenization_utils import tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft, tokenize_process_reward __all__ = [ "tokenize_prompt", @@ -23,4 +23,5 @@ "tokenize_kto", "setup_conversation_template", "Conversation", + "tokenize_process_reward" ] diff --git a/applications/ColossalChat/coati/dataset/conversation.py b/applications/ColossalChat/coati/dataset/conversation.py index a77c220d34af..f66deb885254 100755 --- a/applications/ColossalChat/coati/dataset/conversation.py +++ b/applications/ColossalChat/coati/dataset/conversation.py @@ -1,6 +1,6 @@ -import dataclasses import json import os +from dataclasses import dataclass, field from typing import Any, Dict, List import torch.distributed as dist @@ -11,14 +11,17 @@ logger = get_dist_logger() -@dataclasses.dataclass +@dataclass class Conversation: tokenizer: PreTrainedTokenizer system_message: str chat_template: str stop_ids: List[int] end_of_assistant: str - roles = ["user", "assistant"] + messages: List[Dict[str, str]] = field(default_factory=list) + roles: List[str] = field(default_factory=lambda: ["user", "assistant"]) + step_score_signal: str = None + reward_signal: List[str] = None @classmethod def from_config(cls, tokenizer: PreTrainedTokenizer, config: Dict): @@ -26,18 +29,26 @@ def from_config(cls, tokenizer: PreTrainedTokenizer, config: Dict): Setup the conversation template from config """ tokenizer.chat_template = config["chat_template"] - conv = cls( - tokenizer, config["system_message"], config["chat_template"], config["stop_ids"], config["end_of_assistant"] - ) - conv.clear() - return conv + conversation = cls(tokenizer, **config) + + special_tokens = [] + if conversation.step_score_signal is not None: + special_tokens.extend(conversation.step_score_signal) + + if conversation.reward_signal is not None: + special_tokens.extend(conversation.reward_signal) + + if special_tokens: + conversation.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) + + return conversation def clear(self): self.messages = [] @classmethod def get_conversation_template_keys(cls): - return ["system_message", "chat_template"] + return ["system_message", "chat_template", "end_of_assistant"] def __str__(self): return json.dumps( @@ -46,12 +57,12 @@ def __str__(self): indent=4, ) - def get_prompt(self, length: int = None, add_generation_prompt=False) -> Any: + def get_prompt(self, num_messages: int = None, add_generation_prompt=False) -> Any: """ Retrieves the prompt for the conversation. Args: - length (int, optional): The number of messages to include in the prompt. Defaults to None. + num_messages (int, optional): The number of messages to include in the prompt. Defaults to None. get_seps_info (bool, optional): Whether to include separator information in the output. Defaults to False. add_generation_prompt (bool, optional): Whether to add the assistant line start token in generation (for generation only). Defaults to False. @@ -59,22 +70,19 @@ def get_prompt(self, length: int = None, add_generation_prompt=False) -> Any: str or tuple: The prompt string if get_seps_info is False, otherwise a tuple containing the prompt string and separator information. """ - if length is None: - length = len(self.messages) + if num_messages is None: + num_messages = len(self.messages) - assert length <= len(self.messages) + assert num_messages <= len(self.messages) if self.system_message is not None: - messages = [{"role": "system", "content": self.system_message}] + self.messages[:length] + messages = [{"role": "system", "content": self.system_message}] + self.messages[:num_messages] else: - messages = self.messages[:length] + messages = self.messages[:num_messages] prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=add_generation_prompt ) return prompt - def save_prompt(self): - return self.get_prompt() - def append_message(self, role: str, message: str): """ Append a message to the conversation. @@ -141,9 +149,11 @@ def setup_conversation_template( pass except ValueError as e: raise ValueError(e) - if not dist.is_initialized() or dist.get_rank() == 0: - os.makedirs(os.path.dirname(save_path), exist_ok=True) - with open(save_path, "w", encoding="utf8") as f: - logger.info(f"Successfully generated a conversation tempalte config, save to {save_path}.") - json.dump(chat_template_config, f, indent=4, ensure_ascii=False) + + os.makedirs(os.path.dirname(save_path), exist_ok=True) + with open(save_path, "w", encoding="utf8") as f: + logger.info(f"Successfully generated a conversation tempalte config, save to {save_path}.") + json.dump(chat_template_config, f, indent=4, ensure_ascii=False) + f.write("\n") + return Conversation.from_config(tokenizer, chat_template_config) diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py b/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py index b551497b9cce..ede6fa5314e5 100644 --- a/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py +++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py @@ -1,35 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -Prepare dataset scripts - -Usage: -- For SFT dataset preparation (SFT) -python prepare_dataset.py --type sft \ - --data_input_dirs /PATH/TO/SFT/DATASET \ - --conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \ - --tokenizer_dir "" \ - --data_cache_dir $SAVE_DIR/cache \ - --data_jsonl_output_dir $SAVE_DIR/jsonl \ - --data_arrow_output_dir $SAVE_DIR/arrow \ - -- For prompt dataset preparation (PPO) -python prepare_dataset.py --type prompt \ - --data_input_dirs /PATH/TO/SFT/DATASET \ - --conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \ - --tokenizer_dir "" \ - --data_cache_dir $SAVE_DIR/cache \ - --data_jsonl_output_dir $SAVE_DIR/jsonl \ - --data_arrow_output_dir $SAVE_DIR/arrow \ - -- For Preference dataset preparation (DPO and Reward model training) -python prepare_dataset.py --type preference \ - --data_input_dirs /PATH/TO/SFT/DATASET \ - --conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \ - --tokenizer_dir "" \ - --data_cache_dir $SAVE_DIR/cache \ - --data_jsonl_output_dir $SAVE_DIR/jsonl \ - --data_arrow_output_dir $SAVE_DIR/arrow \ +Prepare Dataset for RL Alogithm. """ import argparse @@ -40,7 +12,7 @@ import time from multiprocessing import cpu_count -from coati.dataset import setup_conversation_template, tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft +from coati.dataset import setup_conversation_template, tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft, tokenize_process_reward from datasets import dataset_dict, load_dataset from transformers import AutoTokenizer @@ -56,7 +28,7 @@ def main(): type=str, required=True, default=None, - choices=["sft", "prompt", "preference", "kto"], + choices=["sft", "prompt", "preference", "kto", 'prm'], help="Type of dataset, chose from 'sft', 'prompt', 'preference'. 'kto'", ) parser.add_argument( @@ -205,8 +177,10 @@ def main(): preparation_function = tokenize_rlhf elif args.type == "kto": preparation_function = tokenize_kto + elif args.type == "prm": + preparation_function = tokenize_process_reward else: - raise ValueError("Unknow dataset type. Please choose one from ['sft', 'prompt', 'preference']") + raise ValueError("Unknow dataset type. Please choose one from ['sft', 'prompt', 'preference', 'kto', 'prm']") for index, dataset in enumerate(list_dataset): assert isinstance(dataset, dataset_dict.Dataset) @@ -218,6 +192,7 @@ def main(): dataset = dataset.select( random.sample(range(len(dataset)), min(args.num_samples_per_datafile, len(dataset))) ) + logger.info(f"Start to process part-{index}/{len(list_dataset)} of all original datasets.") dataset = dataset.map( function=preparation_function, @@ -229,13 +204,6 @@ def main(): keep_in_memory=False, num_proc=min(len(dataset), cpu_count()), ) - if args.type == "kto": - filter_by = "completion" - elif args.type == "preference": - filter_by = "chosen_input_ids" - else: - filter_by = "input_ids" - dataset = dataset.filter(lambda data: data[filter_by] is not None) # Save each jsonl spliced dataset. output_index = "0" * (5 - len(str(index))) + str(index) @@ -249,22 +217,15 @@ def main(): logger.info(f"processing {count} spliced data points for {fp_writer.name}") count += 1 fp_writer.write(json.dumps(data_point, ensure_ascii=False) + "\n") + logger.info( f"Current file {fp_writer.name}; " f"Data size: {len(dataset)}; " f"Time cost: {round((time.time() - st) / 60, 6)} minutes." ) + # Save each arrow spliced dataset output_arrow_path = os.path.join(args.data_arrow_output_dir, output_name) - logger.info(f"Start to save {output_arrow_path}") - dataset = load_dataset( - path="json", - data_files=[output_jsonl_path], - cache_dir=os.path.join(args.data_cache_dir, "tokenized"), - keep_in_memory=False, - num_proc=cpu_count(), - split="train", - ) dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(dataset), cpu_count())) From 0bfb0d32a80bd12a9b96a59ca7577e6b62a33e88 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 11 Nov 2024 11:39:15 +0000 Subject: [PATCH 07/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- applications/ColossalChat/coati/dataset/__init__.py | 4 ++-- .../ColossalChat/coati/dataset/conversation.py | 1 - .../data_preparation_scripts/prepare_dataset.py | 11 +++++++++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/applications/ColossalChat/coati/dataset/__init__.py b/applications/ColossalChat/coati/dataset/__init__.py index 78bd463591ea..f36bb05e5bb0 100755 --- a/applications/ColossalChat/coati/dataset/__init__.py +++ b/applications/ColossalChat/coati/dataset/__init__.py @@ -7,7 +7,7 @@ StatefulDistributedSampler, load_tokenized_dataset, ) -from .tokenization_utils import tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft, tokenize_process_reward +from .tokenization_utils import tokenize_kto, tokenize_process_reward, tokenize_prompt, tokenize_rlhf, tokenize_sft __all__ = [ "tokenize_prompt", @@ -23,5 +23,5 @@ "tokenize_kto", "setup_conversation_template", "Conversation", - "tokenize_process_reward" + "tokenize_process_reward", ] diff --git a/applications/ColossalChat/coati/dataset/conversation.py b/applications/ColossalChat/coati/dataset/conversation.py index f66deb885254..0eb73528640b 100755 --- a/applications/ColossalChat/coati/dataset/conversation.py +++ b/applications/ColossalChat/coati/dataset/conversation.py @@ -3,7 +3,6 @@ from dataclasses import dataclass, field from typing import Any, Dict, List -import torch.distributed as dist from transformers import AutoTokenizer, PreTrainedTokenizer from colossalai.logging import get_dist_logger diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py b/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py index ede6fa5314e5..5eeeadbac643 100644 --- a/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py +++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py @@ -12,7 +12,14 @@ import time from multiprocessing import cpu_count -from coati.dataset import setup_conversation_template, tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft, tokenize_process_reward +from coati.dataset import ( + setup_conversation_template, + tokenize_kto, + tokenize_process_reward, + tokenize_prompt, + tokenize_rlhf, + tokenize_sft, +) from datasets import dataset_dict, load_dataset from transformers import AutoTokenizer @@ -28,7 +35,7 @@ def main(): type=str, required=True, default=None, - choices=["sft", "prompt", "preference", "kto", 'prm'], + choices=["sft", "prompt", "preference", "kto", "prm"], help="Type of dataset, chose from 'sft', 'prompt', 'preference'. 'kto'", ) parser.add_argument( From b6ec337f3dbede0f90a27cf354d618bf9ae23e64 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 14 Nov 2024 08:28:02 +0000 Subject: [PATCH 08/18] update tokenize function --- applications/ColossalChat/coati/dataset/tokenization_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py index 533d0acadae9..0548a1454121 100755 --- a/applications/ColossalChat/coati/dataset/tokenization_utils.py +++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py @@ -431,8 +431,8 @@ def tokenize_process_reward( decoded_label = tokenizer.decode(label, skip_special_tokens=False) return { - "input": tokenized, - "label": label, + "input_ids": tokenized, + "labels": label, "loss_mask": loss_mask, "decoded_input": decoded_input, "decoded_label": decoded_label, From c606d1101c5df0dc1c5d3892e051aa80f50a1935 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 14 Nov 2024 08:31:33 +0000 Subject: [PATCH 09/18] add tokenize func --- applications/ColossalChat/coati/dataset/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/ColossalChat/coati/dataset/__init__.py b/applications/ColossalChat/coati/dataset/__init__.py index 78bd463591ea..f1a56b7bb934 100755 --- a/applications/ColossalChat/coati/dataset/__init__.py +++ b/applications/ColossalChat/coati/dataset/__init__.py @@ -23,5 +23,5 @@ "tokenize_kto", "setup_conversation_template", "Conversation", - "tokenize_process_reward" + "tokenize_process_reward", ] From 9995119c28650a80fade1a9481119e7cd41735a5 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 14 Nov 2024 08:48:08 +0000 Subject: [PATCH 10/18] update init --- applications/ColossalChat/coati/models/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/applications/ColossalChat/coati/models/__init__.py b/applications/ColossalChat/coati/models/__init__.py index fba0949e3fb8..7a78ede58bcd 100755 --- a/applications/ColossalChat/coati/models/__init__.py +++ b/applications/ColossalChat/coati/models/__init__.py @@ -2,7 +2,7 @@ from .critic import Critic from .generation import generate, generate_streaming, prepare_inputs_fn, update_model_kwargs_fn from .lora import LoraConfig, convert_to_lora_module, lora_manager -from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss +from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss, PRMLoss from .reward_model import RewardModel from .utils import disable_dropout @@ -18,9 +18,11 @@ "lora_manager", "convert_to_lora_module", "DpoLoss", - "KTOLoss" "generate", + "KTOLoss", + "generate", "generate_streaming", "disable_dropout", "update_model_kwargs_fn", "prepare_inputs_fn", + "PRMLoss" ] From 38a7f3846dc8fbd1ceb4edcdb50f04556db0b686 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 Nov 2024 08:49:36 +0000 Subject: [PATCH 11/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- applications/ColossalChat/coati/models/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/applications/ColossalChat/coati/models/__init__.py b/applications/ColossalChat/coati/models/__init__.py index 7a78ede58bcd..b74475e68ee4 100755 --- a/applications/ColossalChat/coati/models/__init__.py +++ b/applications/ColossalChat/coati/models/__init__.py @@ -2,7 +2,7 @@ from .critic import Critic from .generation import generate, generate_streaming, prepare_inputs_fn, update_model_kwargs_fn from .lora import LoraConfig, convert_to_lora_module, lora_manager -from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss, PRMLoss +from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, PRMLoss, ValueLoss from .reward_model import RewardModel from .utils import disable_dropout @@ -24,5 +24,5 @@ "disable_dropout", "update_model_kwargs_fn", "prepare_inputs_fn", - "PRMLoss" + "PRMLoss", ] From 797a81a8e20d2bae820fa57adf5363bcd8f0b483 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 14 Nov 2024 08:53:26 +0000 Subject: [PATCH 12/18] add loss --- .../ColossalChat/coati/models/loss.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/applications/ColossalChat/coati/models/loss.py b/applications/ColossalChat/coati/models/loss.py index 927dfd5a89b6..cda5b033a462 100755 --- a/applications/ColossalChat/coati/models/loss.py +++ b/applications/ColossalChat/coati/models/loss.py @@ -280,3 +280,24 @@ def forward( losses = torch.cat((self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), 0).mean() return losses, chosen_rewards, rejected_rewards, kl + + +class PRMLoss(nn.Module): + def __init__(self, reward_signal_id: Optional[list[int]] = None): + super().__init__() + self.IGNORE_INDEX = -100 + self.loss = nn.CrossEntropyLoss(ignore_index=self.IGNORE_INDEX) + self.reward_signal_id = reward_signal_id + + def forward(self, labels: torch.Tensor, logits: torch.Tensor): + loss_mask = torch.isin(labels, torch.tensor(self.reward_signal_id).to(labels.device)) + + logits = logits[loss_mask] + labels = labels[loss_mask] + logits = logits[..., self.reward_signal_id] + + label_mapping = {token: i for i, token in enumerate(self.reward_signal_id)} + labels = torch.tensor([label_mapping.get(label.item(), label.item()) for label in labels], device=labels.device) + loss = self.loss(logits, labels) + + return loss From f2f5ff5e24434f8d40b44e0ac55570d3fd8f26d2 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 14 Nov 2024 08:56:24 +0000 Subject: [PATCH 13/18] update init --- applications/ColossalChat/coati/trainer/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/applications/ColossalChat/coati/trainer/__init__.py b/applications/ColossalChat/coati/trainer/__init__.py index 6d0900153e8a..0c85f7d3d73e 100755 --- a/applications/ColossalChat/coati/trainer/__init__.py +++ b/applications/ColossalChat/coati/trainer/__init__.py @@ -5,6 +5,7 @@ from .ppo import PPOTrainer from .rm import RewardModelTrainer from .sft import SFTTrainer +from .prm import ProcessRewardModelTrainer __all__ = [ "SLTrainer", @@ -15,4 +16,5 @@ "DPOTrainer", "ORPOTrainer", "KTOTrainer", + "ProcessRewardModelTrainer" ] From 852333423ddbca1d60fadcf40fe59970fe2b14d1 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 14 Nov 2024 08:58:43 +0000 Subject: [PATCH 14/18] update prm --- .../examples/training_scripts/train_prm.py | 288 ++++++++++++++++++ 1 file changed, 288 insertions(+) create mode 100644 applications/ColossalChat/examples/training_scripts/train_prm.py diff --git a/applications/ColossalChat/examples/training_scripts/train_prm.py b/applications/ColossalChat/examples/training_scripts/train_prm.py new file mode 100644 index 000000000000..f77539c041db --- /dev/null +++ b/applications/ColossalChat/examples/training_scripts/train_prm.py @@ -0,0 +1,288 @@ +""" +Train Process Reward Model. +""" + +import argparse +import json +import math +import os +import resource +from contextlib import nullcontext + +import torch +from coati.dataset import DataCollatorForSupervisedDataset, StatefulDistributedSampler, load_tokenized_dataset +from coati.trainer import ProcessRewardModelTrainer +from coati.utils import load_checkpoint +from transformers import AutoModelForCausalLM, AutoTokenizer + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.cluster import DistCoordinator +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam + + +def load_dataset(args, plugin, tokenizer): + dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length) + + train_dataloader = plugin.prepare_dataloader( + dataset=dataset, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=data_collator, + distributed_sampler_cls=StatefulDistributedSampler, + ) + + eval_dataloader = None + if args.eval_dataset: + eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev") + eval_data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length) + + eval_dataloader = plugin.prepare_dataloader( + dataset=eval_dataset, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=eval_data_collator, + distributed_sampler_cls=StatefulDistributedSampler, + ) + + return train_dataloader, eval_dataloader + + +def initialize_plugin(args): + if args.plugin == "ddp": + """ + Default torch ddp plugin without any acceleration, for + debugging purpose acceleration, for debugging purpose + """ + plugin = TorchDDPPlugin(find_unused_parameters=not args.grad_checkpoint) + elif args.plugin == "gemini": + plugin = GeminiPlugin( + precision=args.mixed_precision, + placement_policy="static", + initial_scale=2**16, + max_norm=args.grad_clip, + enable_gradient_accumulation=True if args.accumulation_steps > 1 else False, + enable_flash_attention=args.use_flash_attn, + ) + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin( + precision=args.mixed_precision, + placement_policy="auto", + initial_scale=2**16, + max_norm=args.grad_clip, + enable_flash_attention=args.use_flash_attn, + ) + elif args.plugin == "zero2": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + max_norm=args.grad_clip, + ) + elif args.plugin == "zero2_cpu": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + cpu_offload=True, + max_norm=args.grad_clip, + ) + elif args.plugin == "3d": + plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=args.pp, + sp_size=args.sp, + sequence_parallelism_mode=args.sp_mode, + zero_stage=args.zero_stage, + enable_flash_attention=args.use_flash_attn, + enable_sequence_parallelism=args.enable_sequence_parallelism, + cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False, + parallel_output=False, + max_norm=args.grad_clip, + precision=args.mixed_precision, + microbatch_size=args.microbatch_size, + ) + else: + raise ValueError(f"Unknown plugin {args.plugin}") + + return plugin + + +def train(args): + + colossalai.launch_from_torch() + coordinator = DistCoordinator() + + init_ctx = nullcontext() + with init_ctx: + model = AutoModelForCausalLM.from_pretrained( + args.pretrain, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + trust_remote_code=True, + ) + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_dir or args.pretrain, use_fast=False, trust_remote_code=True + ) + + plugin = initialize_plugin(args=args) + booster = Booster(plugin=plugin) + optimizer = HybridAdam( + model_params=model.parameters(), + lr=args.lr, + betas=(0.9, 0.95), + weight_decay=args.weight_decay, + adamw_mode=True, + ) + + model.train() + if args.grad_checkpoint: + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") + + ## + coordinator.print_on_master( + f"Max CUDA memory before data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + train_dataloader, eval_dataloader = load_dataset(args=args, plugin=plugin, tokenizer=tokenizer) + + coordinator.print_on_master( + f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + + num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps + math.ceil(args.max_epochs * num_update_steps_per_epoch) + + if args.warmup_steps is None: + args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps)) + + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optimizer, + total_steps=args.max_epochs * num_update_steps_per_epoch, + warmup_steps=args.warmup_steps, + eta_min=0.1 * args.lr, + ) + + default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 + torch.set_default_dtype(default_dtype) + model, optimizer, _, _, _ = booster.boost(model=model, optimizer=optimizer) + torch.set_default_dtype(torch.float) + + coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") + coordinator.print_on_master( + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + start_epoch = 0 + sampler_start_idx = 0 + start_step = 0 + if args.checkpoint_path is not None: + if "modeling" in args.checkpoint_path: + coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}") + booster.load_model(model, args.checkpoint_path) + else: + coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}") + start_epoch, start_step, sampler_start_idx = load_checkpoint( + load_dir=args.checkpoint_path, + booster=booster, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + ) + train_dataloader.sampler.set_start_index(start_index=sampler_start_idx) + + coordinator.print_on_master( + f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}" + ) + coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}") + + coordinator.print_on_master( + f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + trainer = ProcessRewardModelTrainer( + model=model, + booster=booster, + optimizer=optimizer, + plugin=plugin, + lr_scheduler=lr_scheduler, + tokenizer=tokenizer, + max_epochs=args.max_epochs, + accumulation_steps=args.accumulation_steps, + start_epoch=start_epoch, + save_interval=args.save_interval, + save_dir=args.save_path, + coordinator=coordinator, + reward_signal_ids=tokenizer.convert_tokens_to_ids(args.reward_signal), + ) + + trainer.fit( + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + log_dir=args.log_dir, + use_wandb=args.use_wandb, + ) + + +if __name__ == "__main__": + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument( + "--plugin", + type=str, + default="gemini", + choices=["gemini", "gemini_auto", "3d", "ddp", "zero2_cpu", "zero2"], + help="Choose which plugin to use", + ) + parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") + parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") + parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--pp", type=int, default=1) + parser.add_argument("--sp", type=int, default=1) + parser.add_argument("--disable_loss_mask", default=False, action="store_true") + parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true") + parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2]) + parser.add_argument("--zero_cpu_offload", default=False, action="store_true") + parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"]) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--tokenizer_dir", type=str, default=None) + parser.add_argument("--dataset", nargs="+", default=[]) + parser.add_argument("--eval_dataset", nargs="+", default=[]) + parser.add_argument( + "--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint" + ) + parser.add_argument("--save_path", type=str, default=None) + parser.add_argument("--max_epochs", type=int, default=3) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--max_length", type=int, default=512) + parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["fp16", "bf16"], help="Mixed precision") + parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path") + parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints") + parser.add_argument("--lr", type=float, default=5e-6) + parser.add_argument("--config_file", type=str, default=None, help="Config file") + parser.add_argument("--accumulation_steps", type=int, default=8) + parser.add_argument("--log_dir", default=None, type=str) + parser.add_argument("--use_wandb", default=False, action="store_true") + parser.add_argument("--grad_checkpoint", default=False, action="store_true") + parser.add_argument("--use_flash_attn", default=False, action="store_true") + parser.add_argument("--microbatch_size", type=int, default=1) + parser.add_argument("--reward_signal", nargs="+", default=["+", "-"]) + args = parser.parse_args() + if args.config_file is not None: + os.makedirs(os.path.dirname(args.config_file), exist_ok=True) + with open(args.config_file, "w") as f: + json.dump(args.__dict__, f, indent=4) + train(args) From ab992b89e4a1e252d828867bf80fdaa208492efa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 Nov 2024 08:59:30 +0000 Subject: [PATCH 15/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- applications/ColossalChat/coati/trainer/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/applications/ColossalChat/coati/trainer/__init__.py b/applications/ColossalChat/coati/trainer/__init__.py index 0c85f7d3d73e..2fc47e17de47 100755 --- a/applications/ColossalChat/coati/trainer/__init__.py +++ b/applications/ColossalChat/coati/trainer/__init__.py @@ -3,9 +3,9 @@ from .kto import KTOTrainer from .orpo import ORPOTrainer from .ppo import PPOTrainer +from .prm import ProcessRewardModelTrainer from .rm import RewardModelTrainer from .sft import SFTTrainer -from .prm import ProcessRewardModelTrainer __all__ = [ "SLTrainer", @@ -16,5 +16,5 @@ "DPOTrainer", "ORPOTrainer", "KTOTrainer", - "ProcessRewardModelTrainer" + "ProcessRewardModelTrainer", ] From a8b4afb747671fa6e71c3a2e80f9d727df66a7c5 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 14 Nov 2024 09:06:59 +0000 Subject: [PATCH 16/18] add prm --- .../ColossalChat/coati/trainer/prm.py | 134 ++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 applications/ColossalChat/coati/trainer/prm.py diff --git a/applications/ColossalChat/coati/trainer/prm.py b/applications/ColossalChat/coati/trainer/prm.py new file mode 100644 index 000000000000..73fdb6c5e603 --- /dev/null +++ b/applications/ColossalChat/coati/trainer/prm.py @@ -0,0 +1,134 @@ +""" +Trainer for Process Reward Model. +""" + +import os +import time +from typing import Any, Callable, List, Optional + +import torch +import tqdm +from coati.models import PRMLoss +from coati.trainer.utils import all_reduce_mean +from coati.utils import AccumulativeMeanMeter, save_checkpoint +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader +from transformers import PreTrainedTokenizerBase + +from colossalai.booster import Booster, Plugin +from colossalai.cluster import DistCoordinator +from colossalai.utils import get_current_device + +from .base import SLTrainer +from .utils import is_rank_0, to_device + + +class ProcessRewardModelTrainer(SLTrainer): + """ + Trainer for Process Reward Model. + """ + + def __init__( + self, + model: Any, + booster: Booster, + optimizer: Optimizer, + plugin: Plugin, + lr_scheduler: _LRScheduler, + tokenizer: PreTrainedTokenizerBase, + loss_fn: Optional[Callable] = None, + max_epochs: int = 1, + accumulation_steps: int = 1, + start_epoch: int = 0, + save_interval: int = 0, + save_dir: str = None, + coordinator: DistCoordinator = None, + reward_signal_ids: List[int] = [], + ) -> None: + super().__init__( + booster, max_epochs=max_epochs, model=model, optimizer=optimizer, plugin=plugin, start_epoch=start_epoch + ) + self.lr_scheduler = lr_scheduler + self.tokenizer = tokenizer + self.reward_signal_ids = reward_signal_ids + self.loss_fn = loss_fn if loss_fn is not None else PRMLoss(self.reward_signal_ids) + self.save_interval = save_interval + self.coordinator = coordinator + self.save_dir = save_dir + self.num_train_step = 0 + self.accumulation_steps = accumulation_steps + self.device = get_current_device() + self.accumulative_meter = AccumulativeMeanMeter() + + def _before_fit( + self, + train_dataloader: DataLoader = None, + eval_dataloader: DataLoader = None, + log_dir: Optional[str] = None, + use_wandb: bool = False, + ): + self.train_dataloader = train_dataloader + self.eval_dataloader = eval_dataloader + self.writer = None + if log_dir is not None and is_rank_0(): + from torch.utils.tensorboard import SummaryWriter + + log_dir = os.path.join(log_dir, "PRM", time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) + self.writer = SummaryWriter(log_dir=log_dir) + + if use_wandb: + import wandb + + self.wandb_run = wandb.init(project="Coati-PRM", sync_tensorboard=True) + + def _train(self, epoch): + self.model.train() + step_bar = tqdm.trange( + len(self.train_dataloader) // self.accumulation_steps, + desc=f"Epoch {epoch + 1}/{self.max_epochs}", + disable=not is_rank_0(), + ) + for i, batch in enumerate(self.train_dataloader): + batch = to_device(batch, self.device) + batch_size = batch["input_ids"].size(0) + logits = self.model(batch["input_ids"])["logits"] + loss = self.loss_fn(batch["labels"], logits) + self.booster.backward(loss=loss, optimizer=self.optimizer) + loss_mean = all_reduce_mean(tensor=loss) + self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) + + if (i + 1) % self.accumulation_steps == 0: + self.optimizer.step() + self.optimizer.zero_grad() + self.lr_scheduler.step() + step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")}) + if self.writer: + self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) + self.num_train_step += 1 + step_bar.update() + + # Save checkpoint + if ( + self.save_dir is not None + and self.save_interval is not None + and (self.num_train_step + 1) % self.save_interval == 0 + ): + save_checkpoint( + save_dir=self.save_dir, + booster=self.booster, + model=self.model, + optimizer=self.optimizer, + lr_scheduler=self.scheduler, + epoch=epoch, + step=self.num_train_step + 1, + batch_size=batch_size, + coordinator=self.coordinator, + ) + self.coordinator.print_on_master( + f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}" + ) + + def _eval(epoch: int): + # TODO + pass From 375e356a1632bb242efd3dd51bcdcb57de0ea293 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Fri, 15 Nov 2024 03:18:41 +0000 Subject: [PATCH 17/18] update prm --- .../ColossalChat/coati/trainer/prm.py | 33 ++++++++++++++++--- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/applications/ColossalChat/coati/trainer/prm.py b/applications/ColossalChat/coati/trainer/prm.py index 73fdb6c5e603..3c916c3083d0 100644 --- a/applications/ColossalChat/coati/trainer/prm.py +++ b/applications/ColossalChat/coati/trainer/prm.py @@ -82,7 +82,7 @@ def _before_fit( self.wandb_run = wandb.init(project="Coati-PRM", sync_tensorboard=True) - def _train(self, epoch): + def _train(self, epoch: int): self.model.train() step_bar = tqdm.trange( len(self.train_dataloader) // self.accumulation_steps, @@ -129,6 +129,31 @@ def _train(self, epoch): f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}" ) - def _eval(epoch: int): - # TODO - pass + def _eval(self, epoch: int): + self.model.eval() + + step_bar = tqdm.trange( + len(self.eval_dataloader), + desc=f"Epoch {epoch + 1}/{self.max_epochs}", + disable=not is_rank_0(), + ) + for batch in self.eval_dataloader: + batch = to_device(batch, self.device) + logits = self.model(batch["input_ids"])["logits"] + loss = self.loss_fn(batch["labels"], logits) + loss_mean = all_reduce_mean(tensor=loss) + self.accumulative_meter.add( + "loss", loss_mean.to(torch.float16).item(), count_update=batch["input_ids"].size(0) + ) + step_bar.update() + + loss_mean = self.accumulative_meter.get("loss") + msg = "Evaluation Result:\n" + for tag in ["loss"]: + msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" + self.coordinator.print_on_master(msg) + if self.save_dir is not None: + os.makedirs(self.save_dir, exist_ok=True) + with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: + f.write(msg) + step_bar.close() From dacc04ef7588026a6a2d7e0a3ae1105f84d4c1c6 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Fri, 15 Nov 2024 03:55:55 +0000 Subject: [PATCH 18/18] fix bug --- applications/ColossalChat/coati/dataset/tokenization_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py index 0548a1454121..8264f6684361 100755 --- a/applications/ColossalChat/coati/dataset/tokenization_utils.py +++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py @@ -419,7 +419,7 @@ def tokenize_process_reward( template = deepcopy(conversation_template) template.append_message("user", input) template.append_message("assistant", label) - prompt = template.get_prompt(add_generation_prompt=True) + prompt = template.get_prompt(add_generation_prompt=False) reward_signal_id = tokenizer.convert_tokens_to_ids(template.reward_signal) tokenized = tokenizer(prompt, add_special_tokens=False)["input_ids"]