Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add PRM and refactor MCTS #6119

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion applications/ColossalChat/coati/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
StatefulDistributedSampler,
load_tokenized_dataset,
)
from .tokenization_utils import tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft
from .tokenization_utils import tokenize_kto, tokenize_process_reward, tokenize_prompt, tokenize_rlhf, tokenize_sft

__all__ = [
"tokenize_prompt",
Expand All @@ -23,4 +23,5 @@
"tokenize_kto",
"setup_conversation_template",
"Conversation",
"tokenize_process_reward",
]
59 changes: 34 additions & 25 deletions applications/ColossalChat/coati/dataset/conversation.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,53 @@
import dataclasses
import json
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List

import torch.distributed as dist
from transformers import AutoTokenizer, PreTrainedTokenizer

from colossalai.logging import get_dist_logger

logger = get_dist_logger()


@dataclasses.dataclass
@dataclass
class Conversation:
tokenizer: PreTrainedTokenizer
system_message: str
chat_template: str
stop_ids: List[int]
end_of_assistant: str
roles = ["user", "assistant"]
messages: List[Dict[str, str]] = field(default_factory=list)
roles: List[str] = field(default_factory=lambda: ["user", "assistant"])
step_score_signal: str = None
reward_signal: List[str] = None

@classmethod
def from_config(cls, tokenizer: PreTrainedTokenizer, config: Dict):
"""
Setup the conversation template from config
"""
tokenizer.chat_template = config["chat_template"]
conv = cls(
tokenizer, config["system_message"], config["chat_template"], config["stop_ids"], config["end_of_assistant"]
)
conv.clear()
return conv
conversation = cls(tokenizer, **config)

special_tokens = []
if conversation.step_score_signal is not None:
special_tokens.extend(conversation.step_score_signal)

if conversation.reward_signal is not None:
special_tokens.extend(conversation.reward_signal)

if special_tokens:
conversation.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})

return conversation

def clear(self):
self.messages = []

@classmethod
def get_conversation_template_keys(cls):
return ["system_message", "chat_template"]
return ["system_message", "chat_template", "end_of_assistant"]

def __str__(self):
return json.dumps(
Expand All @@ -46,35 +56,32 @@ def __str__(self):
indent=4,
)

def get_prompt(self, length: int = None, add_generation_prompt=False) -> Any:
def get_prompt(self, num_messages: int = None, add_generation_prompt=False) -> Any:
"""
Retrieves the prompt for the conversation.

Args:
length (int, optional): The number of messages to include in the prompt. Defaults to None.
num_messages (int, optional): The number of messages to include in the prompt. Defaults to None.
get_seps_info (bool, optional): Whether to include separator information in the output. Defaults to False.
add_generation_prompt (bool, optional): Whether to add the assistant line start token in generation (for generation only). Defaults to False.

Returns:
str or tuple: The prompt string if get_seps_info is False, otherwise a tuple containing the prompt string and separator information.
"""

if length is None:
length = len(self.messages)
if num_messages is None:
num_messages = len(self.messages)

assert length <= len(self.messages)
assert num_messages <= len(self.messages)
if self.system_message is not None:
messages = [{"role": "system", "content": self.system_message}] + self.messages[:length]
messages = [{"role": "system", "content": self.system_message}] + self.messages[:num_messages]
else:
messages = self.messages[:length]
messages = self.messages[:num_messages]
prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=add_generation_prompt
)
return prompt

def save_prompt(self):
return self.get_prompt()

def append_message(self, role: str, message: str):
"""
Append a message to the conversation.
Expand Down Expand Up @@ -141,9 +148,11 @@ def setup_conversation_template(
pass
except ValueError as e:
raise ValueError(e)
if not dist.is_initialized() or dist.get_rank() == 0:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, "w", encoding="utf8") as f:
logger.info(f"Successfully generated a conversation tempalte config, save to {save_path}.")
json.dump(chat_template_config, f, indent=4, ensure_ascii=False)

os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, "w", encoding="utf8") as f:
logger.info(f"Successfully generated a conversation tempalte config, save to {save_path}.")
json.dump(chat_template_config, f, indent=4, ensure_ascii=False)
f.write("\n")

return Conversation.from_config(tokenizer, chat_template_config)
46 changes: 45 additions & 1 deletion applications/ColossalChat/coati/dataset/tokenization_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
tokenization utils for constructing dataset for ppo, dpo, sft, rm
Tokenization Utils for Constructing Dataset for RL.
"""

import warnings
from copy import deepcopy
from typing import Any, Dict, List, Union

import torch
from coati.dataset.conversation import Conversation
from coati.dataset.utils import split_templated_prompt_into_chunks, tokenize_and_concatenate
from datasets import dataset_dict
Expand Down Expand Up @@ -393,3 +394,46 @@ def tokenize_kto(
"input_id_decode": decoded_full_prompt,
"completion_decode": decoded_completion,
}


def tokenize_process_reward(
data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
Tokenize function designed for tokenizing Math-Shepherd dataset.

The datapoint has the following format:
{
"input": problem + step-by-step solution,
"label": problem + step-by-step solution with automatic label,
"task": GSM8K or MATH
}

"""
input = data_point["input"]
label = data_point["label"]

template = deepcopy(conversation_template)
template.append_message("user", input)
template.append_message("assistant", label)
prompt = template.get_prompt(add_generation_prompt=False)
reward_signal_id = tokenizer.convert_tokens_to_ids(template.reward_signal)
tokenized = tokenizer(prompt, add_special_tokens=False)["input_ids"]

tokenized_tensor = torch.tensor(tokenized)
loss_mask = torch.isin(tokenized_tensor, torch.tensor(reward_signal_id))

label = (tokenized_tensor * loss_mask).tolist()
decoded_input = tokenizer.decode(tokenized, skip_special_tokens=False)
decoded_label = tokenizer.decode(label, skip_special_tokens=False)

return {
"input_ids": tokenized,
"labels": label,
"loss_mask": loss_mask,
"decoded_input": decoded_input,
"decoded_label": decoded_label,
}
6 changes: 4 additions & 2 deletions applications/ColossalChat/coati/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .critic import Critic
from .generation import generate, generate_streaming, prepare_inputs_fn, update_model_kwargs_fn
from .lora import LoraConfig, convert_to_lora_module, lora_manager
from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, PRMLoss, ValueLoss
from .reward_model import RewardModel
from .utils import disable_dropout

Expand All @@ -18,9 +18,11 @@
"lora_manager",
"convert_to_lora_module",
"DpoLoss",
"KTOLoss" "generate",
"KTOLoss",
"generate",
"generate_streaming",
"disable_dropout",
"update_model_kwargs_fn",
"prepare_inputs_fn",
"PRMLoss",
]
21 changes: 21 additions & 0 deletions applications/ColossalChat/coati/models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,24 @@ def forward(
losses = torch.cat((self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), 0).mean()

return losses, chosen_rewards, rejected_rewards, kl


class PRMLoss(nn.Module):
def __init__(self, reward_signal_id: Optional[list[int]] = None):
super().__init__()
self.IGNORE_INDEX = -100
self.loss = nn.CrossEntropyLoss(ignore_index=self.IGNORE_INDEX)
self.reward_signal_id = reward_signal_id

def forward(self, labels: torch.Tensor, logits: torch.Tensor):
loss_mask = torch.isin(labels, torch.tensor(self.reward_signal_id).to(labels.device))

logits = logits[loss_mask]
labels = labels[loss_mask]
logits = logits[..., self.reward_signal_id]

label_mapping = {token: i for i, token in enumerate(self.reward_signal_id)}
labels = torch.tensor([label_mapping.get(label.item(), label.item()) for label in labels], device=labels.device)
loss = self.loss(logits, labels)

return loss
12 changes: 5 additions & 7 deletions applications/ColossalChat/coati/reasoner/guided_search/mcts.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""
Implementation of MCTS + Self-refine algorithm.

Reference:
Structure is adapted from https://github.com/BrendanGraham14/mcts-llm/ with the following reference:
1. "Accessing GPT-4 level Mathematical Olympiad Solutions via Monte
Carlo Tree Self-refine with LLaMa-3 8B: A Technical Report"
2. https://github.com/BrendanGraham14/mcts-llm/
3. https://github.com/trotsky1997/MathBlackBox/
4. https://github.com/openreasoner/openr/blob/main/reason/guided_search/tree.py
"""
Expand Down Expand Up @@ -121,16 +120,15 @@ def simulate(self):

return self.get_best_answer()

def get_best_answer(self):
def _iter_nodes(self):
to_visit = deque([self.root])
best_node = self.root

while to_visit:
current_node = to_visit.popleft()
if current_node.Q > best_node.Q:
best_node = current_node
yield current_node
to_visit.extend(current_node.children)

def get_best_answer(self):
best_node = max(self._iter_nodes(), key=lambda node: node.Q, default=self.root)
return best_node.answer

def self_refine(self, node: MCTSNode):
Expand Down
2 changes: 2 additions & 0 deletions applications/ColossalChat/coati/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .kto import KTOTrainer
from .orpo import ORPOTrainer
from .ppo import PPOTrainer
from .prm import ProcessRewardModelTrainer
from .rm import RewardModelTrainer
from .sft import SFTTrainer

Expand All @@ -15,4 +16,5 @@
"DPOTrainer",
"ORPOTrainer",
"KTOTrainer",
"ProcessRewardModelTrainer",
]
Loading