diff --git a/syncode/dfa_mask_store.py b/syncode/dfa_mask_store.py index a51f815b..402d7855 100644 --- a/syncode/dfa_mask_store.py +++ b/syncode/dfa_mask_store.py @@ -453,6 +453,13 @@ def _lookup_next_tokens(self, dfa_states: Iterable[DFAState], r: ParseResult) -> overapprox_token_ids |= self._lookup_table.complete_case_lookup(dfa_state) elif len(accept_sequence) == 2: overapprox_token_ids |= self._lookup_next_tokens_for_dfa_state(dfa_state, accept_sequence[1]) + elif len(accept_sequence) == 3: + # This is useful in under-approximating `grammar_strict` mode as they help improve the precision of SynCode + if self._mode == 'grammar_strict': + # If the DFA state is a final state we can jump to the start of next terminal + if self._dfas.is_final(dfa_state): + ignore_init_state = self._dfas.initial(accept_sequence[1]) + overapprox_token_ids |= self._lookup_next_tokens_for_dfa_state(ignore_init_state, accept_sequence[2]) else: raise ValueError(f"Invalid accept sequence: {accept_sequence}") return overapprox_token_ids diff --git a/syncode/parse_result.py b/syncode/parse_result.py index 26b1aeb3..830aca5e 100644 --- a/syncode/parse_result.py +++ b/syncode/parse_result.py @@ -62,8 +62,14 @@ def from_accept_terminals(cur_accept_terminals, next_accept_terminals, remainder for t2 in next_accept_terminals: accept_sequences.add(AcceptSequence([final_terminal, t2])) if ignore_terminals is not None: - for t2 in ignore_terminals: - accept_sequences.add(AcceptSequence([final_terminal, t2])) + for tignore in ignore_terminals: + accept_sequences.add(AcceptSequence([final_terminal, tignore])) + + # These 3 length accept sequences are useful in under-approximating + # `grammar_strict` mode as they help improve the precision of SynCode + for tignore in ignore_terminals: + for t2 in next_accept_terminals: + accept_sequences.add(AcceptSequence([final_terminal, tignore, t])) else: accept_sequences.add(AcceptSequence([t])) diff --git a/tests/test_misc.py b/tests/test_misc.py index 17a9f617..ab74a2ae 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -32,7 +32,35 @@ def test_mask_store_misc(self): dfa_mask = DFAMaskStore.load_dfa_mask_store(grammar=grammar, tokenizer=tokenizer, use_cache=False, logger=common.EmptyLogger()) mask = dfa_mask.get_accept_mask(r, get_list=True) self.assertNotIn(' (', mask) + + @staticmethod + def essay_grammar(): + # A Lark grammar for paragraphs in text + return """ + start: paragraph+ + ?paragraph: sentence+ + ?sentence: word+ punctuation + word: /[a-zA-Z0-9]+/ | COMMA | SINGLE_QUOTE | ESCAPED_DOUBLE_QUOTE + punctuation: /[.!?]/ + COMMA: "," + SINGLE_QUOTE: "'" + ESCAPED_DOUBLE_QUOTE: "\\\"" + + %import common.WS + %ignore WS + """ + + def test_mask_store_misc2(self): + grammar = Grammar(TestParserMisc.essay_grammar()) + model = 'microsoft/Phi-3-mini-128k-instruct' + tokenizer = common.load_tokenizer(model) + inc_parser = create_parser(grammar) + r = inc_parser.get_acceptable_next_terminals("I") + dfa_mask = DFAMaskStore.load_dfa_mask_store(grammar=grammar, tokenizer=tokenizer, use_cache=False, logger=common.EmptyLogger()) + mask = dfa_mask.get_accept_mask(r, get_list=True) + self.assertIn(' have', mask) + def test_parser_calc(self): inc_parser = create_parser(Grammar('calc')) partial_code = "113 + 235 + 17"