diff --git a/syncode/infer.py b/syncode/infer.py index c3ff2d0c..a3ae3583 100644 --- a/syncode/infer.py +++ b/syncode/infer.py @@ -23,18 +23,29 @@ class Syncode: Args: mode (str, optional): Mode for inference. Defaults to "grammar_mask". "original" for original model, "grammar_mask" for grammar mask, "grammar_strict" for strict grammar mask. + model (str): Model id for huggingface model hub or model name if stored locally. + quantize (bool, optional): Quantize model. Defaults to True. + device (str, optional): Device to use. Defaults to "cuda". + num_samples (int, optional): Number of samples. Defaults to 1. + grammar (str, optional): Language. Defaults to "input". "input" is used for user input. other options currently supported are "python", "go", "calc" + dataset (str, optional): Dataset. Defaults to "humaneval". + new_mask_store (bool, optional): Use new DFA mask store. Defaults to False. + num_few_shot (int, optional): Number of examples for few shot prompting. Defaults to 0. + chat_mode (bool, optional): Parse only the (output) and not (prompt+output) in chat mode. Defaults to False. + dev_mode (bool, optional): Development mode. Defaults to False. log_level (int, optional): Log level. Defaults to 2. 0 for no logs, 1 for minimal logs, 2 for all logs including time. + parser (str, optional): Parser to use. Defaults to "lalr". task_id (int, optional): For debugging a specific task. Defaults to None. """ @@ -109,8 +120,8 @@ def __init__( mode=mode, ) - # Set LLM generation args e.g. max_new_tokens, do_sample, etc. - self.set_generation_args(kwargs, tokenizer) + # Set LLM max new tokens to 200 by default + kwargs['max_new_tokens'] = kwargs.get('max_new_tokens', 200) self.model = HuggingFaceModel( model, @@ -152,18 +163,6 @@ def get_output_path(self): out_path = out_dir + 'samples_' + str(self.num_samples) + '_mode_' + str(self.mode) + "_eval.jsonl" os.makedirs(out_dir, exist_ok=True) return out_dir,out_path - - def set_generation_args(self, kwargs, tokenizer): - kwargs['max_new_tokens'] = kwargs.get('max_new_tokens', 200) - kwargs['do_sample'] = kwargs.get('do_sample', False) - kwargs['use_cache'] = kwargs.get('use_cache', True) - kwargs['eos_token_id'] = kwargs.get('eos_token_id', tokenizer.eos_token_id) - kwargs['pad_token_id'] = kwargs.get('pad_token_id', tokenizer.eos_token_id) # model has no pad token - if kwargs['do_sample'] or self.num_samples > 1: # If sampling, set temperature, top_k, top_p - kwargs['temperature'] = kwargs.get('temperature', 0.2) - kwargs['top_k'] = kwargs.get('top_k', self.num_samples) - kwargs['top_p'] = kwargs.get('top_p', 0.95) - print(f"Generation args: {kwargs}") def user_input(self, prompt:str, stop_words=[]): """ diff --git a/syncode/language_model.py b/syncode/language_model.py index 96551496..935c44c9 100644 --- a/syncode/language_model.py +++ b/syncode/language_model.py @@ -38,7 +38,6 @@ def __init__( logger: common.Logger, tokenizer=None, prompt_template: str = '', - api_key: str = None, best_of: int = 1, before_prediction_hook=lambda: None, device='cuda', @@ -67,7 +66,7 @@ def get_grammar_decoder(self): return None @torch.inference_mode() - def generate_batch_completion_grammar(self, prompt, batch_size, stop_words=[]) -> Iterable[str]: + def generate_batch_completion_grammar(self, prompt, batch_size, stop_words=None) -> Iterable[str]: ''' Generates batch_size completions for the given prompt. ''' @@ -87,7 +86,7 @@ def generate_batch_completion_grammar(self, prompt, batch_size, stop_words=[]) - gen_mode = self._get_generation_mode(gen_config) # Create stopping criteria - if len(stop_words) > 0: + if stop_words is not None: stop_criteria = StoppingCriteriaList([KeywordsStoppingCriteria(self.tokenizer, stop_words=stop_words)]) else: stop_criteria = [] @@ -103,13 +102,11 @@ def generate_batch_completion_grammar(self, prompt, batch_size, stop_words=[]) - ) else: # Use generate from transformers library for other modes - if stop_criteria is not None: - print('Warning: Stopping criteria is not supported for batch generation') - generated_ids = self.model.generate( **inputs, logits_processor=self.logit_processors, - stoppings=stop_words, + stop_strings=stop_words, + tokenizer=self.tokenizer, **self.gen_args ) batch_completions = []