Skip to content

Commit

Permalink
add collect and load search-data methods for test-functions
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBlanke committed Dec 29, 2023
1 parent 55ed026 commit 25002ab
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions surfaces/test_functions/_base_objective_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
from hyperactive import Hyperactive
from hyperactive.optimizers import GridSearchOptimizer

from ..machine_learning.data_collector import SurfacesDataCollector


class ObjectiveFunction:
def __init__(self, metric="score", input_type="dictionary", sleep=0):
self.metric = metric
self.input_type = input_type
self.sleep = sleep

self.sql_data = SurfacesDataCollector()

def search_space(self, min=-5, max=5, step=0.1, value_typ="array"):
search_space_ = {}

Expand All @@ -31,8 +35,7 @@ def search_space(self, min=-5, max=5, step=0.1, value_typ="array"):

return search_space_

@property
def search_data(self):
def collect_data(self, if_exists="append"):
para_names = list(self.search_space().keys())
search_data_cols = para_names + ["score"]
search_data = pd.DataFrame([], columns=search_data_cols)
Expand All @@ -44,7 +47,7 @@ def search_data(self):
while search_data_length < search_space_size:
hyper = Hyperactive(verbosity=["progress_bar"])
hyper.add_search(
self,
self.objective_function_dict,
self.search_space(value_typ="list"),
initialize={},
n_iter=search_space_size,
Expand All @@ -54,12 +57,21 @@ def search_data(self):
hyper.run()

search_data = pd.concat(
[search_data, hyper.search_data(self)], ignore_index=True
[search_data, hyper.search_data(self.objective_function_dict)],
ignore_index=True,
)

search_data = search_data.drop_duplicates(subset=para_names)
search_data_length = len(search_data)

return search_data
self.sql_data.save(self.__name__, search_data, if_exists)

def load_search_data(self):
try:
dataframe = self.sql_data.load(self.__name__)
except:
print("Path 2 database: ", self.sql_data.path)
return dataframe

def return_metric(self, loss):
if self.metric == "score":
Expand Down

0 comments on commit 25002ab

Please sign in to comment.