Skip to content

Commit

Permalink
Merge pull request #51 from binglihub/truncate_table
Browse files Browse the repository at this point in the history
Truncate table
  • Loading branch information
binglihub authored Jul 9, 2018
2 parents f7653aa + 41a3cc5 commit b52021a
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 155 deletions.
13 changes: 8 additions & 5 deletions src/it/scala/net/snowflake/spark/snowflake/IssueSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,19 @@ class IssueSuite extends IntegrationSuiteBase {

test("csv delimiter character should not break rows"){
val st1 = new StructType(
Array(StructField("str", StringType, nullable = true))
Array(
StructField("str", StringType, nullable = false),
StructField("num", IntegerType, nullable = false)
)
)
val tt: String = "test_table_123"//s"tt_$randomSuffix"
try {
sparkSession.createDataFrame(
sparkSession.sparkContext.parallelize(
Seq(Row("\"\n\""),
Row("\"|\""),
Row("\",\""),
Row("\n")
Seq(Row("\"\n\"",123),
Row("\"|\"",223),
Row("\",\"",345),
Row("\n",423)
)
),
st1
Expand Down
217 changes: 217 additions & 0 deletions src/it/scala/net/snowflake/spark/snowflake/TruncateTableSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
package net.snowflake.spark.snowflake

import net.snowflake.spark.snowflake.Utils.SNOWFLAKE_SOURCE_NAME
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SaveMode}

import scala.util.Random

class TruncateTableSuite extends IntegrationSuiteBase {
val table = s"test_table_$randomSuffix"

lazy val st1 = new StructType(
Array(
StructField("num1", LongType, nullable = false),
StructField("num2", FloatType, nullable = false)
)
)

lazy val df1 = sqlContext.createDataFrame(
sc.parallelize(1 to 100).map[Row](
_ => {
val rand = new Random(System.nanoTime())
Row(rand.nextLong(), rand.nextFloat())
}
),
st1
)

lazy val st2 = new StructType(
Array(
StructField("num1", IntegerType, nullable = false),
StructField("num2", IntegerType, nullable = false)
)
)

lazy val df2 = sqlContext.createDataFrame(
sc.parallelize(1 to 100).map[Row](
_ => {
val rand = new Random(System.nanoTime())
Row(rand.nextInt(), rand.nextInt())
}
),
st2
)

override def beforeAll(): Unit = {
super.beforeAll()
}

test("use truncate table with staging table") {

jdbcUpdate(s"drop table if exists $table")

//create one table
df2.write.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", table)
.option("truncate_table", "off")
.option("usestagingtable", "on")
.mode(SaveMode.Overwrite)
.save()

//replace previous table and overwrite schema
df1.write.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", table)
.option("truncate_table", "off")
.option("usestagingtable", "on")
.mode(SaveMode.Overwrite)
.save()

//truncate previous table and keep schema
df2.write.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", table)
.option("truncate_table", "on")
.option("usestagingtable", "on")
.mode(SaveMode.Overwrite)
.save()

//check schema
assert(checkSchema1())

}

test("use truncate table without staging table") {

jdbcUpdate(s"drop table if exists $table")

//create table
df2.write.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", table)
.option("truncate_table", "off")
.option("usestagingtable", "off")
.mode(SaveMode.Overwrite)
.save()

//replace previous table and overwrite schema
df1.write.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", table)
.option("truncate_table", "off")
.option("usestagingtable", "off")
.mode(SaveMode.Overwrite)
.save()

//truncate table and keep schema
df2.write.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", table)
.option("truncate_table", "on")
.option("usestagingtable", "off")
.mode(SaveMode.Overwrite)
.save()

//checker schema
assert(checkSchema1())

}

test("don't truncate table with staging table") {

jdbcUpdate(s"drop table if exists $table")

//create one table
df2.write.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", table)
.option("truncate_table", "off")
.option("usestagingtable", "on")
.mode(SaveMode.Overwrite)
.save()

//replace previous table and overwrite schema
df1.write.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", table)
.option("truncate_table", "off")
.option("usestagingtable", "on")
.mode(SaveMode.Overwrite)
.save()

//truncate previous table and overwrite schema
df2.write.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", table)
.option("truncate_table", "off")
.option("usestagingtable", "on")
.mode(SaveMode.Overwrite)
.save()

//check schema
assert(checkSchema2())
}
test("don't truncate table without staging table") {

jdbcUpdate(s"drop table if exists $table")

//create one table
df2.write.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", table)
.option("truncate_table", "off")
.option("usestagingtable", "off")
.mode(SaveMode.Overwrite)
.save()

//replace previous table and overwrite schema
df1.write.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", table)
.option("truncate_table", "off")
.option("usestagingtable", "off")
.mode(SaveMode.Overwrite)
.save()

//truncate previous table and overwrite schema
df2.write.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", table)
.option("truncate_table", "off")
.option("usestagingtable", "off")
.mode(SaveMode.Overwrite)
.save()

//check schema
assert(checkSchema2())
}

def checkSchema2(): Boolean = {
val st = DefaultJDBCWrapper.resolveTable(conn, table)
val st1 = new StructType(
Array(
StructField("NUM1", DecimalType(38, 0), nullable = false),
StructField("NUM2", DecimalType(38, 0), nullable = false)
)
)
st.equals(st1)
}

def checkSchema1(): Boolean = {
val st = DefaultJDBCWrapper.resolveTable(conn, table)
val st1 = new StructType(
Array(
StructField("NUM1", DecimalType(38, 0), nullable = false),
StructField("NUM2", DoubleType, nullable = false)
)
)
st.equals(st1)
}

override def afterAll(): Unit = {
jdbcUpdate(s"drop table if exists $table")
super.afterAll()
}
}
11 changes: 9 additions & 2 deletions src/main/scala/net/snowflake/spark/snowflake/Parameters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ object Parameters {
val PARAM_TRUNCATE_COLUMNS = knownParam("truncate_columns")
val PARAM_PURGE = knownParam("purge")

val PARAM_TRUNCATE_TABLE = knownParam("truncate_table")

val DEFAULT_S3_MAX_FILE_SIZE = (10 * 1000 * 1000).toString
val MIN_S3_MAX_FILE_SIZE = 1000000

Expand Down Expand Up @@ -98,7 +100,7 @@ object Parameters {
// Notes:
// * tempdir, dbtable and url have no default and they *must* be provided
"diststyle" -> "EVEN",
"usestagingtable" -> "true",
PARAM_USE_STAGING_TABLE -> "true",
PARAM_PREACTIONS -> "",
PARAM_POSTACTIONS -> "",
PARAM_AUTO_PUSHDOWN -> "on"
Expand Down Expand Up @@ -441,7 +443,7 @@ object Parameters {
* Defaults to true.
*/
def useStagingTable: Boolean =
parameters(PARAM_USE_STAGING_TABLE).toBoolean
isTrue(parameters(PARAM_USE_STAGING_TABLE))

/**
* Extra options to append to the Snowflake COPY command (e.g. "MAXERROR 100").
Expand Down Expand Up @@ -496,6 +498,11 @@ object Parameters {
yield
new StorageCredentialsSharedAccessSignature(sas)
}
/**
* Truncate table when overwriting.
* Keep the table schema
*/
def truncateTable: Boolean = isTrue(parameters(PARAM_TRUNCATE_TABLE))
}
}

Expand Down
22 changes: 14 additions & 8 deletions src/main/scala/net/snowflake/spark/snowflake/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -304,26 +304,32 @@ object Utils {
""".stripMargin.trim
}

private [snowflake] def executePreActions(jdbcWrapper: JDBCWrapper,
conn: Connection,
params: MergedParameters) : Unit = {
private [snowflake] def executePreActions(
jdbcWrapper: JDBCWrapper,
conn: Connection,
params: MergedParameters,
table: Option[TableName]
) : Unit = {
// Execute preActions
params.preActions.foreach { action =>
if (action != null && !action.trim.isEmpty) {
val actionSql = if (action.contains("%s")) action.format(params.table.get) else action
val actionSql = if (action.contains("%s")) action.format(table.get) else action
log.info("Executing preAction: " + actionSql)
jdbcWrapper.executePreparedInterruptibly(conn.prepareStatement(actionSql))
}
}
}

private [snowflake] def executePostActions(jdbcWrapper: JDBCWrapper,
conn: Connection,
params: MergedParameters) : Unit = {
private [snowflake] def executePostActions(
jdbcWrapper: JDBCWrapper,
conn: Connection,
params: MergedParameters,
table: Option[TableName]
) : Unit = {
// Execute preActions
params.postActions.foreach { action =>
if (action != null && !action.trim.isEmpty) {
val actionSql = if (action.contains("%s")) action.format(params.table.get) else action
val actionSql = if (action.contains("%s")) action.format(table.get) else action
log.info("Executing postAction: " + actionSql)
jdbcWrapper.executePreparedInterruptibly(conn.prepareStatement(actionSql))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ private[io] trait DataUnloader {
log.debug(Utils.sanitizeQueryText(prologueSql))
jdbcWrapper.executeInterruptibly(conn, prologueSql)

Utils.executePreActions(jdbcWrapper, conn, params)
Utils.executePreActions(jdbcWrapper, conn, params, params.table)

// Run the unload query
log.debug(Utils.sanitizeQueryText(sql))
Expand All @@ -61,7 +61,7 @@ private[io] trait DataUnloader {
val second = res.next()
assert(!second)

Utils.executePostActions(jdbcWrapper, conn, params)
Utils.executePostActions(jdbcWrapper, conn, params, params.table)
numRows
} finally {
SnowflakeTelemetry.send(jdbcWrapper.getTelemetry(conn))
Expand Down
Loading

0 comments on commit b52021a

Please sign in to comment.