Skip to content

Commit

Permalink
separate search_data and collect_data methods
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBlanke committed Jan 7, 2024
1 parent ba1f132 commit 53fcf27
Showing 1 changed file with 10 additions and 15 deletions.
25 changes: 10 additions & 15 deletions surfaces/mathematical_functions/_base_objective_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import pandas as pd
from functools import reduce

from hyperactive import Hyperactive
from hyperactive.optimizers import GridSearchOptimizer
from gradient_free_optimizers import GridSearchOptimizer

from .._base_test_function import BaseTestFunction

Expand Down Expand Up @@ -41,7 +40,7 @@ def search_space(self, min=-5, max=5, step=0.1, value_types="array"):

return search_space_

def collect_data(self, if_exists="append"):
def search_data(self):
self.search_space = self.search_space(value_types="list")

para_names = list(self.search_space.keys())
Expand All @@ -53,26 +52,22 @@ def collect_data(self, if_exists="append"):
search_space_size = reduce((lambda x, y: x * y), dim_sizes_list)

while search_data_length < search_space_size:
hyper = Hyperactive(verbosity=["progress_bar"])
hyper.add_search(
self.objective_function_dict,
self.search_space,
initialize={},
n_iter=search_space_size,
optimizer=GridSearchOptimizer(direction="orthogonal"),
memory_warm_start=search_data,
)
hyper.run()
print("\n search_space_size", search_space_size)
opt = GridSearchOptimizer(self.search_space, direction="orthogonal")
opt.search(self.objective_function_dict, n_iter=search_space_size)

search_data = pd.concat(
[search_data, hyper.search_data(self.objective_function_dict)],
[search_data, opt.search_data],
ignore_index=True,
)

search_data = search_data.drop_duplicates(subset=para_names)
search_data_length = len(search_data)
print("\n search_data_length", search_data_length, "\n")
return search_data

self.sql_data.save(self.__name__, search_data, if_exists)
def collect_data(self, if_exists="append"):
self.sql_data.save(self.__name__, self.search_data(), if_exists)

def return_metric(self, loss):
if self.metric == "score":
Expand Down

0 comments on commit 53fcf27

Please sign in to comment.