Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SQL grammar issues #107

Merged
merged 1 commit into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions syncode/evaluation/sql_eval.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Optional
from tqdm import tqdm
from mxeval.data import write_jsonl
Expand Down Expand Up @@ -41,11 +42,21 @@ def run_eval(syncode, out_path: Optional[str], debug_task_id: Optional[int] = No
pbar.close()

# Run evaluation script
SQLEval.compute_accuracy(samples, predict_file)

if out_path is not None and debug_task_id is None: write_jsonl(out_path, samples)

@staticmethod
def compute_accuracy(samples, predict_file):
from syncode.utils.sql_spider_eval.evaluation import evaluate
gold_file = "syncode/utils/sql_spider_eval/evaluation_examples/gold_example.txt"
tables = "syncode/utils/sql_spider_eval/evaluation_examples/examples/tables.json"
databses = "syncode/utils/sql_spider_eval/databases"

# Get current dir path
current_dir = os.path.dirname(os.path.realpath(__file__))

# Set paths
gold_file = f"{current_dir}/../utils/sql_spider_eval/evaluation_examples/gold_example.txt"
tables = f"{current_dir}/..//utils/sql_spider_eval/evaluation_examples/examples/tables.json"
databses = f"{current_dir}/..//utils/sql_spider_eval/databases"

scores, error_types = evaluate(predict_file, gold_file, databses, etype="all", table=tables, result_jsonl=samples)
print(f"Scores: {scores['all']}\n Error types: {error_types}")

if out_path is not None and debug_task_id is None: write_jsonl(out_path, samples)
10 changes: 5 additions & 5 deletions syncode/parsers/grammars/sql_grammar.lark
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ TYPENAME: "object"i
| "date"i
| "category"i
| "string"i
AGGREGATION.8: ("sum("i | "avg("i | "min("i | "max("i | "count("i "distinct"i | "count("i)
AGGREGATION.8: ("SUM("i | "AVG("i | "MIN("i | "MAX("i | "COUNT("i "DISTINCT"i | "COUNT("i)
alias: name -> alias_string
_window_name: name
limit_count: integer_ -> limit_count
Expand All @@ -152,8 +152,8 @@ bool_parentheses: comparison_type
comparison_type: equals | not_equals | greater_than | less_than | greater_than_or_equal
| less_than_or_equal | between | in_expr | not_in_expr | subquery_in | is_null | is_not_null
equals: expression_math "=" expression_math
is_null: expression_math "is"i "null"i
is_not_null: expression_math "is"i "not"i "null"i
is_null: expression_math "IS"i "NULL"i
is_not_null: expression_math "IS"i "NOT"i "NULL"i
not_equals: expression_math ("<>" | "!=") expression_math
greater_than: expression_math ">" expression_math
less_than: expression_math "<" expression_math
Expand All @@ -167,8 +167,8 @@ not_in_expr: expression_math "NOT"i "IN"i "(" [expression_math ","]* expression_
| number_expr -> number
| /'([^'])+'|''/ -> string
| timestamp_expression -> timestamp_expression
boolean: "true"i -> true
| "false"i -> false
boolean: "TRUE"i -> true
| "FALSE"i -> false
?number_expr: product

?product: NUMBER
Expand Down
7 changes: 7 additions & 0 deletions tests/test_grammar_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,11 @@ def test_sql_parser2(self):
r = inc_parser.get_acceptable_next_terminals(partial_code)
assert r.remainder == ''
assert r.remainder_state == RemainderState.COMPLETE

def test_sql_parser3(self):
inc_parser.reset()
partial_code = "SELECT stuid FROM has_pet WHERE pettype = 'cat' AND has_pet.stuid NOT"
r = inc_parser.get_acceptable_next_terminals(partial_code)
assert r.remainder == 'NOT'
assert r.remainder_state == RemainderState.MAYBE_COMPLETE

Loading