From f54fce1862da1f7de422957c49f77cd483b26f51 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Mon, 9 Jul 2018 14:40:40 -0700 Subject: [PATCH] create continue_on_error paramter --- .../spark/snowflake/IssueSuite.scala | 4 +- .../spark/snowflake/OnErrorSuite.scala | 69 +++++++++++++++++++ .../spark/snowflake/Parameters.scala | 8 +++ .../spark/snowflake/io/StageWriter.scala | 18 ++++- 4 files changed, 95 insertions(+), 4 deletions(-) create mode 100644 src/it/scala/net/snowflake/spark/snowflake/OnErrorSuite.scala diff --git a/src/it/scala/net/snowflake/spark/snowflake/IssueSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/IssueSuite.scala index 78e7d6f8..bf6640c8 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/IssueSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/IssueSuite.scala @@ -27,7 +27,7 @@ class IssueSuite extends IntegrationSuiteBase { StructField("num", IntegerType, nullable = false) ) ) - val tt: String = "test_table_123"//s"tt_$randomSuffix" + val tt: String = s"tt_$randomSuffix" try { sparkSession.createDataFrame( sparkSession.sparkContext.parallelize( @@ -52,8 +52,6 @@ class IssueSuite extends IntegrationSuiteBase { .option("dbtable", tt) .load() - //loadDf.show() - //print(s"-------------> size: ${loadDf.collect().length}") assert(loadDf.collect().length == 4) } finally { diff --git a/src/it/scala/net/snowflake/spark/snowflake/OnErrorSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/OnErrorSuite.scala new file mode 100644 index 00000000..808c2996 --- /dev/null +++ b/src/it/scala/net/snowflake/spark/snowflake/OnErrorSuite.scala @@ -0,0 +1,69 @@ +package net.snowflake.spark.snowflake + +import net.snowflake.client.jdbc.SnowflakeSQLException +import net.snowflake.spark.snowflake.Utils.SNOWFLAKE_SOURCE_NAME +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.types.{StringType, StructField, StructType} + +class OnErrorSuite extends IntegrationSuiteBase{ + lazy val table = "test_table"//s"spark_test_table_$randomSuffix" + + lazy val schema = new StructType( + Array( + StructField("var", StringType, nullable = false) + ) + ) + + + lazy val df = sqlContext.createDataFrame( + sc.parallelize( + Seq( + Row("{\"dsadas\nadsa\":12311}"), + Row("{\"abc\":334}") + ) //invalid json key + ), + schema + ) + + override def beforeAll(): Unit = { + super.beforeAll() + jdbcUpdate(s"create or replace table $table(var variant)") + } + + override def afterAll(): Unit = { + //jdbcUpdate(s"drop table $table") + super.afterAll() + } + + test("continue_on_error off") { + + assertThrows[SnowflakeSQLException]{ + df.write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", table) + .mode(SaveMode.Append) + .save() + } + } + + test("continue_on_error on") { + df.write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("continue_on_error", "on") + .option("dbtable", table) + .mode(SaveMode.Append) + .save() + + val result = sqlContext.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", table) + .load() + + assert(result.collect().length == 1) + } + + +} diff --git a/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala b/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala index 18b40af5..f9148869 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala @@ -72,6 +72,7 @@ object Parameters { val PARAM_PURGE = knownParam("purge") val PARAM_TRUNCATE_TABLE = knownParam("truncate_table") + val PARAM_CONTINUE_ON_ERROR = knownParam("continue_on_error") val DEFAULT_S3_MAX_FILE_SIZE = (10 * 1000 * 1000).toString val MIN_S3_MAX_FILE_SIZE = 1000000 @@ -101,6 +102,7 @@ object Parameters { // * tempdir, dbtable and url have no default and they *must* be provided "diststyle" -> "EVEN", PARAM_USE_STAGING_TABLE -> "true", + PARAM_CONTINUE_ON_ERROR -> "off", PARAM_PREACTIONS -> "", PARAM_POSTACTIONS -> "", PARAM_AUTO_PUSHDOWN -> "on" @@ -503,6 +505,12 @@ object Parameters { * Keep the table schema */ def truncateTable: Boolean = isTrue(parameters(PARAM_TRUNCATE_TABLE)) + + /** + * Set on_error parameter to continue in COPY command + * todo: create data validation function in spark side instead of using COPY COMMAND + */ + def continueOnError: Boolean = isTrue(parameters(PARAM_CONTINUE_ON_ERROR)) } } diff --git a/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala b/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala index bc4145f8..1bd9fcdc 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala @@ -183,7 +183,20 @@ private[io] object StageWriter { //copy log.debug(Utils.sanitizeQueryText(copyStatement)) - jdbcWrapper.executeInterruptibly(conn, copyStatement) + //todo: handle on_error parameter on spark side + //jdbcWrapper.executeInterruptibly(conn, copyStatement) + + //report the number of skipped files. + val resultSet = jdbcWrapper.executeQueryInterruptibly(conn, copyStatement) + if(params.continueOnError){ + var rowSkipped: Long = 0l + while(resultSet.next()){ + rowSkipped += + resultSet.getLong("rows_parsed") - + resultSet.getLong("rows_loaded") + } + log.error(s"ON_ERROR: Continue -> Skipped $rowSkipped rows") + } Utils.setLastCopyLoad(copyStatement) //post actions @@ -301,6 +314,8 @@ private[io] object StageWriter { val purge = if (params.purge()) "PURGE = TRUE" else "" + val onError = if (params.continueOnError) "ON_ERROR = CONTINUE" else "" + /** TODO(etduwx): Refactor this to be a collection of different options, and use a mapper * function to individually set each file_format and copy option. */ @@ -320,6 +335,7 @@ private[io] object StageWriter { | ) | $truncateCol | $purge + | $onError """.stripMargin.trim }