Skip to content

Commit

Permalink
add a new parameter fix_data
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanLee97 committed Nov 2, 2024
1 parent ca71df7 commit 6e28934
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
7 changes: 5 additions & 2 deletions angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ class AngleDataTokenizer:
:param dataset_format: Optional[str]. Specify dataset_format from DatasetFormats. Default None.
It will automatically detect the dataset format.
:param end_with_eos: bool. Specify whether ends with the eos token. Default False.
:param fix_data: bool. Specify whether fix the data. Only works when prompt_template is not None. Default True.
Example::
Expand All @@ -415,14 +416,16 @@ def __init__(self,
template_placeholders: Optional[List[str]] = None,
extra_columns: Optional[List[str]] = None,
dataset_format: Optional[str] = None,
end_with_eos: bool = False):
end_with_eos: bool = False,
fix_data: bool = True):
self.tokenizer = tokenizer
self.max_length = max_length
self.prompt_template = prompt_template
self.prompt_template_tok = None
self.extra_columns = extra_columns
self.dataset_format = dataset_format
self.end_with_eos = end_with_eos
self.fix_data = fix_data
if template_placeholders is None:
template_placeholders = ['condition', 'text']
if prompt_template is not None:
Expand Down Expand Up @@ -492,7 +495,7 @@ def __call__(self, data: Dict) -> Dict:
for text_column in text_columns:
toks.append(self.tokenizer(data[text_column], max_length=self.max_length, truncation=True))

if self.prompt_template_tok is not None:
if self.prompt_template_tok is not None and self.fix_data:
for tok in toks:
if tok['input_ids'][-1] != self.prompt_template_tok['input_ids'][-1]:
logger.info(f"data data: token ids={tok['input_ids']}, prompt_token_ids={self.prompt_template_tok['input_ids']}") # NOQA
Expand Down
18 changes: 14 additions & 4 deletions angle_emb/angle_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
'This prompt will be applied for all text columns.'
'If you want to specify different prompts for different text columns,'
'please handle it in the preprocessing step.')
parser.add_argument('--fix_data', type=int, default=1, choices=[0, 1],
help='Whether fix data (only works when prompt_template is not None), choices [0, 1], defaut 1')
parser.add_argument('--filter_duplicate', type=int, default=1, choices=[0, 1],
help='Specify filter_duplicate, choices [0, 1], defaut 1')
parser.add_argument('--save_dir', type=str, default=None,
Expand Down Expand Up @@ -221,11 +223,15 @@ def main():
logger.info('Processing train...')
if args.streaming:
train_ds = ds[args.train_split_name].shuffle(args.dataset_seed).map(
AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template),
AngleDataTokenizer(model.tokenizer, model.max_length,
prompt_template=args.prompt_template,
fix_data=args.fix_data),
num_proc=args.workers)
else:
train_ds = ds[args.train_split_name].shuffle(args.dataset_seed).map(
AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template),
AngleDataTokenizer(model.tokenizer, model.max_length,
prompt_template=args.prompt_template,
fix_data=args.fix_data),
num_proc=args.workers)

valid_ds = None
Expand All @@ -239,7 +245,9 @@ def main():
else:
valid_ds = load_dataset(args.valid_name_or_path, num_proc=args.workers)
valid_ds = valid_ds[args.valid_split_name or 'train'].map(
AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template),
AngleDataTokenizer(model.tokenizer, model.max_length,
prompt_template=args.prompt_template,
fix_data=args.fix_data),
num_proc=args.workers)

valid_ds_for_callback = None
Expand All @@ -258,7 +266,9 @@ def main():
valid_ds_for_callback = load_dataset(
args.valid_name_or_path_for_callback, num_proc=args.workers)
valid_ds_for_callback = valid_ds_for_callback[args.valid_split_name_for_callback or 'train'].map(
AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template),
AngleDataTokenizer(model.tokenizer, model.max_length,
prompt_template=args.prompt_template,
fix_data=args.fix_data),
num_proc=args.workers)

argument_kwargs = {}
Expand Down

0 comments on commit 6e28934

Please sign in to comment.