From 25002abbbe5ee886414d0b7aca3ce0d8c7dbc2ff Mon Sep 17 00:00:00 2001 From: Simon Blanke Date: Fri, 29 Dec 2023 19:05:29 +0100 Subject: [PATCH] add collect and load search-data methods for test-functions --- .../_base_objective_function.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/surfaces/test_functions/_base_objective_function.py b/surfaces/test_functions/_base_objective_function.py index 8150750..a2f2b83 100644 --- a/surfaces/test_functions/_base_objective_function.py +++ b/surfaces/test_functions/_base_objective_function.py @@ -11,6 +11,8 @@ from hyperactive import Hyperactive from hyperactive.optimizers import GridSearchOptimizer +from ..machine_learning.data_collector import SurfacesDataCollector + class ObjectiveFunction: def __init__(self, metric="score", input_type="dictionary", sleep=0): @@ -18,6 +20,8 @@ def __init__(self, metric="score", input_type="dictionary", sleep=0): self.input_type = input_type self.sleep = sleep + self.sql_data = SurfacesDataCollector() + def search_space(self, min=-5, max=5, step=0.1, value_typ="array"): search_space_ = {} @@ -31,8 +35,7 @@ def search_space(self, min=-5, max=5, step=0.1, value_typ="array"): return search_space_ - @property - def search_data(self): + def collect_data(self, if_exists="append"): para_names = list(self.search_space().keys()) search_data_cols = para_names + ["score"] search_data = pd.DataFrame([], columns=search_data_cols) @@ -44,7 +47,7 @@ def search_data(self): while search_data_length < search_space_size: hyper = Hyperactive(verbosity=["progress_bar"]) hyper.add_search( - self, + self.objective_function_dict, self.search_space(value_typ="list"), initialize={}, n_iter=search_space_size, @@ -54,12 +57,21 @@ def search_data(self): hyper.run() search_data = pd.concat( - [search_data, hyper.search_data(self)], ignore_index=True + [search_data, hyper.search_data(self.objective_function_dict)], + ignore_index=True, ) + search_data = search_data.drop_duplicates(subset=para_names) search_data_length = len(search_data) - return search_data + self.sql_data.save(self.__name__, search_data, if_exists) + + def load_search_data(self): + try: + dataframe = self.sql_data.load(self.__name__) + except: + print("Path 2 database: ", self.sql_data.path) + return dataframe def return_metric(self, loss): if self.metric == "score":