Skip to content

Commit

Permalink
Merge pull request #103 from uiuc-focal-lab/json_eval3
Browse files Browse the repository at this point in the history
Update JSON eval notebook with original and explicit prompts
  • Loading branch information
shubhamugare authored Aug 11, 2024
2 parents b9d70a8 + 694a452 commit 1515d6e
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 8 deletions.
77 changes: 76 additions & 1 deletion notebooks/eval_json.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,27 @@
"text": [
"/home/shubham/anaconda3/envs/codex/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 11.83it/s]\n"
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 11.82it/s]\n",
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 10.95it/s]\n"
]
}
],
"source": [
"# We compare the original model with SynCode in generating JSON with valid schema\n",
"# Running Gemma-2-2b-it on JSON evaluation from nousResearch/json-mode-eval\n",
"import sys, os\n",
"sys.path.append(os.getcwd() + '/../')\n",
"from syncode import Syncode\n",
"\n",
"# Load original model\n",
"llm = Syncode(\n",
" mode=\"original\",\n",
" model=\"google/gemma-2-2b-it\",\n",
" grammar=\"json\",\n",
" max_new_tokens=400,\n",
")\n",
"\n",
"# Load model with grammar mask\n",
"syn_llm = Syncode(\n",
" mode=\"grammar_mask\",\n",
" model=\"google/gemma-2-2b-it\",\n",
Expand All @@ -30,6 +41,70 @@
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 100/100 [07:11<00:00, 4.32s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Result: {'pass@1': 0.41}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# Accuracy of original model with original prompt ending \"JSON:\"\n",
"_ = llm.evaluate(dataset=\"json_eval\", prompt_type=\"original\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 100/100 [05:54<00:00, 3.54s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Result: {'pass@1': 0.43}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# Accuracy of the model with additionally explicitly specifying \"Generate only JSON\" in the prompt\n",
"_ = llm.evaluate(dataset=\"json_eval\", prompt_type=\"explicit\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
Expand Down
14 changes: 10 additions & 4 deletions syncode/evaluation/json_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def run_json_eval(
syncode,
out_path: Optional[str],
debug_task_id: Optional[int] = None,
logger=common.EmptyLogger()
logger=common.EmptyLogger(),
prompt_type='original'
):
problems = syncode.dataset.problems
if syncode.grammar_decoder is not None:
Expand All @@ -33,7 +34,7 @@ def run_json_eval(
results = defaultdict(list)

for task_id, problem in enumerate(problems):
output = JSONEval.run_eval_for_task(syncode, syncode.num_samples, problem, samples, pbar, task_id)
output = JSONEval.run_eval_for_task(syncode, syncode.num_samples, problem, samples, pbar, task_id, prompt_type=prompt_type)
if debug_task_id is not None:
return output
outputs.append(outputs)
Expand All @@ -53,12 +54,17 @@ def run_json_eval(
logger.close()
return outputs

def run_eval_for_task(syncode, num_samples_per_task, problem, samples, pbar, task_id):
def run_eval_for_task(syncode, num_samples_per_task, problem, samples, pbar, task_id, prompt_type='original'):
"""
run evaluation for a specific task
"""

if prompt_type == 'original':
problem["prompt"][0]['content'] = f"{problem['prompt'][0]['content']}\nJSON:\n"
else:
problem["prompt"][0]['content'] = f"{problem['prompt'][0]['content']}\nOnly output JSON.\nJSON:\n"

prompt = syncode.model.tokenizer.apply_chat_template(problem["prompt"], tokenize = False)
prompt = f"{prompt}\nJSON:\n"

batch_completions = syncode.model.generate_batch_completion_grammar(prompt, num_samples_per_task)
for completion_id, completion in enumerate(batch_completions):
Expand Down
8 changes: 5 additions & 3 deletions syncode/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
self.parse_output_only = parse_output_only

# Set the grammar
self.language = grammar
self.grammar = Grammar(grammar) if self.is_grammar_mode() else None

# Load model
Expand Down Expand Up @@ -143,7 +144,8 @@ def evaluate(
out_path: str=None,
num_few_shot:int=0,
logger=common.EmptyLogger(),
task_id=None
task_id=None,
prompt_type='original' # For JSONEvalL: "original" or "explicit"
) -> dict:
"""
Run evaluation on the model:
Expand All @@ -159,7 +161,7 @@ def evaluate(
logger.open()

# Load the dataset
self.dataset = Dataset(dataset, language=self.grammar.name, num_few_shot=num_few_shot)
self.dataset = Dataset(dataset, language=self.language, num_few_shot=num_few_shot)

if self.dataset.type == "code":
output = CodeEval.run_code_eval(self, self.num_samples, out_path, format_tabs=True, debug_task_id=task_id, logger=logger)
Expand All @@ -170,7 +172,7 @@ def evaluate(
elif self.dataset.type == "fol":
output = FOLEval.run_eval(self, out_path, debug_task_id=task_id)
elif self.dataset.type == "json":
output = JSONEval.run_json_eval(self, out_path, debug_task_id=task_id, logger=logger)
output = JSONEval.run_json_eval(self, out_path, debug_task_id=task_id, logger=logger, prompt_type=prompt_type)
else:
raise ValueError(f"Dataset type {self.dataset.type} not supported")
logger.close()
Expand Down

0 comments on commit 1515d6e

Please sign in to comment.