diff --git a/src/snowflake/connector/pandas_tools.py b/src/snowflake/connector/pandas_tools.py index 4c67c7f23..e32cf17b2 100644 --- a/src/snowflake/connector/pandas_tools.py +++ b/src/snowflake/connector/pandas_tools.py @@ -292,13 +292,14 @@ def write_pandas( def make_pd_writer( - quote_identifiers: bool = True, + **kwargs, ) -> Callable[ [ pandas.io.sql.SQLTable, sqlalchemy.engine.Engine | sqlalchemy.engine.Connection, Iterable, Iterable, + Any, ], None, ]: @@ -311,16 +312,20 @@ def make_pd_writer( sf_connector_version_df = pd.DataFrame([('snowflake-connector-python', '1.0')], columns=['NAME', 'NEWEST_VERSION']) sf_connector_version_df.to_sql('driver_versions', engine, index=False, method=make_pd_writer()) - # to use quote_identifiers=False, + # to use parallel=1, quote_identifiers=False, from functools import partial sf_connector_version_df.to_sql( - 'driver_versions', engine, index=False, method=make_pd_writer(quote_identifiers=False))) + 'driver_versions', engine, index=False, method=make_pd_writer(parallel=1, quote_identifiers=False))) - Args: - quote_identifiers: if True (default), the pd_writer will pass quote identifiers to Snowflake. - If False, the created pd_writer will not quote identifiers (and typically coerced to uppercase by Snowflake) + This function takes arguments used by 'pd_writer' (excluding 'table', 'conn', 'keys', and 'data_iter') + Please refer to 'pd_writer' for documentation. """ - return partial(pd_writer, quote_identifiers=quote_identifiers) + if any(arg in kwargs for arg in ("table", "conn", "keys", "data_iter")): + raise ProgrammingError( + "Arguments 'table', 'conn', 'keys', and 'data_iter' are not supported parameters for make_pd_writer." + ) + + return partial(pd_writer, **kwargs) def pd_writer( @@ -328,7 +333,7 @@ def pd_writer( conn: sqlalchemy.engine.Engine | sqlalchemy.engine.Connection, keys: Iterable, data_iter: Iterable, - quote_identifiers: bool = True, + **kwargs, ) -> None: """This is a wrapper on top of write_pandas to make it compatible with to_sql method in pandas. @@ -339,16 +344,20 @@ def pd_writer( sf_connector_version_df = pd.DataFrame([('snowflake-connector-python', '1.0')], columns=['NAME', 'NEWEST_VERSION']) sf_connector_version_df.to_sql('driver_versions', engine, index=False, method=pd_writer) - # to use quote_identifiers=False, see `make_pd_writer` - Args: table: Pandas package's table object. conn: SQLAlchemy engine object to talk to Snowflake. keys: Column names that we are trying to insert. data_iter: Iterator over the rows. - quote_identifiers: if True (default), quote identifiers passed to Snowflake. If False, identifiers are not - quoted (and typically coerced to uppercase by Snowflake) + + More parameters can be provided to be used by 'write_pandas' (excluding 'conn', 'df', 'table_name', and 'schema'), + Please refer to 'write_pandas' for documentation on other available parameters. """ + if any(arg in kwargs for arg in ("conn", "df", "table_name", "schema")): + raise ProgrammingError( + "Arguments 'conn', 'df', 'table_name', and 'schema' are not supported parameters for pd_writer." + ) + sf_connection = conn.connection.connection df = pandas.DataFrame(data_iter, columns=keys) write_pandas( @@ -357,5 +366,5 @@ def pd_writer( # Note: Our sqlalchemy connector creates tables case insensitively table_name=table.name.upper(), schema=table.schema, - quote_identifiers=quote_identifiers, + **kwargs, )