From 6e28934e7277337ba17e100e00b2b8955a03fff7 Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Sat, 2 Nov 2024 10:27:12 +0800 Subject: [PATCH] add a new parameter fix_data --- angle_emb/angle.py | 7 +++++-- angle_emb/angle_trainer.py | 18 ++++++++++++++---- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/angle_emb/angle.py b/angle_emb/angle.py index 09d1c31..ea693e5 100644 --- a/angle_emb/angle.py +++ b/angle_emb/angle.py @@ -393,6 +393,7 @@ class AngleDataTokenizer: :param dataset_format: Optional[str]. Specify dataset_format from DatasetFormats. Default None. It will automatically detect the dataset format. :param end_with_eos: bool. Specify whether ends with the eos token. Default False. + :param fix_data: bool. Specify whether fix the data. Only works when prompt_template is not None. Default True. Example:: @@ -415,7 +416,8 @@ def __init__(self, template_placeholders: Optional[List[str]] = None, extra_columns: Optional[List[str]] = None, dataset_format: Optional[str] = None, - end_with_eos: bool = False): + end_with_eos: bool = False, + fix_data: bool = True): self.tokenizer = tokenizer self.max_length = max_length self.prompt_template = prompt_template @@ -423,6 +425,7 @@ def __init__(self, self.extra_columns = extra_columns self.dataset_format = dataset_format self.end_with_eos = end_with_eos + self.fix_data = fix_data if template_placeholders is None: template_placeholders = ['condition', 'text'] if prompt_template is not None: @@ -492,7 +495,7 @@ def __call__(self, data: Dict) -> Dict: for text_column in text_columns: toks.append(self.tokenizer(data[text_column], max_length=self.max_length, truncation=True)) - if self.prompt_template_tok is not None: + if self.prompt_template_tok is not None and self.fix_data: for tok in toks: if tok['input_ids'][-1] != self.prompt_template_tok['input_ids'][-1]: logger.info(f"data data: token ids={tok['input_ids']}, prompt_token_ids={self.prompt_template_tok['input_ids']}") # NOQA diff --git a/angle_emb/angle_trainer.py b/angle_emb/angle_trainer.py index bfeb7ca..9c54f47 100644 --- a/angle_emb/angle_trainer.py +++ b/angle_emb/angle_trainer.py @@ -47,6 +47,8 @@ 'This prompt will be applied for all text columns.' 'If you want to specify different prompts for different text columns,' 'please handle it in the preprocessing step.') +parser.add_argument('--fix_data', type=int, default=1, choices=[0, 1], + help='Whether fix data (only works when prompt_template is not None), choices [0, 1], defaut 1') parser.add_argument('--filter_duplicate', type=int, default=1, choices=[0, 1], help='Specify filter_duplicate, choices [0, 1], defaut 1') parser.add_argument('--save_dir', type=str, default=None, @@ -221,11 +223,15 @@ def main(): logger.info('Processing train...') if args.streaming: train_ds = ds[args.train_split_name].shuffle(args.dataset_seed).map( - AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template), + AngleDataTokenizer(model.tokenizer, model.max_length, + prompt_template=args.prompt_template, + fix_data=args.fix_data), num_proc=args.workers) else: train_ds = ds[args.train_split_name].shuffle(args.dataset_seed).map( - AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template), + AngleDataTokenizer(model.tokenizer, model.max_length, + prompt_template=args.prompt_template, + fix_data=args.fix_data), num_proc=args.workers) valid_ds = None @@ -239,7 +245,9 @@ def main(): else: valid_ds = load_dataset(args.valid_name_or_path, num_proc=args.workers) valid_ds = valid_ds[args.valid_split_name or 'train'].map( - AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template), + AngleDataTokenizer(model.tokenizer, model.max_length, + prompt_template=args.prompt_template, + fix_data=args.fix_data), num_proc=args.workers) valid_ds_for_callback = None @@ -258,7 +266,9 @@ def main(): valid_ds_for_callback = load_dataset( args.valid_name_or_path_for_callback, num_proc=args.workers) valid_ds_for_callback = valid_ds_for_callback[args.valid_split_name_for_callback or 'train'].map( - AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template), + AngleDataTokenizer(model.tokenizer, model.max_length, + prompt_template=args.prompt_template, + fix_data=args.fix_data), num_proc=args.workers) argument_kwargs = {}