From 53fcf27007ab151b32ac425b8b586a1970c9942c Mon Sep 17 00:00:00 2001 From: Simon Blanke Date: Sun, 7 Jan 2024 15:21:35 +0100 Subject: [PATCH] separate search_data and collect_data methods --- .../_base_objective_function.py | 25 ++++++++----------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/surfaces/mathematical_functions/_base_objective_function.py b/surfaces/mathematical_functions/_base_objective_function.py index b885fc1..37afd42 100644 --- a/surfaces/mathematical_functions/_base_objective_function.py +++ b/surfaces/mathematical_functions/_base_objective_function.py @@ -8,8 +8,7 @@ import pandas as pd from functools import reduce -from hyperactive import Hyperactive -from hyperactive.optimizers import GridSearchOptimizer +from gradient_free_optimizers import GridSearchOptimizer from .._base_test_function import BaseTestFunction @@ -41,7 +40,7 @@ def search_space(self, min=-5, max=5, step=0.1, value_types="array"): return search_space_ - def collect_data(self, if_exists="append"): + def search_data(self): self.search_space = self.search_space(value_types="list") para_names = list(self.search_space.keys()) @@ -53,26 +52,22 @@ def collect_data(self, if_exists="append"): search_space_size = reduce((lambda x, y: x * y), dim_sizes_list) while search_data_length < search_space_size: - hyper = Hyperactive(verbosity=["progress_bar"]) - hyper.add_search( - self.objective_function_dict, - self.search_space, - initialize={}, - n_iter=search_space_size, - optimizer=GridSearchOptimizer(direction="orthogonal"), - memory_warm_start=search_data, - ) - hyper.run() + print("\n search_space_size", search_space_size) + opt = GridSearchOptimizer(self.search_space, direction="orthogonal") + opt.search(self.objective_function_dict, n_iter=search_space_size) search_data = pd.concat( - [search_data, hyper.search_data(self.objective_function_dict)], + [search_data, opt.search_data], ignore_index=True, ) search_data = search_data.drop_duplicates(subset=para_names) search_data_length = len(search_data) + print("\n search_data_length", search_data_length, "\n") + return search_data - self.sql_data.save(self.__name__, search_data, if_exists) + def collect_data(self, if_exists="append"): + self.sql_data.save(self.__name__, self.search_data(), if_exists) def return_metric(self, loss): if self.metric == "score":