Skip to content

Commit

Permalink
Create UDFs with function instance
Browse files Browse the repository at this point in the history
  • Loading branch information
gary-peng committed Nov 26, 2023
1 parent 0ed076c commit 6828b53
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 23 deletions.
49 changes: 28 additions & 21 deletions evadb/functions/simple_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
import pandas as pd
import importlib
import inspect
import pickle
from pathlib import Path

from evadb.catalog.catalog_type import NdArrayType
from evadb.functions.abstract.abstract_function import AbstractFunction
from evadb.functions.decorators.decorators import forward, setup
from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe
from evadb.configuration.constants import EvaDB_ROOT_DIR

class SimpleUDF(AbstractFunction):
@setup(cacheable=False, function_type="SimpleUDF", batchable=False)
Expand Down Expand Up @@ -60,27 +62,32 @@ def _forward(row: pd.Series) -> np.ndarray:
return ret

def set_udf(self, classname:str, filepath: str):
try:
abs_path = Path(filepath).resolve()
spec = importlib.util.spec_from_file_location(abs_path.stem, abs_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
except ImportError as e:
# ImportError in the case when we are able to find the file but not able to load the module
err_msg = f"ImportError : Couldn't load function from {filepath} : {str(e)}. Not able to load the code provided in the file {abs_path}. Please ensure that the file contains the implementation code for the function."
raise ImportError(err_msg)
except FileNotFoundError as e:
# FileNotFoundError in the case when we are not able to find the file at all at the path.
err_msg = f"FileNotFoundError : Couldn't load function from {filepath} : {str(e)}. This might be because the function implementation file does not exist. Please ensure the file exists at {abs_path}"
raise FileNotFoundError(err_msg)
except Exception as e:
# Default exception, we don't know what exactly went wrong so we just output the error message
err_msg = f"Couldn't load function from {filepath} : {str(e)}."
raise RuntimeError(err_msg)

# Try to load the specified class by name
if classname and hasattr(module, classname):
self.udf = getattr(module, classname)
print("lmaoaa: ", filepath)
if f"{EvaDB_ROOT_DIR}/simple_udfs/" in filepath:
f = open(f"{EvaDB_ROOT_DIR}/simple_udfs/Func_SimpleUDF", 'rb')
self.udf = pickle.load(f)
else:
try:
abs_path = Path(filepath).resolve()
spec = importlib.util.spec_from_file_location(abs_path.stem, abs_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
except ImportError as e:
# ImportError in the case when we are able to find the file but not able to load the module
err_msg = f"ImportError : Couldn't load function from {filepath} : {str(e)}. Not able to load the code provided in the file {abs_path}. Please ensure that the file contains the implementation code for the function."
raise ImportError(err_msg)
except FileNotFoundError as e:
# FileNotFoundError in the case when we are not able to find the file at all at the path.
err_msg = f"FileNotFoundError : Couldn't load function from {filepath} : {str(e)}. This might be because the function implementation file does not exist. Please ensure the file exists at {abs_path}"
raise FileNotFoundError(err_msg)
except Exception as e:
# Default exception, we don't know what exactly went wrong so we just output the error message
err_msg = f"Couldn't load function from {filepath} : {str(e)}."
raise RuntimeError(err_msg)

# Try to load the specified class by name
if classname and hasattr(module, classname):
self.udf = getattr(module, classname)

self.signature = inspect.signature(self.udf)

Expand Down
27 changes: 26 additions & 1 deletion evadb/interfaces/relational/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
import multiprocessing

import pandas
import pickle

from evadb.configuration.constants import EvaDB_DATABASE_DIR
from evadb.configuration.constants import EvaDB_DATABASE_DIR, EvaDB_ROOT_DIR
from evadb.database import EvaDBDatabase, init_evadb_instance
from evadb.expression.tuple_value_expression import TupleValueExpression
from evadb.functions.function_bootstrap_queries import init_builtin_functions
Expand Down Expand Up @@ -413,6 +414,30 @@ def create_function(
function_name, if_not_exists, impl_path, type, **kwargs
)
return EvaDBQuery(self._evadb, stmt)

def create_simple_function(
self,
function_name: str,
function: callable,
if_not_exists: bool = True,
) -> "EvaDBQuery":
"""
Create a function in the database by passing in a function instance.
Args:
function_name (str): Name of the function to be created.
if_not_exists (bool): If True, do not raise an error if the function already exist. If False, raise an error.
function (callable): The function instance
Returns:
EvaDBQuery: The EvaDBQuery object representing the function created.
"""
impl_path = f"{EvaDB_ROOT_DIR}/simple_udfs/{function_name}"
f = open(impl_path, 'ab')
pickle.dump(function, f)
f.close()

return self.create_function(function_name, if_not_exists, impl_path)

def create_table(
self, table_name: str, if_not_exists: bool = True, columns: str = None, **kwargs
Expand Down
22 changes: 21 additions & 1 deletion test/integration_tests/long/test_simple_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from evadb.interfaces.relational.db import connect
from evadb.server.command_handler import execute_query_fetch_all

def Func_SimpleUDF(cls, x:int)->int:
return x + 10

@pytest.mark.notparallel
class SimpleFunctionTests(unittest.TestCase):
Expand All @@ -34,6 +36,7 @@ def setUp(self):
def tearDown(self):
execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS test_table;")
execute_query_fetch_all(self.evadb, "DROP FUNCTION IF EXISTS My_SimpleUDF;")
execute_query_fetch_all(self.evadb, "DROP FUNCTION IF EXISTS Func_SimpleUDF;")

def test_from_file(self):
cursor = self.conn.cursor()
Expand All @@ -50,4 +53,21 @@ def test_from_file(self):
result = cursor.query("SELECT My_SimpleUDF(val) FROM test_table;").df()
expected = pd.DataFrame({'output': [6]})

self.assertTrue(expected.equals(result))
self.assertTrue(expected.equals(result))

def test_from_function(self):
cursor = self.conn.cursor()

execute_query_fetch_all(self.evadb, "CREATE TABLE IF NOT EXISTS test_table (val INTEGER);")
cursor.insert("test_table", "(val)", "(1)").df()

cursor.create_simple_function(
"Func_SimpleUDF",
Func_SimpleUDF,
True,
).df()

result = cursor.query("SELECT Func_SimpleUDF(val) FROM test_table;").df()
expected = pd.DataFrame({'output': [11]})

self.assertTrue(expected.equals(result))

0 comments on commit 6828b53

Please sign in to comment.