-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #29 from Giggfitnesse/main
Fixed the bugs #27
- Loading branch information
Showing
9 changed files
with
212 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
""" | ||
# Author: Wenjie Li <[email protected]> | ||
# License: MIT | ||
import pdb | ||
|
||
import numpy as np | ||
from PyXAB.algos.Algo import Algorithm | ||
|
@@ -41,6 +42,8 @@ def __init__( | |
raise ValueError("Partition of the parameter space is not given.") | ||
if algo is None: | ||
raise ValueError("Algorithm for GPO is not given") | ||
if algo.__name__ != 'T_HOO' and algo.__name__ != 'HCT' and algo.__name__ != 'VHCT': | ||
raise NotImplementedError('GPO has not yet included implementations for this algorithm') | ||
|
||
self.rounds = rounds | ||
self.rhomax = rhomax | ||
|
@@ -52,7 +55,7 @@ def __init__( | |
|
||
# The big-N in the algorithm | ||
self.N = np.ceil( | ||
0.5 * self.Dmax * np.log(self.rounds / 2) / np.log(self.rounds / 2) | ||
0.5 * self.Dmax * np.log((self.rounds / 2) / np.log(self.rounds / 2)) | ||
) | ||
|
||
# phase number | ||
|
@@ -87,11 +90,17 @@ def pull(self, time): | |
else: | ||
if self.counter == 0: | ||
rho = self.rhomax ** (2 * self.N / (2 * self.phase + 1)) | ||
# TODO: for algorithms that do not need nu or rho | ||
self.curr_algo = self.algo( | ||
nu=self.numax, rho=rho, domain=self.domain, partition=self.partition | ||
) | ||
|
||
if self.algo.__name__ == 'T_HOO': | ||
self.curr_algo = self.algo( | ||
nu=self.numax, rho=rho, rounds=self.rounds, domain=self.domain, partition=self.partition | ||
) | ||
elif self.algo.__name__ == 'HCT' or self.algo.__name__ == 'VHCT': | ||
self.curr_algo = self.algo( | ||
nu=self.numax, rho=rho, domain=self.domain, partition=self.partition | ||
) | ||
else: | ||
# TODO: add more algorithms that do not need nu or rho | ||
raise NotImplementedError('GPO has not yet included implementations for this algorithm') | ||
if self.counter < self.half_phase_length: | ||
point = self.curr_algo.pull(time) | ||
self.goodx = point | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -90,4 +90,4 @@ def get_last_point(self): | |
------- | ||
""" | ||
self.algorithm.get_last_point() | ||
return self.algorithm.get_last_point() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
""" | ||
# Author: Wenjie Li <[email protected]> | ||
# License: MIT | ||
import pdb | ||
|
||
import numpy as np | ||
from PyXAB.algos.Algo import Algorithm | ||
|
@@ -41,6 +42,8 @@ def __init__( | |
raise ValueError("Partition of the parameter space is not given.") | ||
if algo is None: | ||
raise ValueError("Algorithm for POO is not given") | ||
if algo.__name__ != 'T_HOO' and algo.__name__ != 'HCT' and algo.__name__ != 'VHCT': | ||
raise NotImplementedError('POO has not yet included implementations for this algorithm') | ||
|
||
self.rounds = rounds | ||
self.rhomax = rhomax | ||
|
@@ -65,6 +68,7 @@ def __init__( | |
# The cross-validation list | ||
self.V_algo = [] | ||
self.V_reward = [] | ||
self.Times = [] | ||
|
||
def pull(self, time): | ||
""" | ||
|
@@ -82,33 +86,28 @@ def pull(self, time): | |
""" | ||
|
||
if self.N <= 0.5 * self.Dmax * np.log(self.n / np.log(self.n)): | ||
|
||
if self.counter == 0: | ||
rho = self.rhomax ** (2 * self.N / (2 * self.phase + 1)) | ||
self.curr_algo = self.algo( | ||
nu=self.numax, rho=rho, domain=self.domain, partition=self.partition | ||
) | ||
if self.algo.__name__ == 'T_HOO': | ||
self.curr_algo = self.algo( | ||
nu=self.numax, rho=rho, rounds=self.rounds, domain=self.domain, partition=self.partition | ||
) | ||
elif self.algo.__name__ == 'HCT' or self.algo.__name__ == 'VHCT': | ||
self.curr_algo = self.algo( | ||
nu=self.numax, rho=rho, domain=self.domain, partition=self.partition | ||
) | ||
else: | ||
# TODO: add more algorithms that do not need nu or rho | ||
raise NotImplementedError('POO has not yet included implementations for this algorithm') | ||
self.V_algo.append(self.curr_algo) | ||
self.V_reward.append(0) | ||
point = self.curr_algo.pull(time) | ||
|
||
if self.counter >= np.ceil(self.N / self.n): | ||
self.counter = 0 | ||
self.phase += 1 | ||
self.Times.append(0) | ||
point = self.V_algo[-1].pull(time) | ||
|
||
# Refresh, change n and N | ||
if self.phase >= self.N: | ||
self.n = 2 * self.n | ||
self.N = 2 * self.N | ||
self.phase = 0 | ||
self.counter = 0 | ||
self.algo_counter = 0 | ||
else: | ||
algo = self.V_algo[self.algo_counter] | ||
point = algo.pull(time) | ||
self.algo_counter += 1 | ||
if self.algo_counter == len(self.V_algo): | ||
self.algo_counter = 0 | ||
self.n = self.n + self.N | ||
|
||
return point | ||
|
||
|
@@ -129,17 +128,33 @@ def receive_reward(self, time, reward): | |
""" | ||
if self.N <= 0.5 * self.Dmax * np.log(self.n / np.log(self.n)): | ||
self.curr_algo.receive_reward(time, reward) | ||
self.V_reward[-1] = (self.V_reward[-1] * (self.counter) + reward) / ( | ||
self.V_algo[-1].receive_reward(time, reward) | ||
self.V_reward[-1] = (self.V_reward[-1] * self.counter + reward) / ( | ||
self.counter + 1 | ||
) | ||
self.Times[-1] += 1 | ||
self.counter += 1 | ||
if self.counter >= np.ceil(self.n / self.N): | ||
self.counter = 0 | ||
self.phase += 1 | ||
|
||
# Refresh, change n and N | ||
if self.phase >= self.N: | ||
self.n = 2 * self.n | ||
self.N = 2 * self.N | ||
self.phase = 0 | ||
self.counter = 0 | ||
self.algo_counter = 0 | ||
else: | ||
self.V_algo[self.algo_counter].receive_reward(time, reward) | ||
self.V_reward[self.algo_counter] = ( | ||
self.V_reward[self.algo_counter] * np.ceil(self.N / self.n) + reward | ||
) / (np.ceil(self.N / self.n) + 1) | ||
self.V_reward[self.algo_counter] * np.ceil(self.n / self.N) + reward | ||
) / (np.ceil(self.n / self.N) + 1) | ||
self.Times[self.algo_counter] += 1 | ||
self.algo_counter += 1 | ||
if self.algo_counter == len(self.V_algo): | ||
self.algo_counter = 0 | ||
self.n = self.n + self.N | ||
|
||
def get_last_point(self): | ||
""" | ||
|
@@ -150,9 +165,6 @@ def get_last_point(self): | |
""" | ||
V_reward = np.array(self.V_reward) | ||
|
||
max_param = np.argmax(V_reward) | ||
|
||
point = self.V_algo[max_param].pull(time=self.rounds) | ||
|
||
point = self.V_algo[max_param].pull(time=0) | ||
return point |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -92,4 +92,4 @@ def get_last_point(self): | |
------- | ||
""" | ||
self.algorithm.get_last_point() | ||
return self.algorithm.get_last_point() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.