Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update incrementality handling of the parser #113

Merged
merged 1 commit into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions notebooks/tests/beam_search2.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/shubham/anaconda3/envs/codex/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Unrecognized keys in `rope_scaling` for 'rope_type'='linear': {'type'}\n"
]
}
],
"source": [
"from syncode import SyncodeLogitsProcessor, Grammar\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"import lark\n",
"\n",
"grammar_str = r\"\"\"\n",
"start: item (\",\" item)* \n",
"\n",
"item: \"'\" name \"'\"\n",
" | \"\\\"\" name \"\\\"\"\n",
"\n",
"name: \"Alice\" \n",
" | \"Bob\" \n",
" | \"Carol\" \n",
" | \"Dave\"\n",
" | \"Eve\"\n",
"\"\"\"\n",
"\n",
"device = \"cuda\"\n",
"model_name = \"deepseek-ai/deepseek-coder-1.3b-base\"\n",
"model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"auto\").eval()\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"\n",
"syncode_grammar = Grammar(grammar_str)\n",
"parser = lark.Lark(grammar_str)\n",
"\n",
"prompt = \"A list of male first names:\\n\"\n",
"\n",
"inputs = tokenizer(prompt, return_tensors=\"pt\").input_ids.to(device)\n",
"\n",
"constrain = True\n",
"\n",
"args = {\n",
" \"max_new_tokens\" : 128,\n",
" \"do_sample\" : True,\n",
" \"num_beams\" : 2,\n",
" \"num_return_sequences\" : 2,\n",
" \"pad_token_id\" : tokenizer.eos_token_id,\n",
"}\n",
"\n",
"syncode_logits_processor = SyncodeLogitsProcessor(\n",
" grammar=syncode_grammar, \n",
" tokenizer=tokenizer, \n",
" parse_output_only=True, \n",
" num_samples=args[\"num_beams\"],\n",
" mode=\"grammar_strict\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[\"'Alice'\", '\"Alice\"']\n",
">>> COMPLETION 0\n",
"\n",
"CAN PARSE\n",
"\n",
"'Alice'\n",
"\n",
"[\"'\", 'A', 'lic', 'e', \"'\", '<|end▁of▁sentence|>']\n",
"\n",
">>> COMPLETION 1\n",
"\n",
"CAN PARSE\n",
"\n",
"\"Alice\"\n",
"\n",
"['\"', 'A', 'lic', 'e', '\"', '<|end▁of▁sentence|>']\n",
"\n"
]
}
],
"source": [
"syncode_logits_processor.reset(prompt)\n",
"\n",
"outputs = model.generate(\n",
" inputs,\n",
" logits_processor=[syncode_logits_processor] if constrain else None,\n",
" **args,\n",
")\n",
"\n",
"outputs = [o[len(inputs[0]):] for o in outputs]\n",
"completions = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
"completions_tokens = [[tokenizer.decode(tok) for tok in output] for output in outputs]\n",
"print(completions)\n",
"\n",
"\n",
"for i, (c, toks) in enumerate(zip(completions, completions_tokens)):\n",
" print(f\">>> COMPLETION {i}\\n\")\n",
" try:\n",
" tree = parser.parse(c)\n",
" print(\"CAN PARSE\\n\")\n",
" except:\n",
" print(\"CANNOT PARSE\\n\") \n",
" print(f\"{c}\\n\")\n",
" print(f\"{toks}\\n\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "codex",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
12 changes: 5 additions & 7 deletions syncode/grammar_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def __init__(self,
mode=mode,
)

# Create parsers
self.inc_parsers: Iterator[IncrementalParser] = [create_parser(self.grammar, logger=self.logger, parser=parser, ignore_whitespace=self._ignore_whitespace) for _ in range(self.batch_size)]
# Create parser
self.inc_parser: IncrementalParser = create_parser(self.grammar, logger=self.logger, parser=parser, ignore_whitespace=self._ignore_whitespace)


def _log_current_status(self, partial_code, r: ParseResult):
Expand Down Expand Up @@ -97,8 +97,7 @@ def reset(self, prompt: str):
else:
self.start_from = 0

for p in self.inc_parsers:
p.reset()
self.inc_parser.reset()


def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) -> bool:
Expand All @@ -117,7 +116,7 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) ->
partial_code = self._get_partial_codes(input_ids)[0]

try:
r = self.inc_parsers[0].get_acceptable_next_terminals(partial_code)
r = self.inc_parser.get_acceptable_next_terminals(partial_code)
except Exception as e:
self.logger.log(f"Exception while parsing:\n {e}")
return False
Expand All @@ -135,12 +134,11 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
# start_from is used for choosing where the parsing should start
debug = True
partial_codes = self._get_partial_codes(input_ids)
assert len(partial_codes) == len(self.inc_parsers), "Number of partial codes should match the number of parsers. Make sure that the argument `num_samples` is set correctly in the SyncodeLogitsProcessor."

for idx, partial_code in enumerate(partial_codes):
## Parsing
try: # returns the accept sequences that are currently accepted.
r = self.inc_parsers[idx].get_acceptable_next_terminals(partial_code)
r = self.inc_parser.get_acceptable_next_terminals(partial_code)
except Exception as e:
if self.dev_mode == True:
raise e
Expand Down
10 changes: 8 additions & 2 deletions syncode/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
import fire
import syncode.common as common
import torch
from syncode.language_model import HuggingFaceModel
from syncode.grammar_decoder import SyncodeLogitsProcessor
from typing import Optional, Literal
Expand All @@ -14,9 +15,9 @@
from syncode.evaluation.fol_eval import FOLEval


def compile_and_run(model, mode="grammar_strict", quantize=True, device="cuda", num_samples=1, grammar=None, dataset="input", num_few_shot=0, chat_mode=False, dev_mode=False, log_level=1, new_mask_store=False, parser="lalr", task_id=None, **kwargs):
def compile_and_run(model, mode="grammar_strict", quantize=True, device="cuda", num_samples=1, grammar=None, dataset="input", num_few_shot=0, chat_mode=False, dev_mode=False, log_level=1, new_mask_store=False, parser="lalr", task_id=None, seed=None, **kwargs):

syncode = Syncode(model, mode=mode, quantize=quantize, device=device, num_samples=num_samples, grammar=grammar, chat_mode=chat_mode, dev_mode=dev_mode, log_level=log_level, new_mask_store=new_mask_store, parser=parser, **kwargs)
syncode = Syncode(model, mode=mode, quantize=quantize, device=device, num_samples=num_samples, grammar=grammar, chat_mode=chat_mode, dev_mode=dev_mode, log_level=log_level, new_mask_store=new_mask_store, parser=parser, seed=seed, **kwargs)

if dataset == "input":
syncode.infer()
Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(
log_level: int = 1,
new_mask_store: bool = False,
parser: Literal["lr", "lalr"] = "lalr",
seed: Optional[int] = None,
**kwargs
):
# Check inputs
Expand All @@ -90,6 +92,10 @@ def __init__(
self.chat_mode = chat_mode
self.log_level = log_level

# Set seed
if seed is not None:
torch.manual_seed(seed)

if self.chat_mode:
self.parse_output_only = True
else:
Expand Down
3 changes: 1 addition & 2 deletions syncode/parsers/go_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ def get_acceptable_next_terminals(self, partial_code) -> ParseResult:
# Restore the previous state of the parser
self._restore_recent_parser_state(lexer_tokens)

self.prev_lexer_tokens = lexer_tokens # Set the previous lexer tokens

# Parse the tokens
self.time_accepts = 0
parse_incomplete = False
Expand All @@ -44,6 +42,7 @@ def get_acceptable_next_terminals(self, partial_code) -> ParseResult:
# Store the current state of the parser
self._store_parser_state(
self.cur_pos-1,
lexer_tokens,
interactive.parser_state.copy(),
self._accepts(interactive)
)
Expand Down
49 changes: 28 additions & 21 deletions syncode/parsers/incremental_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ def __init__(self, base_parser, logger: Optional[common.Logger]=None, ignore_whi
self.logger = logger if logger is not None else common.EmptyLogger()
self.interactive = self.base_parser.parse_interactive('')
self.parsed_lexer_tokens: list = []
self.prev_lexer_tokens: list[Token] = [] # To enable going back to old state of the parser
self.cur_pos_to_parser_state: dict[int, Tuple[Any, set, set, Optional[list], list]] = {} # parser_state, cur_ac_terminals, next_ac_terminals, indent_levels (optional), dedent_queue

# parser_state, cur_ac_terminals, next_ac_terminals, indent_levels (optional), dedent_queue
self.cur_pos_to_parser_state: dict[int, Tuple[Any, set, set, Optional[list], list]] = {}

self.cur_ac_terminals: set = set()
self.next_ac_terminals: set = self._accepts(self.interactive)

def reset(self):
"""
Resets the parser to the initial state.
"""
self.prev_lexer_tokens = []
self.cur_pos_to_parser_state = {}
self.lexer_pos = 0

Expand All @@ -47,18 +47,21 @@ def _set_initial_parser_state(self):
self.cur_ac_terminals = set()
self.next_ac_terminals = self._accepts(self.interactive)

def _store_parser_state(self, pos: int, parser_state, accepts: set, indent_levels: Optional[list] = None):
def _store_parser_state(self, pos: int, lexer_tokens: Iterable[Token], parser_state, accepts: set, indent_levels: Optional[list] = None):
cur_ac_terminals = self.next_ac_terminals
next_ac_terminals = accepts

# Create a hash of lexer tokens till position pos
key = self._get_hash(lexer_tokens[:pos+1])

# parser_state, cur_ac_terminals, next_ac_terminals, indent_levels, dedent_queue
self.cur_pos_to_parser_state[pos] = (copy.deepcopy(self.parsed_lexer_tokens), parser_state, cur_ac_terminals, next_ac_terminals, indent_levels, copy.deepcopy(self.dedent_queue))
self.cur_pos_to_parser_state[key] = (copy.deepcopy(self.parsed_lexer_tokens), parser_state, cur_ac_terminals, next_ac_terminals, indent_levels, copy.deepcopy(self.dedent_queue))

self.cur_ac_terminals = copy.deepcopy(cur_ac_terminals)
self.next_ac_terminals = copy.deepcopy(next_ac_terminals)

def _restore_parser_state(self, pos: int):
parsed_lexer_tokens, parser_state, cur_ac_terminals, next_ac_terminals, indent_levels, dedent_queue = self.cur_pos_to_parser_state[pos]
def _restore_parser_state(self, key: int):
parsed_lexer_tokens, parser_state, cur_ac_terminals, next_ac_terminals, indent_levels, dedent_queue = self.cur_pos_to_parser_state[key]

self.interactive.parser_state = parser_state.copy()
self.parsed_lexer_tokens = copy.deepcopy(parsed_lexer_tokens)
Expand Down Expand Up @@ -98,20 +101,26 @@ def _restore_recent_parser_state(self, lexer_tokens):
"""
Restores the parser state to the most recent prefix matching state that was stored.
"""
max_matching_index = -1
for i in range(min(len(self.prev_lexer_tokens), len(lexer_tokens))):
if self.prev_lexer_tokens[i] != lexer_tokens[i]:
max_stored_index = -1
idx = len(lexer_tokens)-1

while idx >= 0:
# TODO: This is not the best way to hash the lexer tokens. We should use a better hashing mechanism with some sliding window.
key = self._get_hash(lexer_tokens[:idx+1])
if key in self.cur_pos_to_parser_state:
max_stored_index = idx
break
if i in self.cur_pos_to_parser_state:
max_matching_index = i
idx -= 1

if max_matching_index != -1:
self.cur_pos = max_matching_index + 1
assert (max_matching_index) in self.cur_pos_to_parser_state
self._restore_parser_state(max_matching_index)
if max_stored_index != -1:
self.cur_pos = max_stored_index + 1
key = self._get_hash(lexer_tokens[:max_stored_index+1])
self._restore_parser_state(key)
else:
self._set_initial_parser_state()

def _get_hash(self, lexer_tokens: Iterable[Token]) -> int:
return hash(tuple(lexer_tokens))

def get_acceptable_next_terminals(self, partial_code) -> ParseResult:
"""
Expand All @@ -123,10 +132,7 @@ def get_acceptable_next_terminals(self, partial_code) -> ParseResult:
self.next_ac_terminals = self._accepts(interactive)

# Restore the previous state of the parser
if len(self.prev_lexer_tokens) > 0:
self._restore_recent_parser_state(lexer_tokens)

self.prev_lexer_tokens = lexer_tokens # Set the previous lexer tokens
self._restore_recent_parser_state(lexer_tokens)

# Parse the tokens
self.time_accepts = 0
Expand All @@ -141,7 +147,8 @@ def get_acceptable_next_terminals(self, partial_code) -> ParseResult:

# Store the current state of the parser
self._store_parser_state(
self.cur_pos-1,
self.cur_pos-1,
lexer_tokens,
interactive.parser_state.copy(),
self._accepts(interactive))

Expand Down
4 changes: 2 additions & 2 deletions syncode/parsers/python_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def get_acceptable_next_terminals(self, code) -> ParseResult:
# Restore the previous state of the parser
self._restore_recent_parser_state(lexer_tokens)

self.prev_lexer_tokens = lexer_tokens # Set the previous lexer tokens for retrieving the state of the parser in next iterations
next_ac_indents = None

# Parse the tokens
Expand Down Expand Up @@ -69,7 +68,8 @@ def get_acceptable_next_terminals(self, code) -> ParseResult:

# Store the current state of the parser
self._store_parser_state(
self.cur_pos-1,
self.cur_pos-1,
lexer_tokens,
interactive.parser_state.copy(),
self._accepts(interactive),
indent_levels=copy.copy(self.indent_level)
Expand Down
Loading