Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor evaluation #123

Merged
merged 1 commit into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 82 additions & 12 deletions syncode/evaluation/code_eval.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import time
import torch
from tqdm import tqdm
from typing import Optional
from syncode import common
from syncode.evaluation.mxeval_evaluation import check_coorectness
from syncode.evaluation.mxeval_evaluation import check_corectness
from mxeval.data import write_jsonl


Expand All @@ -17,54 +18,87 @@ def run_code_eval(
num_samples_per_task: int,
out_path: Optional[str]=None,
format_tabs: bool = False,
debug_task_id: Optional[int] = None,
logger=common.EmptyLogger()
debug_task_id: Optional[int]=None,
logger=common.EmptyLogger(),
num_tasks: Optional[int]=None
):
problems = syncode.dataset.problems

samples = []
outputs = []

if syncode.language == "python":
stop_words = ["\n\n\n"]
elif syncode.language == "go":
stop_words = ["\n\n\n"]
else:
stop_words = None

pbar = tqdm(total=len(problems) * num_samples_per_task)
if debug_task_id is None:
time1 = time.time()
for task_id in problems:
outputs.append(CodeEval.run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, task_id))

# Run evaluation for all tasks
for task_id in list(problems.keys())[:num_tasks]:
outputs.append(CodeEval.run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, task_id, stop_words=stop_words))

if out_path is not None: write_jsonl(out_path, samples)

avg_time = (time.time() - time1) / len(problems)
functional_result = check_coorectness(out_path, logger=logger)
functional_result = check_corectness(out_path, logger=logger)
logger.log(f"Functional result: {functional_result}")

# Also log these results in a separate file
CodeEval.write_results(syncode, out_path, avg_time, functional_result)
else: # Debugging a specific task
debug_task_id = list(problems.keys())[debug_task_id]
return CodeEval.run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, debug_task_id, logger=logger)
return CodeEval.run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, debug_task_id, logger=logger, stop_words=stop_words)
return outputs

def run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, task_id, logger=common.EmptyLogger()):
def run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, task_id, logger=common.EmptyLogger(), stop_words=None):
"""
run evaluation for a specific task
"""
logger.log(f"Running eval for task {task_id}")

if format_tabs:
# In this case, we replace 4 spaces with a tab
prompt = problems[task_id]["prompt"].replace(" ", "\t")
problems[task_id]["prompt"] = prompt
else:
prompt = problems[task_id]["prompt"]

batch_completions = syncode.model.generate_batch_completion_grammar(prompt, num_samples_per_task)
batch_completions = syncode.model.generate_batch_completion_grammar(prompt, num_samples_per_task, stop_words=stop_words, return_token_ids=True)
batch_size = len(batch_completions)
all_completions = []

for _, completion in enumerate(batch_completions):
for i, (_, generated_ids, input_ids) in enumerate(batch_completions):
input_ids_cutoff = input_ids.size(1)

# We tokenize the whole thing together since tokenizer just the generated_ids messes up with the
# indentation and removes the initial whitespaces in some cases
raw_completion = syncode.model.tokenizer.decode(generated_ids, skip_special_tokens=True)

# Post-processing to filter out using stop word
if syncode.model.grammar != None and syncode.model.grammar.name == "python":
completion = CodeEval.postproces_completion_python(syncode.model, i, batch_size, input_ids_cutoff, generated_ids, syncode.model.grammar_decoder, raw_completion, stop_words)
elif syncode.model.grammar != None and syncode.model.grammar.name == "go":
completion = CodeEval.postproces_completion_go(syncode.model, i, batch_size, raw_completion, generated_ids, syncode.model.grammar_decoder, input_ids_cutoff)
else: # TODO: handle the case for other grammars
completion = raw_completion

result = dict(
task_id=task_id,
language=problems[task_id]["language"],
completion=completion
)
samples += [result]
all_completions.append(completion)
pbar.update(num_samples_per_task)
return batch_completions

# Clear the cache
torch.cuda.empty_cache()
return all_completions

def write_results(self, out_path, avg_time, functional_result):
"""
Expand All @@ -77,4 +111,40 @@ def write_results(self, out_path, avg_time, functional_result):
f.write(f"Functional result: {functional_result}\n")
f.write(f"Output path: {out_path}\n")
f.write(f"Averge time taken for each task: {avg_time:.2f}s\n")
f.write("\n")
f.write("\n")

def postproces_completion_python(hf_model, i, batch_size, input_ids_cutoff, generated_ids, grammar_decoder, raw_completion, stop_words):
generated_output = hf_model.tokenizer.decode(generated_ids[input_ids_cutoff:])

if all(stop_word not in generated_output for stop_word in stop_words) and hf_model.tokenizer.eos_token_id != generated_ids[-1] and grammar_decoder is not None:
# Use when the stop word does not exist in the completion and grammar_decoder is used
function_incomplete = [False for _ in range(batch_size)]
completion = CodeEval.compute_backup_completion(hf_model, grammar_decoder, function_incomplete, i, raw_completion)
else:
completion = raw_completion
return completion

def postproces_completion_go(hf_model, i, batch_size, raw_completion, generated_ids, grammar_decoder, input_ids_cutoff):
if hf_model.mode != "original":
# When the grammar_decoder is used
function_incomplete = [False for _ in range(batch_size)]
completion = CodeEval.compute_backup_completion(hf_model, grammar_decoder, function_incomplete, i, raw_completion)

if function_incomplete[i]:
completion += "}"

return completion

def compute_backup_completion(hf_model, grammar_decoder, function_incomplete, i, raw_completion):
fn_ends = sorted(list(set(grammar_decoder.function_ends[i])))
if grammar_decoder.function_ends[i] is not None and len(fn_ends) > 1:
# if the function end is not None, then the last valid state is the function end
last_valid_state = fn_ends[1]
else:
# otherwise, the last valid state is the last valid state
function_incomplete[i] = True
last_valid_state = grammar_decoder.last_valid_state[i]

# Use when the stop word does not exist in the completion
backup_completion = raw_completion[:last_valid_state]
return backup_completion
15 changes: 6 additions & 9 deletions syncode/evaluation/mxeval_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def evaluate_functional_correctness(
sample_file: str,
k: List[int] = [1, 10, 100],
n_workers: int = os.cpu_count() - 1,
timeout: float = 10.0,
timeout: float = 100.0,
problem_file: str = '',
logger: common.Logger = common.EmptyLogger(),
):
Expand Down Expand Up @@ -59,7 +59,7 @@ def evaluate_functional_correctness(
completion_id[task_id] += 1
n_samples += 1

assert len(completion_id) == len(problems), "Some problems are not attempted."
# assert len(completion_id) == len(problems), "Some problems are not attempted."

logger.log_eval("Running test suites...")
i = 0
Expand Down Expand Up @@ -148,12 +148,9 @@ def unsafe_execute():
# Disable functionalities that can make destructive changes to the test.
reliability_guard()

prompt = problem["prompt"].replace("\t", " ")

# Construct the check program and run it.
check_program = (
prompt
+ completion
completion
+ "\n"
+ problem["test"]
+ "\n"
Expand Down Expand Up @@ -236,7 +233,7 @@ def check_correctness_helper(
logger: common.Logger = common.EmptyLogger()
):
current_dir = os.path.dirname(os.path.realpath(__file__))
entire_string = problem["prompt"] + completion + problem["test"]
entire_string = completion + problem["test"]

language_dirname = f"{language}_exec_eval"

Expand Down Expand Up @@ -401,7 +398,7 @@ def log_comparison(p, c1, c2, task_id):
logger.log_eval(f"Total: {count_total}")


def check_coorectness(
def check_corectness(
filename: str,
compare: bool = False,
logger: common.Logger = common.EmptyLogger(),
Expand Down Expand Up @@ -435,4 +432,4 @@ def check_coorectness(

if __name__ == "__main__":
logger = common.EmptyLogger()
fire.Fire(check_coorectness)
fire.Fire(check_corectness)
48 changes: 42 additions & 6 deletions syncode/evaluation/sql_eval.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
import os
import time
from typing import Optional
from tqdm import tqdm
from mxeval.data import write_jsonl


class SQLEval:
"""
Run evaluation on SQL dataset
Class for running evaluation on SQL spider dataset
"""
@staticmethod
def run_eval(syncode, out_path: Optional[str], debug_task_id: Optional[int] = None):
problems = syncode.dataset.problems[:100]
def run_eval(syncode, out_path: Optional[str], num_tasks: Optional[int]=None, debug_task_id: Optional[int] = None):
"""
Run evaluation on SQL dataset
"""
problems = syncode.dataset.problems

# Run only for num_tasks
if num_tasks is not None:
problems = problems[:num_tasks]

samples = []
pbar = tqdm(total=len(problems) * syncode.num_samples)
results = {}
Expand All @@ -26,15 +35,30 @@ def run_eval(syncode, out_path: Optional[str], debug_task_id: Optional[int] = No
with open(predict_file, 'w') as f:
for task_id, problem in enumerate(problems):
results[task_id] = []
start_time = time.time()
batch_completions = syncode.model.generate_batch_completion_grammar(
problem['prompt'],
syncode.num_samples
)
end_time = time.time()
raw_completion = batch_completions[0]
completion = syncode.dataset.post_process_answer(raw_completion)

extract = False
# If this flag is set try to extract the SQL query from the completion
if extract:
# We consider possibilities where the output is in either ``` {output} ``` or ```sql {output} ``` format
if '```' in completion:
completion = completion.split('```')[1]
if 'sql' in completion:
completion = completion.split('sql')[1]
print(f"Extracted completion: {completion}")

res = dict(
task_id=task_id,
completion=completion,
total_tokens=syncode.model.total_tokens,
total_time=end_time - start_time
)
samples += [res]
f.write(completion + '\n')
Expand All @@ -47,7 +71,7 @@ def run_eval(syncode, out_path: Optional[str], debug_task_id: Optional[int] = No
if out_path is not None and debug_task_id is None: write_jsonl(out_path, samples)

@staticmethod
def compute_accuracy(samples, predict_file):
def compute_accuracy(results_jsonl, predict_file):
from syncode.utils.sql_spider_eval.evaluation import evaluate

# Get current dir path
Expand All @@ -58,5 +82,17 @@ def compute_accuracy(samples, predict_file):
tables = f"{current_dir}/..//utils/sql_spider_eval/evaluation_examples/examples/tables.json"
databses = f"{current_dir}/..//utils/sql_spider_eval/databases"

scores, error_types = evaluate(predict_file, gold_file, databses, etype="all", table=tables, result_jsonl=samples)
print(f"Scores: {scores['all']}\n Error types: {error_types}")
scores, error_types = evaluate(predict_file, gold_file, databses, etype="all", table=tables, result_jsonl=results_jsonl)
print(f"Scores: {[(lvl, scores[lvl]['exec']) for lvl in scores.keys()]}\nError types: {dict(error_types)}\nCounts: {[(lvl, scores[lvl]['count']) for lvl in scores.keys()]}")

print("Execution accuracy:", scores['all']['exec'])
print(f"Compilation error types: {dict(error_types)}")

# Average token count
total_tokens = sum([r['total_tokens'] for r in results_jsonl])
print(f"Average token count: {total_tokens/len(results_jsonl)}")

# Average time
total_time = sum([r['total_time'] for r in results_jsonl])
print(f"Average time: {total_time/len(results_jsonl)}")
return scores, error_types, results_jsonl
32 changes: 17 additions & 15 deletions syncode/grammar_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self,

# For backtracking to syntactically valid completions
self.last_valid_state: list = []
self.function_end: list = []
self.function_ends: list = []

# We use this when only the LLM output is parsed and not (input+output)
self.parse_output_only = parse_output_only
Expand Down Expand Up @@ -89,7 +89,7 @@ def reset(self, prompt: str):
Resets the decoder state on every new prompt.
"""
self.last_valid_state = [0 for _ in range(self.batch_size)]
self.function_end = [None for _ in range(self.batch_size)]
self.function_ends = [None for _ in range(self.batch_size)]

prompt_tokens = self.tokenizer.encode(prompt, return_tensors='pt')[0]
if self.parse_output_only:
Expand Down Expand Up @@ -121,17 +121,19 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) ->
self.logger.log(f"Exception while parsing:\n {e}")
return False

self.update_valid_state(input_ids, 0, r)

if input_ids[0, -1] == self.tokenizer.eos_token_id:
return AcceptSequence(['$END']) in r.accept_sequences
is_valid = AcceptSequence(['$END']) in r.accept_sequences

if r.remainder_state == RemainderState.COMPLETE or r.remainder_state == RemainderState.MAYBE_COMPLETE:
return True
is_valid = True

# Check if the remainder is a valid prefix for the last terminal
out = self.dfa_mask_store.is_valid_prefix(r)
return out
is_valid = self.dfa_mask_store.is_valid_prefix(r)

if is_valid:
self.update_valid_state(partial_code, 0, r)

return is_valid


def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
Expand All @@ -142,13 +144,12 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
## Parsing
try: # returns the accept sequences that are currently accepted.
r = self.inc_parser.get_acceptable_next_terminals(partial_code)
self.update_valid_state(partial_code, idx, r)
except Exception as e:
if self.dev_mode == True:
raise e
self.logger.log(f"Exception while parsing:\n {e}")
continue # Skip altering the scores for this batch

self.update_valid_state(input_ids, idx, r)

accept_mask = self.dfa_mask_store.get_accept_mask(r, logger=self.logger)

Expand Down Expand Up @@ -176,20 +177,21 @@ def _get_partial_codes(self, input_ids: torch.LongTensor):
partial_codes = self.tokenizer.batch_decode(input_ids[:, self.start_from:], skip_special_tokens=True)
return partial_codes

def update_valid_state(self, input_ids, idx: int, r: ParseResult):
def update_valid_state(self, partial_code: str, idx: int, r: ParseResult):
"""
This a simple heuristic to cut off the generated output at the end of the function.
TODO: Put this under a flag to enable/disable this heuristic.
"""
if idx < len(self.function_end):
if r.function_end and self.function_end[idx] == None: # If the function end is not None, then the last valid state is the function end
self.function_end[idx] = len(input_ids[idx])-1
if idx < len(self.function_ends):
if r.function_end: # If the function end is not None, then the last valid state is the function end
if self.function_ends[idx] is None: self.function_ends[idx] = []
self.function_ends[idx].append(len(partial_code) - len(r.remainder))

if idx < len(self.last_valid_state):
for accept_seq in r.accept_sequences:
# 'EOF' is special terminal since $END does not work with python
if accept_seq[0] == '$END' or accept_seq[0] == 'EOF':
self.last_valid_state[idx] = len(input_ids[idx])-1
self.last_valid_state[idx] = len(partial_code) - len(r.remainder)

def _debug_greedy(self, scores, idx, partial_code, r, greedy_token):
greedy_grammar_token = self.tokenizer.decode(scores[idx].argmax(dim=-1))
Expand Down
Loading
Loading