Skip to content

Commit

Permalink
Some dead code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhamugare committed Aug 5, 2024
1 parent b8d1198 commit 98f31ab
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 21 deletions.
27 changes: 13 additions & 14 deletions syncode/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,29 @@ class Syncode:
Args:
mode (str, optional): Mode for inference. Defaults to "grammar_mask".
"original" for original model, "grammar_mask" for grammar mask, "grammar_strict" for strict grammar mask.
model (str): Model id for huggingface model hub or model name if stored locally.
quantize (bool, optional): Quantize model. Defaults to True.
device (str, optional): Device to use. Defaults to "cuda".
num_samples (int, optional): Number of samples. Defaults to 1.
grammar (str, optional): Language. Defaults to "input". "input" is used for user input.
other options currently supported are "python", "go", "calc"
dataset (str, optional): Dataset. Defaults to "humaneval".
new_mask_store (bool, optional): Use new DFA mask store. Defaults to False.
num_few_shot (int, optional): Number of examples for few shot prompting. Defaults to 0.
chat_mode (bool, optional): Parse only the (output) and not (prompt+output) in chat mode. Defaults to False.
dev_mode (bool, optional): Development mode. Defaults to False.
log_level (int, optional): Log level. Defaults to 2. 0 for no logs, 1 for minimal logs, 2 for all logs including time.
parser (str, optional): Parser to use. Defaults to "lalr".
task_id (int, optional): For debugging a specific task. Defaults to None.
"""
Expand Down Expand Up @@ -109,8 +120,8 @@ def __init__(
mode=mode,
)

# Set LLM generation args e.g. max_new_tokens, do_sample, etc.
self.set_generation_args(kwargs, tokenizer)
# Set LLM max new tokens to 200 by default
kwargs['max_new_tokens'] = kwargs.get('max_new_tokens', 200)

self.model = HuggingFaceModel(
model,
Expand Down Expand Up @@ -152,18 +163,6 @@ def get_output_path(self):
out_path = out_dir + 'samples_' + str(self.num_samples) + '_mode_' + str(self.mode) + "_eval.jsonl"
os.makedirs(out_dir, exist_ok=True)
return out_dir,out_path

def set_generation_args(self, kwargs, tokenizer):
kwargs['max_new_tokens'] = kwargs.get('max_new_tokens', 200)
kwargs['do_sample'] = kwargs.get('do_sample', False)
kwargs['use_cache'] = kwargs.get('use_cache', True)
kwargs['eos_token_id'] = kwargs.get('eos_token_id', tokenizer.eos_token_id)
kwargs['pad_token_id'] = kwargs.get('pad_token_id', tokenizer.eos_token_id) # model has no pad token
if kwargs['do_sample'] or self.num_samples > 1: # If sampling, set temperature, top_k, top_p
kwargs['temperature'] = kwargs.get('temperature', 0.2)
kwargs['top_k'] = kwargs.get('top_k', self.num_samples)
kwargs['top_p'] = kwargs.get('top_p', 0.95)
print(f"Generation args: {kwargs}")

def user_input(self, prompt:str, stop_words=[]):
"""
Expand Down
11 changes: 4 additions & 7 deletions syncode/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def __init__(
logger: common.Logger,
tokenizer=None,
prompt_template: str = '',
api_key: str = None,
best_of: int = 1,
before_prediction_hook=lambda: None,
device='cuda',
Expand Down Expand Up @@ -67,7 +66,7 @@ def get_grammar_decoder(self):
return None

@torch.inference_mode()
def generate_batch_completion_grammar(self, prompt, batch_size, stop_words=[]) -> Iterable[str]:
def generate_batch_completion_grammar(self, prompt, batch_size, stop_words=None) -> Iterable[str]:
'''
Generates batch_size completions for the given prompt.
'''
Expand All @@ -87,7 +86,7 @@ def generate_batch_completion_grammar(self, prompt, batch_size, stop_words=[]) -
gen_mode = self._get_generation_mode(gen_config)

# Create stopping criteria
if len(stop_words) > 0:
if stop_words is not None:
stop_criteria = StoppingCriteriaList([KeywordsStoppingCriteria(self.tokenizer, stop_words=stop_words)])
else:
stop_criteria = []
Expand All @@ -103,13 +102,11 @@ def generate_batch_completion_grammar(self, prompt, batch_size, stop_words=[]) -
)
else:
# Use generate from transformers library for other modes
if stop_criteria is not None:
print('Warning: Stopping criteria is not supported for batch generation')

generated_ids = self.model.generate(
**inputs,
logits_processor=self.logit_processors,
stoppings=stop_words,
stop_strings=stop_words,
tokenizer=self.tokenizer,
**self.gen_args
)
batch_completions = []
Expand Down

0 comments on commit 98f31ab

Please sign in to comment.