Skip to content

Commit

Permalink
Fix array_intersect for array<array<T>>
Browse files Browse the repository at this point in the history
  • Loading branch information
kewang1024 committed Oct 29, 2024
1 parent 0d04e97 commit d3051d6
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,6 @@ public static Block intersect(
@SqlType("array<T>")
public static String arrayIntersectArray()
{
return "RETURN reduce(input, null, (s, x) -> IF((s IS NULL), x, array_intersect(s, x)), (s) -> s)";
return "RETURN reduce(input, IF((cardinality(input) = 0), ARRAY[], input[1]), (s, x) -> array_intersect(s, x), (s) -> s)";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ public void testDuplicates()
public void testSqlFunctions()
{
assertFunction("array_intersect(ARRAY[ARRAY[1, 3, 5], ARRAY[2, 3, 5], ARRAY[3, 3, 3, 6]])", new ArrayType(INTEGER), ImmutableList.of(3));
assertFunction("array_intersect(ARRAY[null, ARRAY[], ARRAY[1, 2, 3]])", new ArrayType(INTEGER), null);
assertFunction("array_intersect(ARRAY[ARRAY[], null, ARRAY[1, 2, 3]])", new ArrayType(INTEGER), null);
assertFunction("array_intersect(ARRAY[])", new ArrayType(UNKNOWN), ImmutableList.of());
assertFunction("array_intersect(null)", new ArrayType(UNKNOWN), null);
assertFunction("array_intersect(ARRAY[ARRAY[], ARRAY[1, 2, 3]])", new ArrayType(INTEGER), ImmutableList.of());
assertFunction("array_intersect(ARRAY[ARRAY[1, 2, 3], null])", new ArrayType(INTEGER), null);
assertFunction("array_intersect(ARRAY[ARRAY[DOUBLE'1.1', DOUBLE'2.2', DOUBLE'3.3'], ARRAY[DOUBLE'1.1', DOUBLE'3.4'], ARRAY[DOUBLE'1.0', DOUBLE'1.1', DOUBLE'1.2']])", new ArrayType(DOUBLE), ImmutableList.of(1.1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ public abstract class AbstractTestNanQueries
public static final String SIMPLE_DOUBLE_ARRAY_COLUMN = "simple_double_array";
public static final String SIMPLE_REAL_ARRAY_COLUMN = "simple_real_array";

public static final String ARRAY_TABLE_NAME_NO_NULL = "array_nans_table_no_null";
public static final String SIMPLE_DOUBLE_ARRAY_COLUMN_NO_NULL = "simple_double_array_no_null";
public static final String SIMPLE_REAL_ARRAY_COLUMN_NO_NULL = "simple_real_array_no_null";

public static final String MAP_TABLE_NAME = "map_nans_table";
public static final String DOUBLE_MAP_COLUMN = "double_map";
public static final String REAL_MAP_COLUMN = "real_map";
Expand Down Expand Up @@ -96,7 +100,19 @@ public void setup()
"(ARRAY[DOUBLE '0', DOUBLE '1', nan(), DOUBLE '-1', nan(), DOUBLE '1', DOUBLE '1', DOUBLE'0'], ARRAY [REAL '0', REAL '1', CAST(nan() AS REAL), REAL '-1', CAST(nan() AS REAL), REAL '1', REAL '1', REAL '0'])) " +
"AS t (" + SIMPLE_DOUBLE_ARRAY_COLUMN + ", " + SIMPLE_REAL_ARRAY_COLUMN + ")";

@Language("SQL") String createArrayTableNoNullQuery = "" +
"CREATE TABLE " + ARRAY_TABLE_NAME_NO_NULL + " AS " +
"SELECT * FROM (VALUES " +
"(ARRAY[nan(), DOUBLE '0', DOUBLE '1', DOUBLE '-1'], ARRAY[cast(nan() AS REAL), REAL '0', REAL '1', REAL '-1']), " +
"(ARRAY[ DOUBLE '0', nan(), DOUBLE '1', DOUBLE '-1'], ARRAY[REAL '0', CAST(nan() AS REAL), REAL '1', REAL '-1']), " +
"(ARRAY[ DOUBLE '0', DOUBLE '1', DOUBLE '-1', nan()], ARRAY[REAL '0', REAL '1', REAL '-1', CAST(nan() AS REAL)]), " +
"(ARRAY[null, nan(), DOUBLE '200'], ARRAY[null, CAST(nan() AS REAL), REAL '200']), " +
"(ARRAY[nan(), nan()], ARRAY[CAST(nan() AS REAL), CAST(nan() AS REAL)]), " +
"(ARRAY[DOUBLE '0', DOUBLE '1', nan(), DOUBLE '-1', nan(), DOUBLE '1', DOUBLE '1', DOUBLE'0'], ARRAY [REAL '0', REAL '1', CAST(nan() AS REAL), REAL '-1', CAST(nan() AS REAL), REAL '1', REAL '1', REAL '0'])) " +
"AS t (" + SIMPLE_DOUBLE_ARRAY_COLUMN_NO_NULL + ", " + SIMPLE_REAL_ARRAY_COLUMN_NO_NULL + ")";

assertUpdate(createArrayTableQuery, 7);
assertUpdate(createArrayTableNoNullQuery, 6);

@Language("SQL") String createMapTableQuery = "" +
"CREATE TABLE " + MAP_TABLE_NAME + " AS " +
Expand Down Expand Up @@ -728,6 +744,9 @@ public void testDoubleArrayIntersect2()
// Test the array of arrays function signature
assertQueryWithSameQueryRunner(
format("SELECT array_sort(array_intersect(array_agg(%s))) FROM %s", SIMPLE_DOUBLE_ARRAY_COLUMN, ARRAY_TABLE_NAME),
"SELECT NULL");
assertQueryWithSameQueryRunner(
format("SELECT array_sort(array_intersect(array_agg(%s))) FROM %s", SIMPLE_DOUBLE_ARRAY_COLUMN_NO_NULL, ARRAY_TABLE_NAME_NO_NULL),
"SELECT * FROM (VALUES (ARRAY[nan()]))");
}

Expand All @@ -737,6 +756,9 @@ public void testRealArrayIntersect2()
// Test the array of arrays function signature
assertQueryWithSameQueryRunner(
format("SELECT array_sort(array_intersect(array_agg(%s))) FROM %s", SIMPLE_REAL_ARRAY_COLUMN, ARRAY_TABLE_NAME),
"SELECT NULL");
assertQueryWithSameQueryRunner(
format("SELECT array_sort(array_intersect(array_agg(%s))) FROM %s", SIMPLE_REAL_ARRAY_COLUMN_NO_NULL, ARRAY_TABLE_NAME_NO_NULL),
"SELECT * FROM (VALUES (ARRAY[CAST(nan() AS REAL)]))");
}

Expand Down

0 comments on commit d3051d6

Please sign in to comment.