diff --git a/evadb/binder/statement_binder.py b/evadb/binder/statement_binder.py index a235bc3c6d..5b8cd75408 100644 --- a/evadb/binder/statement_binder.py +++ b/evadb/binder/statement_binder.py @@ -92,7 +92,9 @@ def _bind_create_function_statement(self, node: CreateFunctionStatement): outputs.append(column) else: inputs.append(column) - elif string_comparison_case_insensitive(node.function_type, "sklearn"): + elif string_comparison_case_insensitive( + node.function_type, "sklearn" + ) or string_comparison_case_insensitive(node.function_type, "XGBoost"): assert ( "predict" in arg_map ), f"Creating {node.function_type} functions expects 'predict' metadata." diff --git a/evadb/executor/create_function_executor.py b/evadb/executor/create_function_executor.py index e14b9cde7d..66c549c8ab 100644 --- a/evadb/executor/create_function_executor.py +++ b/evadb/executor/create_function_executor.py @@ -24,7 +24,6 @@ import numpy as np import pandas as pd -from sklearn.metrics import mean_squared_error from evadb.catalog.catalog_utils import get_metadata_properties from evadb.catalog.models.function_catalog import FunctionCatalogEntry