Skip to content

Commit

Permalink
sklearn converted to numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
americast committed Nov 10, 2023
1 parent 1d7c30e commit 823e3fe
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions evadb/executor/create_function_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error

from evadb.catalog.catalog_utils import get_metadata_properties
from evadb.catalog.models.function_catalog import FunctionCatalogEntry
Expand Down Expand Up @@ -57,6 +56,10 @@
from evadb.utils.logging_manager import logger


def root_mean_squared_error(y_true, y_pred):
return np.sqrt(np.mean(np.square(y_pred - y_true)))


# From https://stackoverflow.com/a/34333710
@contextlib.contextmanager
def set_env(**environ):
Expand Down Expand Up @@ -602,12 +605,11 @@ def get_optuna_config(trial):
crossvalidation_df.unique_id == uid
]
rmses.append(
mean_squared_error(
root_mean_squared_error(
crossvalidation_df_here.y,
crossvalidation_df_here[
arg_map["model"] + "-median"
],
squared=False,
)
/ np.mean(crossvalidation_df_here.y)
)
Expand Down Expand Up @@ -643,10 +645,9 @@ def get_optuna_config(trial):
crossvalidation_df.unique_id == uid
]
rmses.append(
mean_squared_error(
root_mean_squared_error(
crossvalidation_df_here.y,
crossvalidation_df_here[arg_map["model"]],
squared=False,
)
/ np.mean(crossvalidation_df_here.y)
)
Expand Down

0 comments on commit 823e3fe

Please sign in to comment.