Skip to content

Commit

Permalink
Decouple infer and evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhamugare committed Aug 8, 2024
1 parent eab96b7 commit 0f423a1
Show file tree
Hide file tree
Showing 10 changed files with 223 additions and 110 deletions.
86 changes: 86 additions & 0 deletions notebooks/eval_json.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/shubham/anaconda3/envs/codex/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 11.83it/s]\n"
]
}
],
"source": [
"# Running Gemma-2-2b-it on JSON evaluation from nousResearch/json-mode-eval\n",
"import sys, os\n",
"sys.path.append(os.getcwd() + '/../')\n",
"from syncode import Syncode\n",
"\n",
"syn_llm = Syncode(\n",
" mode=\"grammar_mask\",\n",
" model=\"google/gemma-2-2b-it\",\n",
" grammar=\"json\",\n",
" max_new_tokens=400,\n",
" parser=\"lr\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 100/100 [07:50<00:00, 4.71s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Result: {'pass@1': 0.99}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"result = syn_llm.evaluate(dataset=\"json_eval\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "codex",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
33 changes: 24 additions & 9 deletions notebooks/example_date.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,33 @@
"cells": [
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 1,
"id": "ffd1e5da",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
"/home/shubham/anaconda3/envs/codex/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 3.61it/s]\n",
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 3.69it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Creating DFA mask store for CodeGenTokenizerFast and custom, may take more than 10 minutes. Caching at /home/shubham/syncode/cache/mask_stores/CodeGenTokenizerFast/grammar_strict_1140931734_50257.pkl.\n",
"Ignore whitespace tokens is False\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 4.00it/s]\n",
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 4.08it/s]\n",
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
"100%|██████████| 96/96 [00:24<00:00, 3.89it/s]\n"
]
}
],
Expand All @@ -45,10 +53,17 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 2,
"id": "f3f9f3b8",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
Expand All @@ -73,7 +88,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 3,
"id": "4bf95c2b",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -101,7 +116,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 4,
"id": "3ad2158b",
"metadata": {},
"outputs": [
Expand Down
18 changes: 11 additions & 7 deletions syncode/common.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import os
from transformers import (
LlamaTokenizer,
LlamaForCausalLM,
AutoTokenizer,
AutoModelForCausalLM,
)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Remove this in future and add instruction to set the HF_CACHE env variable
RESULTS_DIR = os.environ['RESULTS_DIR'] if 'RESULTS_DIR' in os.environ else 'results/'
Expand Down Expand Up @@ -52,6 +47,12 @@ def load_tokenizer(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True)
return tokenizer

def get_output_path(model_name, grammar, dataset, num_samples, mode):
out_dir = f"results/{model_name}/{grammar}/{dataset}/"
out_path = out_dir + 'samples_' + str(num_samples) + '_mode_' + str(mode) + "_eval.jsonl"
os.makedirs(out_dir, exist_ok=True)
return out_dir,out_path

class Logger:
"""
Logger class for logging the output of the model
Expand Down Expand Up @@ -135,4 +136,7 @@ def log_error(self, msg):
pass
def close(self):
pass

def is_closed(self):
return False
def open(self):
pass
20 changes: 12 additions & 8 deletions syncode/evaluation/code_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
from tqdm import tqdm
from typing import Optional
from syncode import common
from syncode.evaluation.mxeval_evaluation import check_coorectness
from mxeval.data import write_jsonl

Expand All @@ -14,9 +15,10 @@ class CodeEval:
def run_code_eval(
syncode,
num_samples_per_task: int,
out_path: str,
out_path: Optional[str]=None,
format_tabs: bool = False,
debug_task_id: Optional[int] = None
debug_task_id: Optional[int] = None,
logger=common.EmptyLogger()
):
problems = syncode.dataset.problems

Expand All @@ -27,23 +29,25 @@ def run_code_eval(
time1 = time.time()
for task_id in problems:
outputs.append(CodeEval.run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, task_id))
write_jsonl(out_path, samples)

if out_path is not None: write_jsonl(out_path, samples)

avg_time = (time.time() - time1) / len(problems)
functional_result = check_coorectness(out_path, logger=syncode.logger)
syncode.logger.log(f"Functional result: {functional_result}")
functional_result = check_coorectness(out_path, logger=logger)
logger.log(f"Functional result: {functional_result}")

# Also log these results in a separate file
CodeEval.write_results(syncode, out_path, avg_time, functional_result)
else: # Debugging a specific task
debug_task_id = list(problems.keys())[debug_task_id]
return CodeEval.run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, debug_task_id)
return CodeEval.run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, debug_task_id, logger=logger)
return outputs

def run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, task_id):
def run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samples, pbar, task_id, logger=common.EmptyLogger()):
"""
run evaluation for a specific task
"""
syncode.logger.log(f"Running eval for task {task_id}")
logger.log(f"Running eval for task {task_id}")

if format_tabs:
prompt = problems[task_id]["prompt"].replace(" ", "\t")
Expand Down
8 changes: 5 additions & 3 deletions syncode/evaluation/fol_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import random
import re
from typing import Optional
from mxeval.data import write_jsonl
from tqdm import tqdm
import signal
Expand Down Expand Up @@ -77,7 +78,7 @@

class FOLEval:
@staticmethod
def run_eval(syncode, debug_task_id=None):
def run_eval(syncode, out_path: Optional[str]=None, debug_task_id=None):
problems = syncode.dataset.problems[:100]
if debug_task_id is not None:
problems = [problems[debug_task_id]]
Expand Down Expand Up @@ -152,8 +153,9 @@ def run_eval(syncode, debug_task_id=None):
count_syn_error += (not is_parsed)
samples += [res]
pbar.update(syncode.num_samples)

write_jsonl(syncode.out_path, samples)

if out_path is not None: write_jsonl(out_path, samples)

print(f"Pass rate: {count_pass}/{len(problems)}")
print(f"Compilation error rate: {count_compile_error}/{len(problems)}")
print(f"Execution error rate: {count_exec_error}/{len(problems)}")
Expand Down
Loading

0 comments on commit 0f423a1

Please sign in to comment.