From 4424e657431475114fc46cb86889ac08d2ce62a9 Mon Sep 17 00:00:00 2001 From: Rongxin Liu Date: Fri, 29 Sep 2023 16:07:48 -0400 Subject: [PATCH] added support for 'CREATE VIEW' statement --- setup.py | 2 +- src/cs50/sql.py | 24 ++++++++++++++---------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/setup.py b/setup.py index 1a8ef3a..bd33a73 100644 --- a/setup.py +++ b/setup.py @@ -18,5 +18,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="9.2.5" + version="9.2.6" ) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 24690e3..8d07327 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -53,7 +53,7 @@ def __init__(self, url, **kwargs): import sqlalchemy import sqlalchemy.orm import threading - + # Temporary fix for missing sqlite3 module on the buildpack stack try: import sqlite3 @@ -149,15 +149,15 @@ def execute(self, sql, *args, **kwargs): if len(args) > 0 and len(kwargs) > 0: raise RuntimeError("cannot pass both positional and named parameters") - # Infer command from (unflattened) statement - for token in statements[0]: - if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]: - token_value = token.value.upper() - if token_value in ["BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"]: - command = token_value - break - else: - command = None + # Infer command from flattened statement to a single string separated by spaces + full_statement = ' '.join(str(token) for token in statements[0].tokens if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]) + full_statement = full_statement.upper() + + # set of possible commands + commands = {"BEGIN", "CREATE VIEW", "DELETE", "INSERT", "SELECT", "START", "UPDATE"} + + # check if the full_statement starts with any command + command = next((cmd for cmd in commands if full_statement.startswith(cmd)), None) # Flatten statement tokens = list(statements[0].flatten()) @@ -393,6 +393,10 @@ def teardown_appcontext(exception): elif command in ["DELETE", "UPDATE"]: ret = result.rowcount + # If CREATE VIEW, return True + elif command == "CREATE VIEW": + ret = True + # If constraint violated except sqlalchemy.exc.IntegrityError as e: self._logger.error(termcolor.colored(_statement, "red"))