From bc6960d593a535c54344482f16e3bdba278ebaaf Mon Sep 17 00:00:00 2001 From: Shubham Ugare Date: Tue, 3 Sep 2024 11:51:27 -0500 Subject: [PATCH] Fix SQL grammar issues --- syncode/evaluation/sql_eval.py | 21 ++++++++++++++++----- syncode/parsers/grammars/sql_grammar.lark | 10 +++++----- tests/test_grammar_sql.py | 7 +++++++ 3 files changed, 28 insertions(+), 10 deletions(-) diff --git a/syncode/evaluation/sql_eval.py b/syncode/evaluation/sql_eval.py index e1d02441..911fcf15 100644 --- a/syncode/evaluation/sql_eval.py +++ b/syncode/evaluation/sql_eval.py @@ -1,3 +1,4 @@ +import os from typing import Optional from tqdm import tqdm from mxeval.data import write_jsonl @@ -41,11 +42,21 @@ def run_eval(syncode, out_path: Optional[str], debug_task_id: Optional[int] = No pbar.close() # Run evaluation script + SQLEval.compute_accuracy(samples, predict_file) + + if out_path is not None and debug_task_id is None: write_jsonl(out_path, samples) + + @staticmethod + def compute_accuracy(samples, predict_file): from syncode.utils.sql_spider_eval.evaluation import evaluate - gold_file = "syncode/utils/sql_spider_eval/evaluation_examples/gold_example.txt" - tables = "syncode/utils/sql_spider_eval/evaluation_examples/examples/tables.json" - databses = "syncode/utils/sql_spider_eval/databases" + + # Get current dir path + current_dir = os.path.dirname(os.path.realpath(__file__)) + + # Set paths + gold_file = f"{current_dir}/../utils/sql_spider_eval/evaluation_examples/gold_example.txt" + tables = f"{current_dir}/..//utils/sql_spider_eval/evaluation_examples/examples/tables.json" + databses = f"{current_dir}/..//utils/sql_spider_eval/databases" + scores, error_types = evaluate(predict_file, gold_file, databses, etype="all", table=tables, result_jsonl=samples) print(f"Scores: {scores['all']}\n Error types: {error_types}") - - if out_path is not None and debug_task_id is None: write_jsonl(out_path, samples) diff --git a/syncode/parsers/grammars/sql_grammar.lark b/syncode/parsers/grammars/sql_grammar.lark index 81ba167a..64f01f76 100644 --- a/syncode/parsers/grammars/sql_grammar.lark +++ b/syncode/parsers/grammars/sql_grammar.lark @@ -138,7 +138,7 @@ TYPENAME: "object"i | "date"i | "category"i | "string"i -AGGREGATION.8: ("sum("i | "avg("i | "min("i | "max("i | "count("i "distinct"i | "count("i) +AGGREGATION.8: ("SUM("i | "AVG("i | "MIN("i | "MAX("i | "COUNT("i "DISTINCT"i | "COUNT("i) alias: name -> alias_string _window_name: name limit_count: integer_ -> limit_count @@ -152,8 +152,8 @@ bool_parentheses: comparison_type comparison_type: equals | not_equals | greater_than | less_than | greater_than_or_equal | less_than_or_equal | between | in_expr | not_in_expr | subquery_in | is_null | is_not_null equals: expression_math "=" expression_math -is_null: expression_math "is"i "null"i -is_not_null: expression_math "is"i "not"i "null"i +is_null: expression_math "IS"i "NULL"i +is_not_null: expression_math "IS"i "NOT"i "NULL"i not_equals: expression_math ("<>" | "!=") expression_math greater_than: expression_math ">" expression_math less_than: expression_math "<" expression_math @@ -167,8 +167,8 @@ not_in_expr: expression_math "NOT"i "IN"i "(" [expression_math ","]* expression_ | number_expr -> number | /'([^'])+'|''/ -> string | timestamp_expression -> timestamp_expression -boolean: "true"i -> true - | "false"i -> false +boolean: "TRUE"i -> true + | "FALSE"i -> false ?number_expr: product ?product: NUMBER diff --git a/tests/test_grammar_sql.py b/tests/test_grammar_sql.py index beb244f1..1490e494 100644 --- a/tests/test_grammar_sql.py +++ b/tests/test_grammar_sql.py @@ -26,4 +26,11 @@ def test_sql_parser2(self): r = inc_parser.get_acceptable_next_terminals(partial_code) assert r.remainder == '' assert r.remainder_state == RemainderState.COMPLETE + + def test_sql_parser3(self): + inc_parser.reset() + partial_code = "SELECT stuid FROM has_pet WHERE pettype = 'cat' AND has_pet.stuid NOT" + r = inc_parser.get_acceptable_next_terminals(partial_code) + assert r.remainder == 'NOT' + assert r.remainder_state == RemainderState.MAYBE_COMPLETE \ No newline at end of file