Skip to content

Commit

Permalink
Allow maps with floating point keys in map_union_sum (facebookincubat…
Browse files Browse the repository at this point in the history
…or#7045)

Summary: Pull Request resolved: facebookincubator#7045

Reviewed By: xiaoxmeng

Differential Revision: D50271047

Pulled By: mbasmanova

fbshipit-source-id: 9bf6aef7d6be274ba213bb719c783832a01457fb
  • Loading branch information
mbasmanova authored and facebook-github-bot committed Oct 13, 2023
1 parent 9992cae commit e2ee0ca
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
28 changes: 25 additions & 3 deletions velox/functions/prestosql/aggregates/MapUnionSumAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,27 @@ std::unique_ptr<exec::Aggregate> createMapUnionSumAggregate(
}

exec::AggregateRegistrationResult registerMapUnionSum(const std::string& name) {
const std::vector<std::string> keyTypes = {
"tinyint",
"smallint",
"integer",
"bigint",
"real",
"double",
"varchar",
};
const std::vector<std::string> valueTypes = {
"tinyint",
"smallint",
"integer",
"bigint",
"double",
"real",
};

std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures;
for (auto keyType : {"tinyint", "smallint", "integer", "bigint", "varchar"}) {
for (auto valueType :
{"tinyint", "smallint", "integer", "bigint", "double", "real"}) {
for (auto keyType : keyTypes) {
for (auto valueType : valueTypes) {
auto mapType = fmt::format("map({},{})", keyType, valueType);
signatures.push_back(exec::AggregateFunctionSignatureBuilder()
.returnType(mapType)
Expand Down Expand Up @@ -386,6 +403,11 @@ exec::AggregateRegistrationResult registerMapUnionSum(const std::string& name) {
case TypeKind::BIGINT:
return createMapUnionSumAggregate<int64_t>(
valueTypeKind, resultType);
case TypeKind::REAL:
return createMapUnionSumAggregate<float>(valueTypeKind, resultType);
case TypeKind::DOUBLE:
return createMapUnionSumAggregate<double>(
valueTypeKind, resultType);
case TypeKind::VARCHAR:
return createMapUnionSumAggregate<StringView>(
valueTypeKind, resultType);
Expand Down
25 changes: 25 additions & 0 deletions velox/functions/prestosql/aggregates/tests/MapUnionSumTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,5 +378,30 @@ TEST_F(MapUnionSumTest, groupByVarcharKey) {
testAggregations({data}, {"c0"}, {"map_union_sum(c1)"}, {expected});
}

TEST_F(MapUnionSumTest, floatingPointKeys) {
auto data = makeRowVector({
makeFlatVector<int32_t>({1, 2, 1, 2, 1, 1, 2, 2}),
makeMapVectorFromJson<float, int64_t>({
"{1.1: 10, 1.2: 20, 1.3: 30}",
"{2.1: 10, 1.2: 20, 2.3: 30}",
"{3.1: 10, 1.2: 20, 2.3: 30}",
"{}",
"null",
"{4.1: 10, 4.2: 20, 2.3: 30}",
"{5.1: 10, 4.2: 20, 2.3: 30}",
"{6.1: 10, 6.2: 20, 6.3: 30}",
}),
});

auto expected = makeRowVector({
makeMapVectorFromJson<float, int64_t>({
"{1.1: 10, 1.2: 60, 1.3: 30, 2.1: 10, 2.3: 120, 3.1: 10, 4.1: 10, "
"4.2: 40, 5.1: 10, 6.1: 10, 6.2: 20, 6.3: 30}",
}),
});

testAggregations({data}, {}, {"map_union_sum(c1)"}, {expected});
}

} // namespace
} // namespace facebook::velox::aggregate::test
4 changes: 3 additions & 1 deletion velox/vector/tests/utils/VectorMaker.h
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,8 @@ class VectorMaker {
MAP(CppToType<K>::create(), CppToType<V>::create())) {
static_assert(
std::is_same_v<K, int8_t> || std::is_same_v<K, int16_t> ||
std::is_same_v<K, int32_t> || std::is_same_v<K, int64_t>);
std::is_same_v<K, int32_t> || std::is_same_v<K, int64_t> ||
std::is_same_v<K, float> || std::is_same_v<K, double>);

std::vector<std::optional<std::vector<std::pair<K, std::optional<V>>>>>
maps;
Expand All @@ -779,6 +780,7 @@ class VectorMaker {

folly::json::serialization_opts options;
options.convert_int_keys = true;
options.allow_non_string_keys = true;
folly::dynamic mapObject = folly::parseJson(jsonMap, options);
if (mapObject.isNull()) {
// Null map.
Expand Down

0 comments on commit e2ee0ca

Please sign in to comment.