Skip to content

Commit c4f7549

Browse files
committed
fixed duckdb tests
1 parent d1b655b commit c4f7549

File tree

3 files changed

+85
-52
lines changed

3 files changed

+85
-52
lines changed

dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DuckDb.kt

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ import java.util.Properties
5858
import kotlin.collections.toList
5959
import kotlin.reflect.KTypeProjection
6060
import kotlin.reflect.full.createType
61+
import kotlin.reflect.full.withNullability
6162
import kotlin.time.Instant
6263
import kotlin.time.toKotlinInstant
6364
import kotlin.uuid.Uuid
@@ -159,15 +160,14 @@ public object DuckDb : DbType("duckdb") {
159160
val (key, value) = parseMapTypes(sqlTypeName)
160161

161162
val parsedKeyType = parseDuckDbType(key, false)
162-
val parsedValueType =
163-
parseDuckDbType(value, true).cast<Any, Any, Any>()
163+
val parsedValueType = parseDuckDbType(value, true).castToAny()
164164

165165
val targetMapType = Map::class.createType(
166166
listOf(
167167
KTypeProjection.invariant(parsedKeyType.targetSchema.type),
168168
KTypeProjection.invariant(parsedValueType.targetSchema.type),
169169
),
170-
)
170+
).withNullability(isNullable)
171171

172172
typeInformationWithPreprocessingForValueColumnOf<Map<String, Any?>, Map<String, Any?>>(
173173
targetColumnType = targetMapType,
@@ -186,8 +186,12 @@ public object DuckDb : DbType("duckdb") {
186186
parseDuckDbType(listType, true).castToAny()
187187

188188
val targetListType = List::class.createType(
189-
listOf(KTypeProjection.invariant(parsedListType.targetSchema.type)),
190-
)
189+
listOf(
190+
KTypeProjection.invariant(
191+
parsedListType.targetSchema.type,
192+
),
193+
),
194+
).withNullability(isNullable)
191195

192196
// todo maybe List<DataRow> should become FrameColumn
193197
typeInformationWithPreprocessingFor<SqlArray, List<Any?>>(

dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/commonTestScenarios.kt

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
package org.jetbrains.kotlinx.dataframe.io
22

3+
import io.kotest.assertions.Actual
4+
import io.kotest.assertions.AssertionFailedError
5+
import io.kotest.assertions.Exceptions
6+
import io.kotest.assertions.Expected
7+
import io.kotest.assertions.failure
8+
import io.kotest.assertions.print.printed
39
import io.kotest.assertions.withClue
410
import io.kotest.matchers.shouldBe
511
import org.intellij.lang.annotations.Language
@@ -144,17 +150,11 @@ internal fun inferNullability(connection: Connection) {
144150
*/
145151
@Suppress("INVISIBLE_REFERENCE")
146152
fun AnyFrame.assertInferredTypesMatchSchema() {
147-
withClue({
148-
"""
149-
|Inferred schema must be <: Provided schema
150-
|
151-
|Inferred Schema:
152-
|${inferType().schema().toString().lines().joinToString("\n|")}
153-
|
154-
|Provided Schema:
155-
|${schema().toString().lines().joinToString("\n|")}
156-
""".trimMargin()
157-
}) {
158-
schema().compare(inferType().schema()).isSuperOrMatches() shouldBe true
153+
if (!schema().compare(inferType().schema()).isSuperOrMatches()) {
154+
throw failure(
155+
expected = Expected(inferType().schema().toString().lines().sorted().joinToString("\n").printed()),
156+
actual = Actual(schema().toString().lines().sorted().joinToString("\n").printed()),
157+
prependMessage = "Inferred schema must be <: Provided schema",
158+
)
159159
}
160160
}

dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/local/duckDbTest.kt

Lines changed: 64 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
package org.jetbrains.kotlinx.dataframe.io.local
44

5+
import io.kotest.assertions.withClue
56
import io.kotest.matchers.shouldBe
7+
import kotlinx.datetime.LocalDate
8+
import kotlinx.datetime.LocalTime
69
import org.duckdb.DuckDBConnection
710
import org.duckdb.DuckDBResultSet
811
import org.duckdb.JsonNode
@@ -36,11 +39,12 @@ import java.nio.file.Files
3639
import java.sql.Blob
3740
import java.sql.DriverManager
3841
import java.sql.Timestamp
39-
import java.time.LocalDate
40-
import java.time.LocalTime
41-
import java.time.OffsetDateTime
4242
import java.util.UUID
4343
import kotlin.io.path.createTempDirectory
44+
import kotlin.time.Instant
45+
import kotlin.time.toKotlinInstant
46+
import kotlin.uuid.Uuid
47+
import java.time.OffsetDateTime as JavaOffsetDateTime
4448

4549
private const val URL = "jdbc:duckdb:"
4650

@@ -58,10 +62,10 @@ class DuckDbTest {
5862
) {
5963
companion object {
6064
val expected = listOf(
61-
Person(1, "John Doe", 30, 50000.0, LocalDate.of(2020, 1, 15)),
62-
Person(2, "Jane Smith", 28, 55000.0, LocalDate.of(2021, 3, 20)),
63-
Person(3, "Bob Johnson", 35, 65000.0, LocalDate.of(2019, 11, 10)),
64-
Person(4, "Alice Brown", 32, 60000.0, LocalDate.of(2020, 7, 1)),
65+
Person(1, "John Doe", 30, 50000.0, LocalDate(2020, 1, 15)),
66+
Person(2, "Jane Smith", 28, 55000.0, LocalDate(2021, 3, 20)),
67+
Person(3, "Bob Johnson", 35, 65000.0, LocalDate(2019, 11, 10)),
68+
Person(4, "Alice Brown", 32, 60000.0, LocalDate(2020, 7, 1)),
6569
).toDataFrame()
6670
}
6771
}
@@ -91,7 +95,7 @@ class DuckDbTest {
9195
@ColumnName("date_col")
9296
val dateCol: LocalDate,
9397
@ColumnName("datetime_col")
94-
val datetimeCol: Timestamp,
98+
val datetimeCol: Instant,
9599
@ColumnName("decimal_col")
96100
val decimalCol: BigDecimal,
97101
@ColumnName("double_col")
@@ -151,11 +155,11 @@ class DuckDbTest {
151155
@ColumnName("time_col")
152156
val timeCol: LocalTime,
153157
@ColumnName("timestamp_col")
154-
val timestampCol: Timestamp,
158+
val timestampCol: Instant,
155159
@ColumnName("timestamptz_col")
156-
val timestamptzCol: OffsetDateTime,
160+
val timestamptzCol: JavaOffsetDateTime,
157161
@ColumnName("timestampwtz_col")
158-
val timestampwtzCol: OffsetDateTime,
162+
val timestampwtzCol: JavaOffsetDateTime,
159163
@ColumnName("tinyint_col")
160164
val tinyintCol: Byte,
161165
@ColumnName("ubigint_col")
@@ -179,7 +183,7 @@ class DuckDbTest {
179183
@ColumnName("utinyint_col")
180184
val utinyintCol: Short,
181185
@ColumnName("uuid_col")
182-
val uuidCol: UUID,
186+
val uuidCol: Uuid,
183187
@ColumnName("varbinary_col")
184188
val varbinaryCol: Blob,
185189
@ColumnName("varchar_col")
@@ -199,7 +203,7 @@ class DuckDbTest {
199203
byteaCol = DuckDBResultSet.DuckDBBlobResult(ByteBuffer.wrap("DEADBEEF".toByteArray())),
200204
charCol = "test",
201205
dateCol = LocalDate.parse("2025-06-19"),
202-
datetimeCol = Timestamp.valueOf("2025-06-19 12:34:56"),
206+
datetimeCol = Timestamp.valueOf("2025-06-19 12:34:56").toInstant().toKotlinInstant(),
203207
decimalCol = BigDecimal("123.45"),
204208
doubleCol = 3.14159,
205209
enumCol = "female",
@@ -229,9 +233,9 @@ class DuckDbTest {
229233
stringCol = "test string",
230234
textCol = "test text",
231235
timeCol = LocalTime.parse("12:34:56"),
232-
timestampCol = Timestamp.valueOf("2025-06-19 12:34:56"),
233-
timestamptzCol = OffsetDateTime.parse("2025-06-19T12:34:56+02:00"),
234-
timestampwtzCol = OffsetDateTime.parse("2025-06-19T12:34:56+02:00"),
236+
timestampCol = Timestamp.valueOf("2025-06-19 12:34:56").toInstant().toKotlinInstant(),
237+
timestamptzCol = JavaOffsetDateTime.parse("2025-06-19T12:34:56+02:00"),
238+
timestampwtzCol = JavaOffsetDateTime.parse("2025-06-19T12:34:56+02:00"),
235239
tinyintCol = 127,
236240
ubigintCol = BigInteger("18446744073709551615"),
237241
uhugeintCol = BigInteger("340282366920938463463374607431768211455"),
@@ -243,7 +247,7 @@ class DuckDbTest {
243247
uintCol = 4294967295L,
244248
usmallintCol = 65535,
245249
utinyintCol = 255,
246-
uuidCol = UUID.fromString("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"),
250+
uuidCol = Uuid.parse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"),
247251
varbinaryCol = DuckDBResultSet.DuckDBBlobResult(ByteBuffer.wrap("DEADBEEF".toByteArray())),
248252
varcharCol = "test string",
249253
),
@@ -254,21 +258,21 @@ class DuckDbTest {
254258
@DataSchema
255259
data class NestedTypes(
256260
@ColumnName("ijstruct_col")
257-
val ijstructCol: java.sql.Struct,
261+
val ijstructCol: java.sql.Struct, // TODO
258262
@ColumnName("intarray_col")
259-
val intarrayCol: java.sql.Array,
263+
val intarrayCol: List<Int?>,
260264
@ColumnName("intlist_col")
261-
val intlistCol: java.sql.Array,
265+
val intlistCol: List<Int?>,
262266
@ColumnName("intstringmap_col")
263267
val intstringmapCol: Map<Int, String?>,
264268
@ColumnName("intstrinstinggmap_col")
265269
val intstrinstinggmapCol: Map<Int, Map<String, String?>?>,
266270
@ColumnName("stringarray_col")
267-
val stringarrayCol: java.sql.Array,
271+
val stringarrayCol: List<String?>,
268272
@ColumnName("stringlist_col")
269-
val stringlistCol: java.sql.Array,
273+
val stringlistCol: List<String?>,
270274
@ColumnName("stringlistlist_col")
271-
val stringlistlistCol: java.sql.Array,
275+
val stringlistlistCol: List<List<String?>?>,
272276
@ColumnName("union_col")
273277
val unionCol: Any,
274278
)
@@ -310,7 +314,19 @@ class DuckDbTest {
310314
subset = DataFrame.readSqlQuery(connection, """SELECT test_table.name, test_table.age FROM test_table""")
311315
}
312316

313-
schema.compare(Person.expected.schema()).isSuperOrMatches() shouldBe true
317+
withClue({
318+
"""
319+
|Read schema must be <: expected schema
320+
|
321+
|Read Schema:
322+
|${schema.toString().lines().joinToString("\n|")}
323+
|
324+
|expected Schema:
325+
|${Person.expected.schema().toString().lines().joinToString("\n|")}
326+
""".trimMargin()
327+
}) {
328+
schema.compare(Person.expected.schema()).isSuperOrMatches() shouldBe true
329+
}
314330

315331
df.cast<Person>(verify = true) shouldBe Person.expected
316332
df.assertInferredTypesMatchSchema()
@@ -545,10 +561,24 @@ class DuckDbTest {
545561
df = DataFrame.readSqlTable(connection, "table1").reorderColumnsByName()
546562
}
547563

548-
schema.compare(GeneralPurposeTypes.expected.schema()).isSuperOrMatches() shouldBe true
564+
// schema.toString().lines().sorted().joinToString("\n") shouldBe
565+
// GeneralPurposeTypes.expected.schema().toString().lines().sorted().joinToString("\n")
566+
withClue({
567+
"""
568+
|Read schema must be <: expected schema
569+
|
570+
|Read Schema:
571+
|${schema.toString().lines().joinToString("\n|")}
572+
|
573+
|expected Schema:
574+
|${GeneralPurposeTypes.expected.schema().toString().lines().joinToString("\n|")}
575+
""".trimMargin()
576+
}) {
577+
schema.compare(GeneralPurposeTypes.expected.schema()).isSuperOrMatches() shouldBe true
578+
}
549579

550580
// on some systems OffsetDateTime's get converted to UTC sometimes, let's compare them as Instant instead
551-
fun AnyFrame.fixOffsetDateTime() = convert { colsOf<OffsetDateTime>() }.with { it.toInstant() }
581+
fun AnyFrame.fixOffsetDateTime() = convert { colsOf<JavaOffsetDateTime>() }.with { it.toInstant() }
552582

553583
df.cast<GeneralPurposeTypes>(verify = true).fixOffsetDateTime() shouldBe
554584
GeneralPurposeTypes.expected.fixOffsetDateTime()
@@ -606,19 +636,18 @@ class DuckDbTest {
606636
df as DataFrame<NestedTypes>
607637

608638
df.single().let {
609-
it[{ "intarray_col"<java.sql.Array>() }].array shouldBe arrayOf(1, 2, null)
610-
it[{ "stringarray_col"<java.sql.Array>() }].array shouldBe arrayOf("a", "ab", "abc")
611-
it[{ "intlist_col"<java.sql.Array>() }].array shouldBe arrayOf(1, 2, 3)
612-
it[{ "stringlist_col"<java.sql.Array>() }].array shouldBe arrayOf("a", "ab", "abc")
613-
(it[{ "stringlistlist_col"<java.sql.Array>() }].array as Array<*>)
614-
.map { (it as java.sql.Array?)?.array } shouldBe listOf(arrayOf("a", "ab"), arrayOf("abc"), null)
615-
it[{ "intstringmap_col"<Map<Int, String?>>() }] shouldBe mapOf(1 to "value1", 200 to "value2")
616-
it[{ "intstrinstinggmap_col"<Map<Int, Map<String, String?>>>() }] shouldBe mapOf(
639+
it["intarray_col"] shouldBe listOf(1, 2, null)
640+
it["stringarray_col"] shouldBe listOf("a", "ab", "abc")
641+
it["intlist_col"] shouldBe listOf(1, 2, 3)
642+
it["stringlist_col"] shouldBe listOf("a", "ab", "abc")
643+
it["stringlistlist_col"] shouldBe listOf(listOf("a", "ab"), listOf("abc"), null)
644+
it["intstringmap_col"] shouldBe mapOf(1 to "value1", 200 to "value2")
645+
it["intstrinstinggmap_col"] shouldBe mapOf(
617646
1 to mapOf("value1" to "a", "value2" to "b"),
618647
200 to mapOf("value1" to "c", "value2" to "d"),
619648
)
620649
it[{ "ijstruct_col"<java.sql.Struct>() }].attributes shouldBe arrayOf<Any>(42, "answer")
621-
it[{ "union_col"<Any>() }] shouldBe 2
650+
it["union_col"] shouldBe 2
622651
}
623652
}
624653

0 commit comments

Comments
 (0)