Skip to content

Commit

Permalink
Merge branch 'master' into statistics-fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jolanrensen committed Nov 8, 2024
2 parents 63ee929 + 79bd076 commit 945fb49
Show file tree
Hide file tree
Showing 8 changed files with 302 additions and 122 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ import org.jetbrains.kotlinx.dataframe.api.ParserOptions
import org.jetbrains.kotlinx.dataframe.api.asColumnGroup
import org.jetbrains.kotlinx.dataframe.api.asDataColumn
import org.jetbrains.kotlinx.dataframe.api.cast
import org.jetbrains.kotlinx.dataframe.api.emptyDataFrame
import org.jetbrains.kotlinx.dataframe.api.getColumnsWithPaths
import org.jetbrains.kotlinx.dataframe.api.convert
import org.jetbrains.kotlinx.dataframe.api.isColumnGroup
import org.jetbrains.kotlinx.dataframe.api.isFrameColumn
import org.jetbrains.kotlinx.dataframe.api.isSubtypeOf
import org.jetbrains.kotlinx.dataframe.api.toColumn
import org.jetbrains.kotlinx.dataframe.api.map
import org.jetbrains.kotlinx.dataframe.api.parse
import org.jetbrains.kotlinx.dataframe.api.to
import org.jetbrains.kotlinx.dataframe.api.tryParse
import org.jetbrains.kotlinx.dataframe.columns.TypeSuggestion
import org.jetbrains.kotlinx.dataframe.columns.size
Expand Down Expand Up @@ -531,17 +532,16 @@ internal fun <T> DataColumn<String?>.parse(parser: StringParser<T>, options: Par
)
}

internal fun <T> DataFrame<T>.parseImpl(options: ParserOptions?, columns: ColumnsSelector<T, Any?>): DataFrame<T> {
val convertedCols = getColumnsWithPaths(columns).map { col ->
internal fun <T> DataFrame<T>.parseImpl(options: ParserOptions?, columns: ColumnsSelector<T, Any?>): DataFrame<T> =
convert(columns).to { col ->
when {
// when a frame column is requested to be parsed,
// parse each value/frame column at any depth inside each DataFrame in the frame column
col.isFrameColumn() ->
col.values.map {
it.parseImpl(options) {
colsAtAnyDepth { !it.isColumnGroup() }
}
}.toColumn(col.name)
col.isFrameColumn() -> col.map {
it.parseImpl(options) {
colsAtAnyDepth { !it.isColumnGroup() }
}
}

// when a column group is requested to be parsed,
// parse each column in the group
Expand All @@ -552,11 +552,8 @@ internal fun <T> DataFrame<T>.parseImpl(options: ParserOptions?, columns: Column

// Base case, parse the column if it's a `String?` column
col.isSubtypeOf<String?>() ->
col.cast<String?>().tryParse(options)
col.cast<String?>().tryParseImpl(options)

else -> col
}.let { ColumnToInsert(col.path, it) }
}
}

return emptyDataFrame<T>().insertImpl(convertedCols)
}
68 changes: 37 additions & 31 deletions dataframe-jdbc/api/dataframe-jdbc.api

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dataframe-jdbc/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies {
testImplementation(libs.h2db)
testImplementation(libs.mssql)
testImplementation(libs.junit)
testImplementation(libs.sl4j)
testImplementation(libs.sl4jsimple)
testImplementation(libs.jts)
testImplementation(libs.kotestAssertions) {
exclude("org.jetbrains.kotlin", "kotlin-stdlib-jdk8")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import kotlin.reflect.KType
*
* NOTE: All date and timestamp-related types are converted to String to avoid java.sql.* types.
*/
public class H2(public val dialect: DbType = MySql) : DbType("h2") {
public open class H2(public val dialect: DbType = MySql) : DbType("h2") {
init {
require(dialect::class != H2::class) { "H2 database could not be specified with H2 dialect!" }
}
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ class JdbcTest {
val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName)
dataSchema.columns.size shouldBe 2
dataSchema.columns["characterCol"]!!.type shouldBe typeOf<String?>()

connection.createStatement().execute("DROP TABLE EmptyTestTable")
}

@Test
Expand Down Expand Up @@ -299,6 +301,8 @@ class JdbcTest {
schema.columns["realCol"]!!.type shouldBe typeOf<Float?>()
schema.columns["doublePrecisionCol"]!!.type shouldBe typeOf<Double?>()
schema.columns["decFloatCol"]!!.type shouldBe typeOf<BigDecimal?>()

connection.createStatement().execute("DROP TABLE $tableName")
}

@Test
Expand Down Expand Up @@ -441,7 +445,7 @@ class JdbcTest {

rs.beforeFirst()

val dataSchema1 = DataFrame.getSchemaForResultSet(rs, connection)
val dataSchema1 = DataFrame.getSchemaForResultSet(rs, H2(MySql))
dataSchema1.columns.size shouldBe 3
dataSchema1.columns["name"]!!.type shouldBe typeOf<String?>()
}
Expand Down Expand Up @@ -493,7 +497,7 @@ class JdbcTest {

rs.beforeFirst()

val dataSchema1 = rs.getDataFrameSchema(connection)
val dataSchema1 = rs.getDataFrameSchema(H2(MySql))
dataSchema1.columns.size shouldBe 3
dataSchema1.columns["name"]!!.type shouldBe typeOf<String?>()
}
Expand Down Expand Up @@ -613,6 +617,7 @@ class JdbcTest {
"""

DataFrame.readSqlQuery(connection, selectFromWeirdTableSQL).rowsCount() shouldBe 0
connection.createStatement().execute("DROP TABLE \"ALTER\"")
}

@Test
Expand Down Expand Up @@ -967,4 +972,127 @@ class JdbcTest {
}
exception.message shouldBe "H2 database could not be specified with H2 dialect!"
}

// helper object created for API testing purposes
object CustomDB : H2(MySql)

@Test
fun `read from table from custom database`() {
val tableName = "Customer"
val df = DataFrame.readSqlTable(connection, tableName, dbType = CustomDB).cast<Customer>()

df.rowsCount() shouldBe 4
df.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2
df[0][1] shouldBe "John"

val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName, dbType = CustomDB)
dataSchema.columns.size shouldBe 3
dataSchema.columns["name"]!!.type shouldBe typeOf<String?>()

val dbConfig = DbConnectionConfig(url = URL)
val df2 = DataFrame.readSqlTable(dbConfig, tableName, dbType = CustomDB).cast<Customer>()

df2.rowsCount() shouldBe 4
df2.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2
df2[0][1] shouldBe "John"

val dataSchema1 = DataFrame.getSchemaForSqlTable(dbConfig, tableName, dbType = CustomDB)
dataSchema1.columns.size shouldBe 3
dataSchema1.columns["name"]!!.type shouldBe typeOf<String?>()
}

@Test
fun `read from query from custom database`() {
@Language("SQL")
val sqlQuery =
"""
SELECT c.name as customerName, SUM(s.amount) as totalSalesAmount
FROM Sale s
INNER JOIN Customer c ON s.customerId = c.id
WHERE c.age > 35
GROUP BY s.customerId, c.name
""".trimIndent()

val df = DataFrame.readSqlQuery(connection, sqlQuery, dbType = CustomDB).cast<CustomerSales>()

df.rowsCount() shouldBe 2
df.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1
df[0][0] shouldBe "John"

val dataSchema = DataFrame.getSchemaForSqlQuery(connection, sqlQuery, dbType = CustomDB)
dataSchema.columns.size shouldBe 2
dataSchema.columns["name"]!!.type shouldBe typeOf<String?>()

val dbConfig = DbConnectionConfig(url = URL)
val df2 = DataFrame.readSqlQuery(dbConfig, sqlQuery, dbType = CustomDB).cast<CustomerSales>()

df2.rowsCount() shouldBe 2
df2.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1
df2[0][0] shouldBe "John"

val dataSchema1 = DataFrame.getSchemaForSqlQuery(dbConfig, sqlQuery, dbType = CustomDB)
dataSchema1.columns.size shouldBe 2
dataSchema1.columns["name"]!!.type shouldBe typeOf<String?>()
}

@Test
fun `read from all tables from custom database`() {
val dataFrameMap = DataFrame.readAllSqlTables(connection, dbType = CustomDB)
dataFrameMap.containsKey("Customer") shouldBe true
dataFrameMap.containsKey("Sale") shouldBe true

val dataframes = dataFrameMap.values.toList()

val customerDf = dataframes[0].cast<Customer>()

customerDf.rowsCount() shouldBe 4
customerDf.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2
customerDf[0][1] shouldBe "John"

val saleDf = dataframes[1].cast<Sale>()

saleDf.rowsCount() shouldBe 4
saleDf.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 3
(saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0

val dataFrameSchemaMap = DataFrame.getSchemaForAllSqlTables(connection, dbType = CustomDB)
dataFrameSchemaMap.containsKey("Customer") shouldBe true
dataFrameSchemaMap.containsKey("Sale") shouldBe true

val dataSchemas = dataFrameSchemaMap.values.toList()

val customerDataSchema = dataSchemas[0]
customerDataSchema.columns.size shouldBe 3
customerDataSchema.columns["name"]!!.type shouldBe typeOf<String?>()

val saleDataSchema = dataSchemas[1]
saleDataSchema.columns.size shouldBe 3
// TODO: fix nullability
saleDataSchema.columns["amount"]!!.type shouldBe typeOf<BigDecimal>()

val dbConfig = DbConnectionConfig(url = URL)
val dataframes2 = DataFrame.readAllSqlTables(dbConfig, dbType = CustomDB).values.toList()

val customerDf2 = dataframes2[0].cast<Customer>()

customerDf2.rowsCount() shouldBe 4
customerDf2.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2
customerDf2[0][1] shouldBe "John"

val saleDf2 = dataframes2[1].cast<Sale>()

saleDf2.rowsCount() shouldBe 4
saleDf2.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 3
(saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0

val dataSchemas1 = DataFrame.getSchemaForAllSqlTables(dbConfig, dbType = CustomDB).values.toList()

val customerDataSchema1 = dataSchemas1[0]
customerDataSchema1.columns.size shouldBe 3
customerDataSchema1.columns["name"]!!.type shouldBe typeOf<String?>()

val saleDataSchema1 = dataSchemas1[1]
saleDataSchema1.columns.size shouldBe 3
saleDataSchema1.columns["amount"]!!.type shouldBe typeOf<BigDecimal>()
}
}
3 changes: 2 additions & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ kotlinpoet = { group = "com.squareup", name = "kotlinpoet", version.ref = "kotli
swagger = { group = "io.swagger.parser.v3", name = "swagger-parser", version.ref = "openapi" }

kotlinLogging = { group = "io.github.oshai", name = "kotlin-logging", version.ref = "kotlinLogging" }
sl4j = { group = "org.slf4j", name = "slf4j-simple", version.ref = "sl4j" }
sl4j = { group = "org.slf4j", name = "slf4j-api", version.ref = "sl4j" }
sl4jsimple = { group = "org.slf4j", name = "slf4j-simple", version.ref = "sl4j" }
android-gradle-api = { group = "com.android.tools.build", name = "gradle-api", version.ref = "android-gradle-api" }
android-gradle = { group = "com.android.tools.build", name = "gradle", version.ref = "android-gradle-api" }
kotlin-gradle-plugin = { group = "org.jetbrains.kotlin", name = "kotlin-gradle-plugin" }
Expand Down

0 comments on commit 945fb49

Please sign in to comment.