Skip to content

Commit

Permalink
refactor base-classes for test-functions
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBlanke committed Apr 27, 2024
1 parent 6a916f6 commit 0e0fddb
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 55 deletions.
46 changes: 42 additions & 4 deletions src/surfaces/test_functions/_base_test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

import time

import numpy as np
import pandas as pd

from ..data_collector import SurfacesDataCollector


class BaseTestFunction:
explanation = """ """
Expand All @@ -15,9 +20,20 @@ class BaseTestFunction:
objective_function: callable
pure_objective_function: callable

def __init__(self):
def __init__(self, metric, sleep=0, evaluate_from_data=False):
self.sleep = sleep
self.metric = metric

self.create_objective_function()

if evaluate_from_data:
self.sdc = SurfacesDataCollector()
self.objective_function = self.objective_function_loaded
else:
self.objective_function = self.pure_objective_function

self.objective_function.__name__ = self.__name__

def create_objective_function(self):
e_msg = "'create_objective_function'-method is not implemented"
raise NotImplementedError(e_msg)
Expand All @@ -26,9 +42,6 @@ def search_space(self):
e_msg = "'search_space'-method is not implemented"
raise NotImplementedError(e_msg)

def load_search_data(self):
return self.sql_data.load(self.__name__)

def return_metric(self, loss):
if self.metric == "score":
return -loss
Expand All @@ -52,3 +65,28 @@ def objective_function(self, *input):
metric = self.objective_function_np(*input)

return self.return_metric(metric)

def objective_function_loaded(self, params):
try:
parameter_d = params.para_dict
except AttributeError:
parameter_d = params

for para_names, dim_value in parameter_d.items():
try:
parameter_d[para_names] = dim_value.__name__
except AttributeError:
pass

search_data = self.sdc.load(self.__name__)
if search_data is None:
msg = "Search Data is empty"
raise TypeError(msg)

params_df = pd.DataFrame(parameter_d, index=[0])

para_df_row = search_data[
np.all(search_data[self.para_names].values == params_df.values, axis=1)
]
score = para_df_row["score"].values[0]
return score
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,9 @@
# License: MIT License


import numpy as np
import pandas as pd
from functools import reduce

from hyperactive import Hyperactive
from hyperactive.optimizers import GridSearchOptimizer
from .._base_test_function import BaseTestFunction


class MachineLearningFunction(BaseTestFunction):
def __init__(self, metric=None, sleep=0, load_search_data=False):
super().__init__()

self.metric = metric
self.sleep = sleep

self.create_objective_function()

if load_search_data:
self.objective_function = self.objective_function_loaded
else:
self.objective_function = self.pure_objective_function

# self.objective_function.__func__.__name__ = self.__name__

def objective_function_loaded(self, params):
try:
parameter_d = params.para_dict
except AttributeError:
parameter_d = params

for para_names, dim_value in parameter_d.items():
try:
parameter_d[para_names] = dim_value.__name__
except AttributeError:
pass

search_data = self.load_search_data()
if search_data is None:
msg = "Search Data is empty"
raise TypeError(msg)

params_df = pd.DataFrame(parameter_d, index=[0])

para_df_row = search_data[
np.all(search_data[self.para_names].values == params_df.values, axis=1)
]
score = para_df_row["score"].values[0]
return score
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@


import numpy as np
import pandas as pd
from functools import reduce

from gradient_free_optimizers import GridSearchOptimizer

from .._base_test_function import BaseTestFunction

Expand All @@ -19,8 +15,15 @@ class MathematicalFunction(BaseTestFunction):
formula = r" "
global_minimum = r" "

def __init__(self, metric="score", input_type="dictionary", sleep=0):
super().__init__()
def __init__(
self,
*args,
metric="loss",
input_type="dictionary",
sleep=0,
**kwargs,
):
super().__init__(*args, metric, **kwargs)

self.metric = metric
self.input_type = input_type
Expand Down

0 comments on commit 0e0fddb

Please sign in to comment.