Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support UDF in plan generator #3040

Draft
wants to merge 78 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
3a5a44b
save
pengpeng-lu Jul 28, 2023
53416bf
Revert "save"
pengpeng-lu Jul 28, 2023
3f88bb4
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Jul 28, 2023
773e0dd
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Aug 11, 2023
fa472d7
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Sep 18, 2023
4c32432
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Sep 25, 2023
ccd8e8d
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Oct 17, 2023
403a4a2
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Nov 30, 2023
cada59a
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Nov 30, 2023
3b02019
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Jan 16, 2024
138a7fc
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Jan 17, 2024
f8b3658
Merge branch 'main' of github.com:pengpeng-lu/fdb-record-layer
pengpeng-lu Feb 13, 2024
b4edc0a
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Feb 20, 2024
c42c0ce
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Feb 21, 2024
a570479
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Feb 27, 2024
f4e5cac
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Mar 4, 2024
c10ec5b
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Apr 2, 2024
0d9cc50
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Apr 12, 2024
057fee0
Merge remote-tracking branch 'upstream/main'
pengpeng-lu May 8, 2024
37c8641
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Jun 7, 2024
a52d15a
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Jun 11, 2024
74c24fb
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Jul 11, 2024
2f33668
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Jul 17, 2024
01c4d01
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Aug 7, 2024
8c70b2a
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Aug 13, 2024
5d59fa2
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Aug 15, 2024
7e84176
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Aug 26, 2024
0d54f1e
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Aug 29, 2024
543a79c
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Sep 5, 2024
8fd8208
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Sep 11, 2024
f5b3314
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Sep 13, 2024
9956a9f
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Sep 19, 2024
5476bb4
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Nov 20, 2024
f9fd6a9
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Nov 20, 2024
651816c
save
pengpeng-lu Nov 20, 2024
2c886c8
save
pengpeng-lu Dec 5, 2024
38a8004
save
pengpeng-lu Dec 11, 2024
4f7a23f
save
pengpeng-lu Dec 12, 2024
25eb6cf
save
pengpeng-lu Dec 12, 2024
6ebd19b
reformat
pengpeng-lu Dec 13, 2024
f8088f5
replace
pengpeng-lu Dec 16, 2024
157ab0c
put keyexpression in separate proto
pengpeng-lu Dec 18, 2024
0daa4ce
checkstyle
pengpeng-lu Dec 18, 2024
9a325bf
checksylte
pengpeng-lu Dec 19, 2024
bffaca4
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Dec 19, 2024
b3589d2
Merge branch 'main' into udf
pengpeng-lu Dec 19, 2024
46eaaf1
checkstyle
pengpeng-lu Dec 19, 2024
523859e
save
pengpeng-lu Jan 9, 2025
18312ed
Merge remote-tracking branch 'upstream/main'
pengpeng-lu Jan 9, 2025
0a70d6e
merge main
pengpeng-lu Jan 9, 2025
d058c09
MacroFunction
pengpeng-lu Jan 10, 2025
67fad06
save
pengpeng-lu Jan 10, 2025
3adcce3
save
pengpeng-lu Nov 18, 2024
74602e7
first clean
pengpeng-lu Dec 5, 2024
fa28a85
save
pengpeng-lu Dec 11, 2024
b0e5b37
store value
pengpeng-lu Dec 12, 2024
91322a2
project work
pengpeng-lu Dec 12, 2024
eee22fc
reformat
pengpeng-lu Dec 13, 2024
41112ed
fix parser
pengpeng-lu Dec 20, 2024
df4235e
clean
pengpeng-lu Dec 20, 2024
c343ed3
more test
pengpeng-lu Dec 20, 2024
6aaae1d
save
pengpeng-lu Jan 10, 2025
4f2b546
move code
pengpeng-lu Jan 10, 2025
e11b315
add test
pengpeng-lu Jan 11, 2025
29b9cb2
pmd
pengpeng-lu Jan 13, 2025
52b3ba4
comments
pengpeng-lu Jan 14, 2025
1cc89a6
save
pengpeng-lu Jan 14, 2025
91e6117
change package
pengpeng-lu Jan 14, 2025
262f03b
fix test compile error
pengpeng-lu Jan 15, 2025
4015220
Merge branch 'udf' into udo
pengpeng-lu Jan 15, 2025
e8c4dfc
change
pengpeng-lu Jan 15, 2025
8076737
save
pengpeng-lu Feb 3, 2025
cea7e31
clean
pengpeng-lu Feb 3, 2025
a71c70b
save
pengpeng-lu Feb 4, 2025
ed5e452
merge main
pengpeng-lu Mar 19, 2025
7f030c9
checkstyle
pengpeng-lu Mar 20, 2025
e3bbc51
clean
pengpeng-lu Mar 20, 2025
f5256d9
nit
pengpeng-lu Mar 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,11 @@ public Map<String, RecordType> getRecordTypes() {
return recordTypes;
}

@Nonnull
public Map<String, UserDefinedFunction> getUserDefinedFunctionMap() {
return userDefinedFunctionMap;
}

@Nonnull
public RecordType getRecordType(@Nonnull String name) {
RecordType recordType = recordTypes.get(name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,10 @@ private void loadProtoExceptRecords(@Nonnull RecordMetaDataProto.MetaData metaDa
typeBuilder.setRecordTypeKey(LiteralKeyExpression.fromProtoValue(typeProto.getExplicitKey()));
}
}
PlanSerializationContext serializationContext = new PlanSerializationContext(DefaultPlanSerializationRegistry.INSTANCE,
PlanHashable.CURRENT_FOR_CONTINUATION);
for (RecordMetaDataProto.PUserDefinedFunction function: metaDataProto.getUserDefinedFunctionsList()) {
UserDefinedFunction func = (UserDefinedFunction)PlanSerialization.dispatchFromProtoContainer(new PlanSerializationContext(DefaultPlanSerializationRegistry.INSTANCE,
PlanHashable.CURRENT_FOR_CONTINUATION), function);
UserDefinedFunction func = (UserDefinedFunction)PlanSerialization.dispatchFromProtoContainer(serializationContext, function);
userDefinedFunctionMap.put(func.getFunctionName(), func);
}
if (metaDataProto.hasSplitLongRecords()) {
Expand Down
14 changes: 12 additions & 2 deletions fdb-relational-core/src/main/antlr/RelationalParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ utilityStatement

templateClause
:
CREATE ( structDefinition | tableDefinition | enumDefinition | indexDefinition )
CREATE ( structDefinition | tableDefinition | enumDefinition | indexDefinition | functionDefinition)
;

createStatement
Expand Down Expand Up @@ -154,6 +154,10 @@ indexDefinition
: (UNIQUE)? INDEX indexName=uid AS queryTerm indexAttributes?
;

functionDefinition
: FUNCTION functionName=uid LEFT_ROUND_BRACKET paramName=uid inputTypeName=columnType RIGHT_ROUND_BRACKET RETURNS columnType AS fullId
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this only declare exactly one parameter? We certainly want to support more than one parameter. Maybe I am misreading this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this clearer as to what this function definition actually does? How about explicitly saying MACRO FUNCTION. I don't want us to have a clash when we actually implement user defined and SQL-bodies functions that probably will be a bit different.

;

indexAttributes
: WITH ATTRIBUTES indexAttribute (COMMA indexAttribute)*
;
Expand Down Expand Up @@ -560,6 +564,11 @@ uid
| DOUBLE_QUOTE_ID
;

userDefinedFunctionName
: ID
| DOUBLE_QUOTE_ID
;

// done
simpleId
: ID
Expand Down Expand Up @@ -789,6 +798,7 @@ functionCall
: aggregateWindowedFunction #aggregateFunctionCall // done (supported)
| specificFunction #specificFunctionCall //
| scalarFunctionName '(' functionArgs? ')' #scalarFunctionCall // done (unsupported)
| userDefinedFunctionName '(' functionArgs? ')' #userDefinedFunctionCall
;

specificFunction
Expand Down Expand Up @@ -899,7 +909,7 @@ levelInWeightListElement
;

aggregateWindowedFunction
: functionName=(AVG | MAX | MIN | SUM | MAX_EVER | MIN_EVER )
: functionName=(AVG | MAX | MIN | SUM | MAX_EVER | MIN_EVER)
'(' aggregator=(ALL | DISTINCT)? functionArg ')' overClause?
| functionName=BITMAP_CONSTRUCT_AGG '(' functionArg ')'
| functionName=COUNT '(' (starArg='*' | aggregator=ALL? functionArg | aggregator=DISTINCT functionArgs) ')' overClause?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.apple.foundationdb.record.RecordMetaDataProto;
import com.apple.foundationdb.record.metadata.Key;
import com.apple.foundationdb.record.query.combinatorics.TopologicalSort;
import com.apple.foundationdb.record.query.plan.cascades.UserDefinedFunction;
import com.apple.foundationdb.relational.api.exceptions.ErrorCode;
import com.apple.foundationdb.relational.api.exceptions.RelationalException;
import com.apple.foundationdb.relational.api.metadata.DataType;
Expand All @@ -50,6 +51,7 @@
import javax.annotation.Nonnull;
import java.util.BitSet;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
Expand All @@ -68,6 +70,9 @@ public final class RecordLayerSchemaTemplate implements SchemaTemplate {
@Nonnull
private final Set<RecordLayerTable> tables;

@Nonnull
private final Map<String, UserDefinedFunction> userDefinedFunctionMap;

private final int version;

private final boolean enableLongRows;
Expand All @@ -85,11 +90,13 @@ public final class RecordLayerSchemaTemplate implements SchemaTemplate {

private RecordLayerSchemaTemplate(@Nonnull final String name,
@Nonnull final Set<RecordLayerTable> tables,
@Nonnull final Map<String, UserDefinedFunction> userDefinedFunctionMap,
int version,
boolean enableLongRows,
boolean storeRowVersions) {
this.name = name;
this.tables = tables;
this.userDefinedFunctionMap = userDefinedFunctionMap;
this.version = version;
this.enableLongRows = enableLongRows;
this.storeRowVersions = storeRowVersions;
Expand All @@ -100,13 +107,15 @@ private RecordLayerSchemaTemplate(@Nonnull final String name,

private RecordLayerSchemaTemplate(@Nonnull final String name,
@Nonnull final Set<RecordLayerTable> tables,
@Nonnull final Map<String, UserDefinedFunction> userDefinedFunctionMap,
int version,
boolean enableLongRows,
boolean storeRowVersions,
@Nonnull final RecordMetaData cachedMetadata) {
this.name = name;
this.version = version;
this.tables = tables;
this.userDefinedFunctionMap = userDefinedFunctionMap;
this.enableLongRows = enableLongRows;
this.storeRowVersions = storeRowVersions;
this.metaDataSupplier = Suppliers.memoize(() -> cachedMetadata);
Expand Down Expand Up @@ -147,6 +156,11 @@ public RecordLayerSchema generateSchema(@Nonnull String databaseId, @Nonnull Str
return new RecordLayerSchema(schemaName, databaseId, this);
}

@Nonnull
public Collection<UserDefinedFunction> getAllUserDefinedFunctions() {
return userDefinedFunctionMap.values();
}

@Nonnull
public Descriptors.Descriptor getDescriptor(@Nonnull final String tableName) {
return toRecordMetadata().getRecordType(tableName).getDescriptor();
Expand Down Expand Up @@ -309,12 +323,15 @@ public static final class Builder {

private final Map<String, RecordLayerTable> tables;

private final Map<String, UserDefinedFunction> functionMap;

private final Map<String, DataType.Named> auxiliaryTypes; // for quick lookup

private RecordMetaData cachedMetadata;

private Builder() {
tables = new LinkedHashMap<>();
functionMap = new HashMap<>();
auxiliaryTypes = new LinkedHashMap<>();
// enable long rows is TRUE by default
enableLongRows = true;
Expand Down Expand Up @@ -402,6 +419,18 @@ public Builder addAuxiliaryTypes(@Nonnull Collection<DataType.Named> auxiliaryTy
return this;
}

@Nonnull
public Builder addUserDefinedFunction(@Nonnull UserDefinedFunction userDefinedFunction) {
functionMap.put(userDefinedFunction.getFunctionName(), userDefinedFunction);
return this;
}

@Nonnull
public Builder addUserDefinedFunctions(@Nonnull Collection<UserDefinedFunction> functions) {
functions.forEach(this::addUserDefinedFunction);
return this;
}

@Nonnull
Builder setCachedMetadata(@Nonnull final RecordMetaData metadata) {
this.cachedMetadata = metadata;
Expand Down Expand Up @@ -459,11 +488,10 @@ public RecordLayerSchemaTemplate build() {
if (needsResolution) {
resolveTypes();
}

if (cachedMetadata != null) {
return new RecordLayerSchemaTemplate(name, new LinkedHashSet<>(tables.values()), version, enableLongRows, storeRowVersions, cachedMetadata);
return new RecordLayerSchemaTemplate(name, new LinkedHashSet<>(tables.values()), functionMap, version, enableLongRows, storeRowVersions, cachedMetadata);
} else {
return new RecordLayerSchemaTemplate(name, new LinkedHashSet<>(tables.values()), version, enableLongRows, storeRowVersions);
return new RecordLayerSchemaTemplate(name, new LinkedHashSet<>(tables.values()), functionMap, version, enableLongRows, storeRowVersions);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ public RecordLayerSchemaTemplate.Builder getSchemaTemplate(@Nonnull final String
.setEnableLongRows(recordMetaData.isSplitLongRecords())
.setName(schemaTemplateName)
.setIntermingleTables(!recordMetaData.primaryKeyHasRecordTypePrefix());
for (final var u: recordMetaData.getUserDefinedFunctionMap().values()) {
schemaTemplateBuilder.addUserDefinedFunction(u);
}
final var nameToTableBuilder = new HashMap<String, RecordLayerTable.Builder>();
for (final var registeredType : registeredTypes) {
switch (registeredType.getType()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ public void visit(@Nonnull com.apple.foundationdb.relational.api.metadata.Index
@Override
public void visit(@Nonnull SchemaTemplate schemaTemplate) {
Assert.thatUnchecked(schemaTemplate instanceof RecordLayerSchemaTemplate);
getBuilder().addUserDefinedFunctions(((RecordLayerSchemaTemplate) schemaTemplate).getAllUserDefinedFunctions());
getBuilder().setSplitLongRecords(schemaTemplate.isEnableLongRows());
getBuilder().setStoreRecordVersions(schemaTemplate.isStoreRowVersions());
getBuilder().setVersion(schemaTemplate.getVersion());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ public boolean prefixedWith(@Nonnull Identifier identifier) {
return true;
}

@Nonnull
public List<String> removePrefix(@Nonnull Identifier prefix) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this used for?

// assume the identifier has the prefix, should call prefixedWith(prefix) to check before calling this method
final var fullName = fullyQualifiedName();
return fullName.subList(prefix.fullyQualifiedName().size(), fullName.size());
}

public boolean qualifiedWith(@Nonnull Identifier identifier) {
final var identifierFullName = identifier.fullyQualifiedName();
final var fullName = fullyQualifiedName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.apple.foundationdb.record.query.plan.cascades.CorrelationIdentifier;
import com.apple.foundationdb.record.query.plan.cascades.IndexAccessHint;
import com.apple.foundationdb.record.query.plan.cascades.Quantifier;
import com.apple.foundationdb.record.query.plan.cascades.UserDefinedFunction;
import com.apple.foundationdb.record.query.plan.cascades.typing.Type;
import com.apple.foundationdb.record.query.plan.cascades.typing.Typed;
import com.apple.foundationdb.record.query.plan.cascades.values.AggregateValue;
Expand All @@ -38,6 +39,7 @@
import com.apple.foundationdb.record.query.plan.cascades.values.IndexableAggregateValue;
import com.apple.foundationdb.record.query.plan.cascades.values.LiteralValue;
import com.apple.foundationdb.record.query.plan.cascades.values.NotValue;
import com.apple.foundationdb.record.query.plan.cascades.values.QuantifiedObjectValue;
import com.apple.foundationdb.record.query.plan.cascades.values.RecordConstructorValue;
import com.apple.foundationdb.record.query.plan.cascades.values.RelOpValue;
import com.apple.foundationdb.record.query.plan.cascades.values.StreamableAggregateValue;
Expand All @@ -51,6 +53,7 @@
import com.apple.foundationdb.relational.api.metadata.Table;
import com.apple.foundationdb.relational.generated.RelationalParser;
import com.apple.foundationdb.relational.recordlayer.metadata.DataTypeUtils;
import com.apple.foundationdb.relational.recordlayer.metadata.RecordLayerSchemaTemplate;
import com.apple.foundationdb.relational.recordlayer.query.functions.SqlFunctionCatalog;
import com.apple.foundationdb.relational.recordlayer.query.functions.SqlFunctionCatalogImpl;
import com.apple.foundationdb.relational.recordlayer.query.visitors.QueryVisitor;
Expand Down Expand Up @@ -100,6 +103,10 @@ public SemanticAnalyzer(@Nonnull SchemaTemplate metadataCatalog,
@Nonnull SqlFunctionCatalog functionCatalog) {
this.metadataCatalog = metadataCatalog;
this.functionCatalog = functionCatalog;
// add UDFs to functionCatalog
for (UserDefinedFunction function: ((RecordLayerSchemaTemplate) metadataCatalog).getAllUserDefinedFunctions()) {
(this.functionCatalog).addUdfFunction(function);
}
}

/**
Expand Down Expand Up @@ -473,6 +480,46 @@ public Optional<Expression> lookupNestedField(@Nonnull Identifier requestedIdent
return Optional.of(nestedAttribute);
}

@Nonnull
public Optional<Value> lookupNestedField(@Nonnull Identifier requestedIdentifier,
@Nonnull Identifier paramId,
@Nonnull QuantifiedObjectValue existingValue,
@Nonnull DataType targetDataType) {
Assert.thatUnchecked(requestedIdentifier.prefixedWith(paramId), "Invalid function definition");

// x -> x
if (requestedIdentifier.fullyQualifiedName().size() == paramId.fullyQualifiedName().size()) {
Assert.thatUnchecked(existingValue.getResultType().equals(DataTypeUtils.toRecordLayerType(targetDataType)), ErrorCode.DATATYPE_MISMATCH, "Result data types don't match!");
return Optional.of(existingValue);
}
// find nested field path
final var remainingPath = requestedIdentifier.removePrefix(paramId);
final ImmutableList.Builder<FieldValue.Accessor> accessors = ImmutableList.builder();
DataType existingDataType = DataTypeUtils.toRelationalType(existingValue.getResultType());
for (String s : remainingPath) {
if (existingDataType.getCode() != DataType.Code.STRUCT) {
return Optional.empty();
}
final var fields = ((DataType.StructType) existingDataType).getFields();
var found = false;
for (int j = 0; j < fields.size(); j++) {
if (fields.get(j).getName().equals(s)) {
accessors.add(new FieldValue.Accessor(fields.get(j).getName(), j));
existingDataType = fields.get(j).getType();
found = true;
break;
}
}
if (!found) {
return Optional.empty();
}
}
final var fieldPath = FieldValue.resolveFieldPath(existingValue.getResultType(), accessors.build());
final var fieldValue = FieldValue.ofFieldsAndFuseIfPossible(existingValue, fieldPath);
Assert.thatUnchecked(fieldValue.getResultType().equals(DataTypeUtils.toRecordLayerType(targetDataType)), ErrorCode.DATATYPE_MISMATCH, "Result data types don't match!");
return Optional.of(fieldValue);
}

@Nonnull
public DataType lookupType(@Nonnull Identifier typeIdentifier, boolean isNullable, boolean isRepeated,
@Nonnull Function<String, Optional<DataType>> dataTypeProvider) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,19 @@

package com.apple.foundationdb.relational.recordlayer.query.functions;

import com.apple.foundationdb.record.query.plan.cascades.BuiltInFunction;
import com.apple.foundationdb.record.query.plan.cascades.typing.Typed;
import com.apple.foundationdb.record.query.plan.cascades.CatalogedFunction;
import com.apple.foundationdb.record.query.plan.cascades.UserDefinedFunction;
import com.apple.foundationdb.relational.recordlayer.query.Expression;

import javax.annotation.Nonnull;

public interface SqlFunctionCatalog {
@Nonnull
BuiltInFunction<? extends Typed> lookUpFunction(@Nonnull String name, @Nonnull Expression... expressions);
CatalogedFunction lookUpFunction(@Nonnull String name, @Nonnull Expression... expressions);

boolean containsFunction(@Nonnull String name);

boolean isUdfFunction(@Nonnull String name);

void addUdfFunction(@Nonnull UserDefinedFunction function);
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import com.apple.foundationdb.annotation.API;

import com.apple.foundationdb.record.query.plan.cascades.BuiltInFunction;
import com.apple.foundationdb.record.query.plan.cascades.CatalogedFunction;
import com.apple.foundationdb.record.query.plan.cascades.UserDefinedFunction;
import com.apple.foundationdb.record.query.plan.cascades.typing.Typed;
import com.apple.foundationdb.record.query.plan.cascades.values.FunctionCatalog;
import com.apple.foundationdb.record.query.plan.cascades.values.RecordConstructorValue;
Expand All @@ -33,7 +35,9 @@
import com.google.common.collect.ImmutableMap;

import javax.annotation.Nonnull;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.StreamSupport;
Expand All @@ -52,26 +56,39 @@ public final class SqlFunctionCatalogImpl implements SqlFunctionCatalog {
@Nonnull
private final ImmutableMap<String, Function<Integer, BuiltInFunction<? extends Typed>>> synonyms;

@Nonnull
private final Map<String, UserDefinedFunction> userDefinedFunctionMap = new HashMap<>();


private SqlFunctionCatalogImpl() {
this.synonyms = createSynonyms();
}

@Nonnull
@Override
public BuiltInFunction<? extends Typed> lookUpFunction(@Nonnull final String name, @Nonnull final Expression... expressions) {
return Assert.notNullUnchecked(Objects.requireNonNull(synonyms.get(name.toLowerCase(Locale.ROOT))).apply(expressions.length));
public CatalogedFunction lookUpFunction(@Nonnull final String name, @Nonnull final Expression... expressions) {
if (synonyms.get(name.toLowerCase(Locale.ROOT)) != null) {
return Assert.notNullUnchecked(Objects.requireNonNull(synonyms.get(name.toLowerCase(Locale.ROOT))).apply(expressions.length));
} else {
return userDefinedFunctionMap.get(name.toLowerCase(Locale.ROOT));
}
}

@Override
public boolean containsFunction(@Nonnull String name) {
return synonyms.containsKey(name.toLowerCase(Locale.ROOT));
return synonyms.containsKey(name.toLowerCase(Locale.ROOT)) || userDefinedFunctionMap.containsKey(name.toLowerCase(Locale.ROOT));
}

@Override
public boolean isUdfFunction(@Nonnull final String name) {
return "java_call".equals(name.trim().toLowerCase(Locale.ROOT));
}

@Override
public void addUdfFunction(@Nonnull final UserDefinedFunction function) {
userDefinedFunctionMap.put(function.getFunctionName(), function);
}

@Nonnull
private static ImmutableMap<String, Function<Integer, BuiltInFunction<? extends Typed>>> createSynonyms() {
return ImmutableMap.<String, Function<Integer, BuiltInFunction<? extends Typed>>>builder()
Expand Down
Loading
Loading