Skip to content

Commit

Permalink
Improve precision of grammar_strict
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhamugare committed Aug 4, 2024
1 parent b8d1198 commit d681190
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 2 deletions.
7 changes: 7 additions & 0 deletions syncode/dfa_mask_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,13 @@ def _lookup_next_tokens(self, dfa_states: Iterable[DFAState], r: ParseResult) ->
overapprox_token_ids |= self._lookup_table.complete_case_lookup(dfa_state)
elif len(accept_sequence) == 2:
overapprox_token_ids |= self._lookup_next_tokens_for_dfa_state(dfa_state, accept_sequence[1])
elif len(accept_sequence) == 3:
# This is useful in under-approximating `grammar_strict` mode as they help improve the precision of SynCode
if self._mode == 'grammar_strict':
# If the DFA state is a final state we can jump to the start of next terminal
if self._dfas.is_final(dfa_state):
ignore_init_state = self._dfas.initial(accept_sequence[1])
overapprox_token_ids |= self._lookup_next_tokens_for_dfa_state(ignore_init_state, accept_sequence[2])
else:
raise ValueError(f"Invalid accept sequence: {accept_sequence}")
return overapprox_token_ids
Expand Down
10 changes: 8 additions & 2 deletions syncode/parse_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,14 @@ def from_accept_terminals(cur_accept_terminals, next_accept_terminals, remainder
for t2 in next_accept_terminals:
accept_sequences.add(AcceptSequence([final_terminal, t2]))
if ignore_terminals is not None:
for t2 in ignore_terminals:
accept_sequences.add(AcceptSequence([final_terminal, t2]))
for tignore in ignore_terminals:
accept_sequences.add(AcceptSequence([final_terminal, tignore]))

# These 3 length accept sequences are useful in under-approximating
# `grammar_strict` mode as they help improve the precision of SynCode
for tignore in ignore_terminals:
for t2 in next_accept_terminals:
accept_sequences.add(AcceptSequence([final_terminal, tignore, t]))
else:
accept_sequences.add(AcceptSequence([t]))

Expand Down
28 changes: 28 additions & 0 deletions tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,35 @@ def test_mask_store_misc(self):
dfa_mask = DFAMaskStore.load_dfa_mask_store(grammar=grammar, tokenizer=tokenizer, use_cache=False, logger=common.EmptyLogger())
mask = dfa_mask.get_accept_mask(r, get_list=True)
self.assertNotIn(' (', mask)

@staticmethod
def essay_grammar():
# A Lark grammar for paragraphs in text
return """
start: paragraph+
?paragraph: sentence+
?sentence: word+ punctuation
word: /[a-zA-Z0-9]+/ | COMMA | SINGLE_QUOTE | ESCAPED_DOUBLE_QUOTE
punctuation: /[.!?]/
COMMA: ","
SINGLE_QUOTE: "'"
ESCAPED_DOUBLE_QUOTE: "\\\""
%import common.WS
%ignore WS
"""

def test_mask_store_misc2(self):
grammar = Grammar(TestParserMisc.essay_grammar())
model = 'microsoft/Phi-3-mini-128k-instruct'
tokenizer = common.load_tokenizer(model)
inc_parser = create_parser(grammar)
r = inc_parser.get_acceptable_next_terminals("I")
dfa_mask = DFAMaskStore.load_dfa_mask_store(grammar=grammar, tokenizer=tokenizer, use_cache=False, logger=common.EmptyLogger())
mask = dfa_mask.get_accept_mask(r, get_list=True)
self.assertIn(' have', mask)


def test_parser_calc(self):
inc_parser = create_parser(Grammar('calc'))
partial_code = "113 + 235 + 17"
Expand Down

0 comments on commit d681190

Please sign in to comment.