Skip to content

Commit c27404c

Browse files
mbasmanovafacebook-github-bot
authored andcommitted
fix: Enable constant folding for lambda functions (facebookincubator#13642)
Summary: Pull Request resolved: facebookincubator#13642 ExprCompiler used to not constant fold lambda functions. transform(array[1, 2, 3], x -> x * 2) would not be constant folded. This was due to Expr::isConstant returning false for such an expression. Expr::isConstant returned true iff expression is deterministic and all input are ConstantExpr. A fix is to modify Expr::isConstant to return true iff expression is deterministic and has no dependencies (distinctFields_ is empty). Also, added convenience API tryEvaluateConstantExpression to compliment existing evaluateConstantExpression. The new API can be safely called on any expression without ensuring the expression is constant-foldable. Reviewed By: bikramSingh91, xiaoxmeng Differential Revision: D76004362 fbshipit-source-id: f3fc5dfcb8bcf93205fcc60b1709ddd09aa75ac3
1 parent 344f4ef commit c27404c

File tree

3 files changed

+109
-40
lines changed

3 files changed

+109
-40
lines changed

velox/expression/Expr.cpp

Lines changed: 67 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,20 +1192,36 @@ void Expr::evalWithMemo(
11921192
VectorPtr base;
11931193
distinctFields_[0]->evalSpecialForm(rows, context, base);
11941194

1195+
// evalWithNulls may throw an exception. If this happens during constant
1196+
// folding, the exception is suppressed and the Expr object may be reused.
1197+
// Hence, it is important to update state in way that ensure "valid" state in
1198+
// case of exceptions.
1199+
//
1200+
// Also, note that the same expression running on same data may pass or may
1201+
// fail depending on whether it runs under TRY or not.
1202+
//
1203+
// An example expression that triggers these edge cases:
1204+
//
1205+
// try(coalesce(array_min_by(array[1, 2, 3], x -> x / 0), 0::INTEGER))
1206+
11951207
if (base.get() != baseOfDictionaryRawPtr_ ||
11961208
baseOfDictionaryWeakPtr_.expired()) {
11971209
baseOfDictionaryRepeats_ = 0;
1198-
baseOfDictionaryWeakPtr_ = base;
1199-
baseOfDictionaryRawPtr_ = base.get();
1210+
baseOfDictionaryWeakPtr_.reset();
1211+
baseOfDictionaryRawPtr_ = nullptr;
12001212
context.releaseVector(baseOfDictionary_);
12011213
context.releaseVector(dictionaryCache_);
1214+
12021215
evalWithNulls(rows, context, result);
1216+
baseOfDictionaryWeakPtr_ = base;
1217+
baseOfDictionaryRawPtr_ = base.get();
12031218
return;
12041219
}
1205-
++baseOfDictionaryRepeats_;
12061220

1207-
if (baseOfDictionaryRepeats_ == 1) {
1221+
if (baseOfDictionaryRepeats_ == 0) {
12081222
evalWithNulls(rows, context, result);
1223+
1224+
++baseOfDictionaryRepeats_;
12091225
baseOfDictionary_ = base;
12101226
dictionaryCache_ = result;
12111227
if (!cachedDictionaryIndices_) {
@@ -1217,6 +1233,8 @@ void Expr::evalWithMemo(
12171233
return;
12181234
}
12191235

1236+
++baseOfDictionaryRepeats_;
1237+
12201238
if (cachedDictionaryIndices_) {
12211239
LocalSelectivityVector cachedHolder(context, rows);
12221240
auto cached = cachedHolder.get();
@@ -1242,31 +1260,34 @@ void Expr::evalWithMemo(
12421260

12431261
evalWithNulls(*uncached, context, result);
12441262
context.deselectErrors(*uncached);
1245-
context.exprSet()->addToMemo(this);
1246-
auto newCacheSize = uncached->end();
1247-
1248-
// dictionaryCache_ is valid only for cachedDictionaryIndices_. Hence, a
1249-
// safe call to BaseVector::ensureWritable must include all the rows not
1250-
// covered by cachedDictionaryIndices_. If BaseVector::ensureWritable is
1251-
// called only for a subset of rows not covered by
1252-
// cachedDictionaryIndices_, it will attempt to copy rows that are not
1253-
// valid leading to a crash.
1254-
LocalSelectivityVector allUncached(context, dictionaryCache_->size());
1255-
allUncached.get()->setAll();
1256-
allUncached.get()->deselect(*cachedDictionaryIndices_);
1257-
context.ensureWritable(*allUncached.get(), type(), dictionaryCache_);
1258-
1259-
if (cachedDictionaryIndices_->size() < newCacheSize) {
1260-
cachedDictionaryIndices_->resize(newCacheSize, false);
1261-
}
12621263

1263-
cachedDictionaryIndices_->select(*uncached);
1264+
if (uncached->hasSelections()) {
1265+
context.exprSet()->addToMemo(this);
1266+
auto newCacheSize = uncached->end();
1267+
1268+
// dictionaryCache_ is valid only for cachedDictionaryIndices_. Hence, a
1269+
// safe call to BaseVector::ensureWritable must include all the rows not
1270+
// covered by cachedDictionaryIndices_. If BaseVector::ensureWritable is
1271+
// called only for a subset of rows not covered by
1272+
// cachedDictionaryIndices_, it will attempt to copy rows that are not
1273+
// valid leading to a crash.
1274+
LocalSelectivityVector allUncached(context, dictionaryCache_->size());
1275+
allUncached.get()->setAll();
1276+
allUncached.get()->deselect(*cachedDictionaryIndices_);
1277+
context.ensureWritable(*allUncached.get(), type(), dictionaryCache_);
1278+
1279+
if (cachedDictionaryIndices_->size() < newCacheSize) {
1280+
cachedDictionaryIndices_->resize(newCacheSize, false);
1281+
}
12641282

1265-
// Resize the dictionaryCache_ to accommodate all the necessary rows.
1266-
if (dictionaryCache_->size() < uncached->end()) {
1267-
dictionaryCache_->resize(uncached->end());
1283+
cachedDictionaryIndices_->select(*uncached);
1284+
1285+
// Resize the dictionaryCache_ to accommodate all the necessary rows.
1286+
if (dictionaryCache_->size() < uncached->end()) {
1287+
dictionaryCache_->resize(uncached->end());
1288+
}
1289+
dictionaryCache_->copy(result.get(), *uncached, nullptr);
12681290
}
1269-
dictionaryCache_->copy(result.get(), *uncached, nullptr);
12701291
}
12711292
context.releaseVector(base);
12721293
}
@@ -1660,12 +1681,8 @@ bool Expr::isConstant() const {
16601681
if (!isDeterministic()) {
16611682
return false;
16621683
}
1663-
for (auto& input : inputs_) {
1664-
if (!input->is<ConstantExpr>()) {
1665-
return false;
1666-
}
1667-
}
1668-
return true;
1684+
1685+
return distinctFields_.empty();
16691686
}
16701687

16711688
namespace {
@@ -2022,17 +2039,30 @@ core::ExecCtx* SimpleExpressionEvaluator::ensureExecCtx() {
20222039
VectorPtr evaluateConstantExpression(
20232040
const core::TypedExprPtr& expr,
20242041
memory::MemoryPool* pool) {
2042+
auto result = tryEvaluateConstantExpression(expr, pool);
2043+
VELOX_USER_CHECK_NOT_NULL(
2044+
result, "Expression is not constant-foldable: {}", expr->toString());
2045+
return result;
2046+
}
2047+
2048+
VectorPtr tryEvaluateConstantExpression(
2049+
const core::TypedExprPtr& expr,
2050+
memory::MemoryPool* pool) {
20252051
auto data = BaseVector::create<RowVector>(ROW({}), 1, pool);
20262052

20272053
auto queryCtx = velox::core::QueryCtx::create();
20282054
velox::core::ExecCtx execCtx{pool, queryCtx.get()};
20292055
velox::exec::ExprSet exprSet({expr}, &execCtx);
2030-
velox::exec::EvalCtx evalCtx(&execCtx, &exprSet, data.get());
20312056

2032-
velox::SelectivityVector singleRow(1);
2033-
std::vector<velox::VectorPtr> results(1);
2034-
exprSet.eval(singleRow, evalCtx, results);
2035-
return results.at(0);
2057+
if (exprSet.expr(0)->is<ConstantExpr>()) {
2058+
velox::exec::EvalCtx evalCtx(&execCtx, &exprSet, data.get());
2059+
velox::SelectivityVector singleRow(1);
2060+
std::vector<velox::VectorPtr> results(1);
2061+
exprSet.eval(singleRow, evalCtx, results);
2062+
return results.at(0);
2063+
}
2064+
2065+
return nullptr;
20362066
}
20372067

20382068
} // namespace facebook::velox::exec

velox/expression/Expr.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -836,12 +836,20 @@ std::unique_ptr<ExprSet> makeExprSetFromFlag(
836836
std::vector<core::TypedExprPtr>&& source,
837837
core::ExecCtx* execCtx);
838838

839-
/// Evaluates an expression that doesn't depend on any inputs and returns the
840-
/// result as single-row vector.
839+
/// Evaluates a deterministic expression that doesn't depend on any inputs and
840+
/// returns the result as single-row vector. Throws if expression is
841+
/// non-deterministic or has dependencies.
841842
VectorPtr evaluateConstantExpression(
842843
const core::TypedExprPtr& expr,
843844
memory::MemoryPool* pool);
844845

846+
/// Evaluates a deterministic expression that doesn't depend on any inputs and
847+
/// returns the result as single-row vector. Returns nullptr if the expression
848+
/// is non-deterministic or has dependencies.
849+
VectorPtr tryEvaluateConstantExpression(
850+
const core::TypedExprPtr& expr,
851+
memory::MemoryPool* pool);
852+
845853
/// Returns a string representation of the expression trees annotated with
846854
/// runtime statistics. Expected to be called after calling ExprSet::eval one or
847855
/// more times. If called before ExprSet::eval runtime statistics will be all

velox/expression/tests/ExprTest.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4986,7 +4986,7 @@ TEST_F(ExprTest, disabledeferredLazyLoading) {
49864986

49874987
TEST_F(ExprTest, evaluateConstantExpression) {
49884988
auto eval = [&](const std::string& sql) {
4989-
auto expr = parseExpression(sql, ROW({}));
4989+
auto expr = parseExpression(sql, ROW({"a"}, {BIGINT()}));
49904990
return exec::evaluateConstantExpression(expr, pool());
49914991
};
49924992

@@ -4995,6 +4995,37 @@ TEST_F(ExprTest, evaluateConstantExpression) {
49954995
assertEqualVectors(
49964996
eval("transform(array[1, 2, 3], x -> (x * 2))"),
49974997
makeArrayVectorFromJson<int64_t>({"[2, 4, 6]"}));
4998+
4999+
assertEqualVectors(
5000+
eval("transform(array[1, 2, 3], x -> (x * (3 - 1)))"),
5001+
makeArrayVectorFromJson<int64_t>({"[2, 4, 6]"}));
5002+
5003+
assertEqualVectors(
5004+
eval("transform(array[1, 2, 3], x -> 2)"),
5005+
makeArrayVectorFromJson<int64_t>({"[2, 2, 2]"}));
5006+
5007+
assertEqualVectors(
5008+
eval(
5009+
"try(coalesce(array_min_by(array[1, 2, 3], x -> x / 0), 0::INTEGER))"),
5010+
makeNullConstant(TypeKind::INTEGER, 1));
5011+
5012+
auto tryEval = [&](const std::string& sql) {
5013+
auto expr = parseExpression(sql, ROW({"a"}, {BIGINT()}));
5014+
return exec::tryEvaluateConstantExpression(expr, pool());
5015+
};
5016+
5017+
VELOX_ASSERT_THROW(eval("a + 1"), "Expression is not constant-foldable");
5018+
ASSERT_TRUE(tryEval("a + 1") == nullptr);
5019+
5020+
VELOX_ASSERT_THROW(
5021+
eval("rand() + 1.0"), "Expression is not constant-foldable");
5022+
ASSERT_TRUE(tryEval("rand() + 1.0") == nullptr);
5023+
5024+
VELOX_ASSERT_THROW(
5025+
eval("transform(array[1, 2, 3], x -> (x * 2) + a)"),
5026+
"Expression is not constant-foldable");
5027+
ASSERT_TRUE(
5028+
tryEval("transform(array[1, 2, 3], x -> (x * 2) + a)") == nullptr);
49985029
}
49995030

50005031
TEST_F(ExprTest, peelingOnDeterministicFunctionInNonDeterministicExpr) {

0 commit comments

Comments
 (0)