Skip to content

Commit

Permalink
Refactor evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhamugare committed Nov 4, 2024
1 parent bfeb0c2 commit cd377ff
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 104 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ Simply install SynCode via PyPi using the following command:
pip install git+https://github.com/uiuc-focal-lab/syncode.git
```

SynCode depends on HuggingFace [transformers](https://github.com/huggingface/transformers):
Note: SynCode depends on HuggingFace [transformers](https://github.com/huggingface/transformers):
| SynCode version | Recommended transformers version |
| -------------- | -------------------------------- |
| `v0.1.3` (latest) | `v4.44.0` |
| `v0.1.4` (latest) | `v4.44.0` |
| `v0.1.2` | `v4.42.0` |


Expand Down
85 changes: 74 additions & 11 deletions syncode/evaluation/code_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from tqdm import tqdm
from typing import Optional
from syncode import common
from syncode.evaluation.mxeval_evaluation import check_coorectness
from syncode.evaluation.mxeval_evaluation import check_corectness
from mxeval.data import write_jsonl


Expand All @@ -17,46 +17,74 @@ def run_code_eval(
num_samples_per_task: int,
out_path: Optional[str]=None,
format_tabs: bool = False,
debug_task_id: Optional[int] = None,
logger=common.EmptyLogger()
debug_task_id: Optional[int]=None,
logger=common.EmptyLogger(),
num_tasks: Optional[int]=None
):
problems = syncode.dataset.problems

samples = []
outputs = []

if syncode.language == "python":
stop_words = ["\n\n\n"]
elif syncode.language == "go":
stop_words = ["\n\n\n\n"]
else:
stop_words = None

pbar = tqdm(total=len(problems) * num_samples_per_task)
if debug_task_id is None:
time1 = time.time()
for task_id in problems:
outputs.append(CodeEval.run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, task_id))

# Run evaluation for all tasks
for task_id in list(problems.keys())[:num_tasks]:
outputs.append(CodeEval.run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, task_id, stop_words=stop_words))

if out_path is not None: write_jsonl(out_path, samples)

avg_time = (time.time() - time1) / len(problems)
functional_result = check_coorectness(out_path, logger=logger)
functional_result = check_corectness(out_path, logger=logger)
logger.log(f"Functional result: {functional_result}")

# Also log these results in a separate file
CodeEval.write_results(syncode, out_path, avg_time, functional_result)
else: # Debugging a specific task
debug_task_id = list(problems.keys())[debug_task_id]
return CodeEval.run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, debug_task_id, logger=logger)
return CodeEval.run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, debug_task_id, logger=logger, stop_words=stop_words)
return outputs

def run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, task_id, logger=common.EmptyLogger()):
def run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, task_id, logger=common.EmptyLogger(), stop_words=None):
"""
run evaluation for a specific task
"""
logger.log(f"Running eval for task {task_id}")

if format_tabs:
# In this case, we replace 4 spaces with a tab
prompt = problems[task_id]["prompt"].replace(" ", "\t")
problems[task_id]["prompt"] = prompt
else:
prompt = problems[task_id]["prompt"]

batch_completions = syncode.model.generate_batch_completion_grammar(prompt, num_samples_per_task)
batch_completions = syncode.model.generate_batch_completion_grammar(prompt, num_samples_per_task, stop_words=stop_words, return_token_ids=True)
batch_size = len(batch_completions)

for _, completion in enumerate(batch_completions):
for i, (_, generated_ids, input_ids) in enumerate(batch_completions):
input_ids_cutoff = input_ids.size(1)

# We tokenize the whole thing together since tokenizer just the generated_ids messes up with the
# indentation and removes the initial whitespaces in some cases
raw_completion = syncode.model.tokenizer.decode(generated_ids, skip_special_tokens=True)

# Post-processing to filter out using stop word
if syncode.model.grammar != None and syncode.model.grammar.name == "python":
completion = CodeEval.postproces_completion_python(syncode.model, i, batch_size, input_ids_cutoff, generated_ids, syncode.model.grammar_decoder, raw_completion, stop_words)
elif syncode.model.grammar != None and syncode.model.grammar.name == "go":
completion = CodeEval.postproces_completion_go(syncode.model, i, batch_size, raw_completion, generated_ids, syncode.model.grammar_decoder, input_ids_cutoff)
else: # TODO: handle the case for other grammars
completion = raw_completion

result = dict(
task_id=task_id,
language=problems[task_id]["language"],
Expand All @@ -77,4 +105,39 @@ def write_results(self, out_path, avg_time, functional_result):
f.write(f"Functional result: {functional_result}\n")
f.write(f"Output path: {out_path}\n")
f.write(f"Averge time taken for each task: {avg_time:.2f}s\n")
f.write("\n")
f.write("\n")

def postproces_completion_python(hf_model, i, batch_size, input_ids_cutoff, generated_ids, grammar_decoder, raw_completion, stop_words):
generated_output = hf_model.tokenizer.decode(generated_ids[input_ids_cutoff:])

if all(stop_word not in generated_output for stop_word in stop_words) and hf_model.tokenizer.eos_token_id != generated_ids[-1] and grammar_decoder is not None:
# Use when the stop word does not exist in the completion and grammar_decoder is used
function_incomplete = [False for _ in range(batch_size)]
completion = CodeEval.compute_backup_completion(hf_model, grammar_decoder, function_incomplete, i, raw_completion)
else:
completion = raw_completion
return completion

def postproces_completion_go(hf_model, i, batch_size, raw_completion, generated_ids, grammar_decoder, input_ids_cutoff):
if hf_model.mode != "original":
# When the grammar_decoder is used
function_incomplete = [False for _ in range(batch_size)]
completion = CodeEval.compute_backup_completion(hf_model, grammar_decoder, function_incomplete, i, raw_completion)

if function_incomplete[i]:
completion += "}"

return completion

def compute_backup_completion(hf_model, grammar_decoder, function_incomplete, i, raw_completion):
if grammar_decoder.function_ends[i] is not None and len(grammar_decoder.function_ends[i]) > 1:
# if the function end is not None, then the last valid state is the function end
last_valid_state = grammar_decoder.function_ends[i][1]
else:
# otherwise, the last valid state is the last valid state
function_incomplete[i] = True
last_valid_state = grammar_decoder.last_valid_state[i]

# Use when the stop word does not exist in the completion
backup_completion = raw_completion[:last_valid_state]
return backup_completion
15 changes: 6 additions & 9 deletions syncode/evaluation/mxeval_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def evaluate_functional_correctness(
sample_file: str,
k: List[int] = [1, 10, 100],
n_workers: int = os.cpu_count() - 1,
timeout: float = 10.0,
timeout: float = 100.0,
problem_file: str = '',
logger: common.Logger = common.EmptyLogger(),
):
Expand Down Expand Up @@ -59,7 +59,7 @@ def evaluate_functional_correctness(
completion_id[task_id] += 1
n_samples += 1

assert len(completion_id) == len(problems), "Some problems are not attempted."
# assert len(completion_id) == len(problems), "Some problems are not attempted."

logger.log_eval("Running test suites...")
i = 0
Expand Down Expand Up @@ -148,12 +148,9 @@ def unsafe_execute():
# Disable functionalities that can make destructive changes to the test.
reliability_guard()

prompt = problem["prompt"].replace("\t", " ")

# Construct the check program and run it.
check_program = (
prompt
+ completion
completion
+ "\n"
+ problem["test"]
+ "\n"
Expand Down Expand Up @@ -236,7 +233,7 @@ def check_correctness_helper(
logger: common.Logger = common.EmptyLogger()
):
current_dir = os.path.dirname(os.path.realpath(__file__))
entire_string = problem["prompt"] + completion + problem["test"]
entire_string = completion + problem["test"]

language_dirname = f"{language}_exec_eval"

Expand Down Expand Up @@ -401,7 +398,7 @@ def log_comparison(p, c1, c2, task_id):
logger.log_eval(f"Total: {count_total}")


def check_coorectness(
def check_corectness(
filename: str,
compare: bool = False,
logger: common.Logger = common.EmptyLogger(),
Expand Down Expand Up @@ -435,4 +432,4 @@ def check_coorectness(

if __name__ == "__main__":
logger = common.EmptyLogger()
fire.Fire(check_coorectness)
fire.Fire(check_corectness)
48 changes: 42 additions & 6 deletions syncode/evaluation/sql_eval.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
import os
import time
from typing import Optional
from tqdm import tqdm
from mxeval.data import write_jsonl


class SQLEval:
"""
Run evaluation on SQL dataset
Class for running evaluation on SQL spider dataset
"""
@staticmethod
def run_eval(syncode, out_path: Optional[str], debug_task_id: Optional[int] = None):
problems = syncode.dataset.problems[:100]
def run_eval(syncode, out_path: Optional[str], num_tasks: Optional[int]=None, debug_task_id: Optional[int] = None):
"""
Run evaluation on SQL dataset
"""
problems = syncode.dataset.problems

# Run only for num_tasks
if num_tasks is not None:
problems = problems[:num_tasks]

samples = []
pbar = tqdm(total=len(problems) * syncode.num_samples)
results = {}
Expand All @@ -26,15 +35,30 @@ def run_eval(syncode, out_path: Optional[str], debug_task_id: Optional[int] = No
with open(predict_file, 'w') as f:
for task_id, problem in enumerate(problems):
results[task_id] = []
start_time = time.time()
batch_completions = syncode.model.generate_batch_completion_grammar(
problem['prompt'],
syncode.num_samples
)
end_time = time.time()
raw_completion = batch_completions[0]
completion = syncode.dataset.post_process_answer(raw_completion)

extract = False
# If this flag is set try to extract the SQL query from the completion
if extract:
# We consider possibilities where the output is in either ``` {output} ``` or ```sql {output} ``` format
if '```' in completion:
completion = completion.split('```')[1]
if 'sql' in completion:
completion = completion.split('sql')[1]
print(f"Extracted completion: {completion}")

res = dict(
task_id=task_id,
completion=completion,
total_tokens=syncode.model.total_tokens,
total_time=end_time - start_time
)
samples += [res]
f.write(completion + '\n')
Expand All @@ -47,7 +71,7 @@ def run_eval(syncode, out_path: Optional[str], debug_task_id: Optional[int] = No
if out_path is not None and debug_task_id is None: write_jsonl(out_path, samples)

@staticmethod
def compute_accuracy(samples, predict_file):
def compute_accuracy(results_jsonl, predict_file):
from syncode.utils.sql_spider_eval.evaluation import evaluate

# Get current dir path
Expand All @@ -58,5 +82,17 @@ def compute_accuracy(samples, predict_file):
tables = f"{current_dir}/..//utils/sql_spider_eval/evaluation_examples/examples/tables.json"
databses = f"{current_dir}/..//utils/sql_spider_eval/databases"

scores, error_types = evaluate(predict_file, gold_file, databses, etype="all", table=tables, result_jsonl=samples)
print(f"Scores: {scores['all']}\n Error types: {error_types}")
scores, error_types = evaluate(predict_file, gold_file, databses, etype="all", table=tables, result_jsonl=results_jsonl)
print(f"Scores: {[(lvl, scores[lvl]['exec']) for lvl in scores.keys()]}\nError types: {dict(error_types)}\nCounts: {[(lvl, scores[lvl]['count']) for lvl in scores.keys()]}")

print("Execution accuracy:", scores['all']['exec'])
print(f"Compilation error types: {dict(error_types)}")

# Average token count
total_tokens = sum([r['total_tokens'] for r in results_jsonl])
print(f"Average token count: {total_tokens/len(results_jsonl)}")

# Average time
total_time = sum([r['total_time'] for r in results_jsonl])
print(f"Average time: {total_time/len(results_jsonl)}")
return scores, error_types, results_jsonl
32 changes: 17 additions & 15 deletions syncode/grammar_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self,

# For backtracking to syntactically valid completions
self.last_valid_state: list = []
self.function_end: list = []
self.function_ends: list = []

# We use this when only the LLM output is parsed and not (input+output)
self.parse_output_only = parse_output_only
Expand Down Expand Up @@ -89,7 +89,7 @@ def reset(self, prompt: str):
Resets the decoder state on every new prompt.
"""
self.last_valid_state = [0 for _ in range(self.batch_size)]
self.function_end = [None for _ in range(self.batch_size)]
self.function_ends = [None for _ in range(self.batch_size)]

prompt_tokens = self.tokenizer.encode(prompt, return_tensors='pt')[0]
if self.parse_output_only:
Expand Down Expand Up @@ -121,17 +121,19 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) ->
self.logger.log(f"Exception while parsing:\n {e}")
return False

self.update_valid_state(input_ids, 0, r)

if input_ids[0, -1] == self.tokenizer.eos_token_id:
return AcceptSequence(['$END']) in r.accept_sequences
is_valid = AcceptSequence(['$END']) in r.accept_sequences

if r.remainder_state == RemainderState.COMPLETE or r.remainder_state == RemainderState.MAYBE_COMPLETE:
return True
is_valid = True

# Check if the remainder is a valid prefix for the last terminal
out = self.dfa_mask_store.is_valid_prefix(r)
return out
is_valid = self.dfa_mask_store.is_valid_prefix(r)

if is_valid:
self.update_valid_state(partial_code, 0, r)

return is_valid


def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
Expand All @@ -142,13 +144,12 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
## Parsing
try: # returns the accept sequences that are currently accepted.
r = self.inc_parser.get_acceptable_next_terminals(partial_code)
self.update_valid_state(partial_code, idx, r)
except Exception as e:
if self.dev_mode == True:
raise e
self.logger.log(f"Exception while parsing:\n {e}")
continue # Skip altering the scores for this batch

self.update_valid_state(input_ids, idx, r)

accept_mask = self.dfa_mask_store.get_accept_mask(r, logger=self.logger)

Expand Down Expand Up @@ -176,20 +177,21 @@ def _get_partial_codes(self, input_ids: torch.LongTensor):
partial_codes = self.tokenizer.batch_decode(input_ids[:, self.start_from:], skip_special_tokens=True)
return partial_codes

def update_valid_state(self, input_ids, idx: int, r: ParseResult):
def update_valid_state(self, partial_code: str, idx: int, r: ParseResult):
"""
This a simple heuristic to cut off the generated output at the end of the function.
TODO: Put this under a flag to enable/disable this heuristic.
"""
if idx < len(self.function_end):
if r.function_end and self.function_end[idx] == None: # If the function end is not None, then the last valid state is the function end
self.function_end[idx] = len(input_ids[idx])-1
if idx < len(self.function_ends):
if r.function_end: # If the function end is not None, then the last valid state is the function end
if self.function_ends[idx] is None: self.function_ends[idx] = []
self.function_ends[idx].append(len(partial_code) - len(r.remainder))

if idx < len(self.last_valid_state):
for accept_seq in r.accept_sequences:
# 'EOF' is special terminal since $END does not work with python
if accept_seq[0] == '$END' or accept_seq[0] == 'EOF':
self.last_valid_state[idx] = len(input_ids[idx])-1
self.last_valid_state[idx] = len(partial_code) - len(r.remainder)

def _debug_greedy(self, scores, idx, partial_code, r, greedy_token):
greedy_grammar_token = self.tokenizer.decode(scores[idx].argmax(dim=-1))
Expand Down
Loading

0 comments on commit cd377ff

Please sign in to comment.