Skip to content

Commit

Permalink
Merge pull request #29 from Giggfitnesse/main
Browse files Browse the repository at this point in the history
Fixed the bugs #27
  • Loading branch information
WilliamLwj authored Mar 13, 2023
2 parents 24c2c25 + 90cc82d commit 5790c09
Show file tree
Hide file tree
Showing 9 changed files with 212 additions and 71 deletions.
21 changes: 15 additions & 6 deletions PyXAB/algos/GPO.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
# Author: Wenjie Li <[email protected]>
# License: MIT
import pdb

import numpy as np
from PyXAB.algos.Algo import Algorithm
Expand Down Expand Up @@ -41,6 +42,8 @@ def __init__(
raise ValueError("Partition of the parameter space is not given.")
if algo is None:
raise ValueError("Algorithm for GPO is not given")
if algo.__name__ != 'T_HOO' and algo.__name__ != 'HCT' and algo.__name__ != 'VHCT':
raise NotImplementedError('GPO has not yet included implementations for this algorithm')

self.rounds = rounds
self.rhomax = rhomax
Expand All @@ -52,7 +55,7 @@ def __init__(

# The big-N in the algorithm
self.N = np.ceil(
0.5 * self.Dmax * np.log(self.rounds / 2) / np.log(self.rounds / 2)
0.5 * self.Dmax * np.log((self.rounds / 2) / np.log(self.rounds / 2))
)

# phase number
Expand Down Expand Up @@ -87,11 +90,17 @@ def pull(self, time):
else:
if self.counter == 0:
rho = self.rhomax ** (2 * self.N / (2 * self.phase + 1))
# TODO: for algorithms that do not need nu or rho
self.curr_algo = self.algo(
nu=self.numax, rho=rho, domain=self.domain, partition=self.partition
)

if self.algo.__name__ == 'T_HOO':
self.curr_algo = self.algo(
nu=self.numax, rho=rho, rounds=self.rounds, domain=self.domain, partition=self.partition
)
elif self.algo.__name__ == 'HCT' or self.algo.__name__ == 'VHCT':
self.curr_algo = self.algo(
nu=self.numax, rho=rho, domain=self.domain, partition=self.partition
)
else:
# TODO: add more algorithms that do not need nu or rho
raise NotImplementedError('GPO has not yet included implementations for this algorithm')
if self.counter < self.half_phase_length:
point = self.curr_algo.pull(time)
self.goodx = point
Expand Down
2 changes: 1 addition & 1 deletion PyXAB/algos/PCT.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,4 @@ def get_last_point(self):
-------
"""
self.algorithm.get_last_point()
return self.algorithm.get_last_point()
66 changes: 39 additions & 27 deletions PyXAB/algos/POO.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
# Author: Wenjie Li <[email protected]>
# License: MIT
import pdb

import numpy as np
from PyXAB.algos.Algo import Algorithm
Expand Down Expand Up @@ -41,6 +42,8 @@ def __init__(
raise ValueError("Partition of the parameter space is not given.")
if algo is None:
raise ValueError("Algorithm for POO is not given")
if algo.__name__ != 'T_HOO' and algo.__name__ != 'HCT' and algo.__name__ != 'VHCT':
raise NotImplementedError('POO has not yet included implementations for this algorithm')

self.rounds = rounds
self.rhomax = rhomax
Expand All @@ -65,6 +68,7 @@ def __init__(
# The cross-validation list
self.V_algo = []
self.V_reward = []
self.Times = []

def pull(self, time):
"""
Expand All @@ -82,33 +86,28 @@ def pull(self, time):
"""

if self.N <= 0.5 * self.Dmax * np.log(self.n / np.log(self.n)):

if self.counter == 0:
rho = self.rhomax ** (2 * self.N / (2 * self.phase + 1))
self.curr_algo = self.algo(
nu=self.numax, rho=rho, domain=self.domain, partition=self.partition
)
if self.algo.__name__ == 'T_HOO':
self.curr_algo = self.algo(
nu=self.numax, rho=rho, rounds=self.rounds, domain=self.domain, partition=self.partition
)
elif self.algo.__name__ == 'HCT' or self.algo.__name__ == 'VHCT':
self.curr_algo = self.algo(
nu=self.numax, rho=rho, domain=self.domain, partition=self.partition
)
else:
# TODO: add more algorithms that do not need nu or rho
raise NotImplementedError('POO has not yet included implementations for this algorithm')
self.V_algo.append(self.curr_algo)
self.V_reward.append(0)
point = self.curr_algo.pull(time)

if self.counter >= np.ceil(self.N / self.n):
self.counter = 0
self.phase += 1
self.Times.append(0)
point = self.V_algo[-1].pull(time)

# Refresh, change n and N
if self.phase >= self.N:
self.n = 2 * self.n
self.N = 2 * self.N
self.phase = 0
self.counter = 0
self.algo_counter = 0
else:
algo = self.V_algo[self.algo_counter]
point = algo.pull(time)
self.algo_counter += 1
if self.algo_counter == len(self.V_algo):
self.algo_counter = 0
self.n = self.n + self.N

return point

Expand All @@ -129,17 +128,33 @@ def receive_reward(self, time, reward):
"""
if self.N <= 0.5 * self.Dmax * np.log(self.n / np.log(self.n)):
self.curr_algo.receive_reward(time, reward)
self.V_reward[-1] = (self.V_reward[-1] * (self.counter) + reward) / (
self.V_algo[-1].receive_reward(time, reward)
self.V_reward[-1] = (self.V_reward[-1] * self.counter + reward) / (
self.counter + 1
)
self.Times[-1] += 1
self.counter += 1
if self.counter >= np.ceil(self.n / self.N):
self.counter = 0
self.phase += 1

# Refresh, change n and N
if self.phase >= self.N:
self.n = 2 * self.n
self.N = 2 * self.N
self.phase = 0
self.counter = 0
self.algo_counter = 0
else:
self.V_algo[self.algo_counter].receive_reward(time, reward)
self.V_reward[self.algo_counter] = (
self.V_reward[self.algo_counter] * np.ceil(self.N / self.n) + reward
) / (np.ceil(self.N / self.n) + 1)
self.V_reward[self.algo_counter] * np.ceil(self.n / self.N) + reward
) / (np.ceil(self.n / self.N) + 1)
self.Times[self.algo_counter] += 1
self.algo_counter += 1
if self.algo_counter == len(self.V_algo):
self.algo_counter = 0
self.n = self.n + self.N

def get_last_point(self):
"""
Expand All @@ -150,9 +165,6 @@ def get_last_point(self):
"""
V_reward = np.array(self.V_reward)

max_param = np.argmax(V_reward)

point = self.V_algo[max_param].pull(time=self.rounds)

point = self.V_algo[max_param].pull(time=0)
return point
2 changes: 1 addition & 1 deletion PyXAB/algos/VPCT.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,4 @@ def get_last_point(self):
-------
"""
self.algorithm.get_last_point()
return self.algorithm.get_last_point()
52 changes: 47 additions & 5 deletions PyXAB/tests/test_algos/test_GPO.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,81 @@
from PyXAB.synthetic_obj import *

from PyXAB.algos.HCT import HCT
from PyXAB.algos.HOO import T_HOO
from PyXAB.algos.VHCT import VHCT
from PyXAB.algos.POO import POO
from PyXAB.algos.StoSOO import StoSOO
from PyXAB.algos.GPO import GPO
from PyXAB.partition.BinaryPartition import BinaryPartition
from PyXAB.utils.plot import compare_regret
import numpy as np
import pytest


def test_GPO_1():
def test_GPO_value_error_1():
# no domain
algo = HCT
partition = BinaryPartition
with pytest.raises(ValueError):
GPO(partition=partition, algo=algo)


def test_GPO_2():
def test_GPO_value_error_2():
# no partition
algo = HCT
domain = [[-5, 5], [-5, 5]]
with pytest.raises(ValueError):
GPO(domain=domain, algo=algo)


def test_GPO_3():
def test_GPO_value_error_3():
# no algorithm
domain = [[-5, 5], [-5, 5]]
partition = BinaryPartition
with pytest.raises(ValueError):
GPO(domain=domain, partition=partition)


def test_GPO_4():

def test_GPO_not_implemented_error():
# wrong algorithm
T = 100
Target = Garland.Garland()
domain = [[0, 1]]
partition = BinaryPartition
with pytest.raises(NotImplementedError):
GPO(rounds=T, domain=domain, partition=partition, algo=StoSOO)


def test_GPO_HOO_Garland():

T = 1000
Target = Garland.Garland()
domain = [[0, 1]]
partition = BinaryPartition

algo = GPO(rounds=T, domain=domain, partition=partition, algo=T_HOO)
GPO_regret_list = []
GPO_regret = 0

for t in range(1, T + 1):
point = algo.pull(t)
reward = Target.f(point) + np.random.uniform(-0.1, 0.1)
algo.receive_reward(t, reward)
inst_regret = Target.fmax - Target.f(point)
GPO_regret += inst_regret
GPO_regret_list.append(GPO_regret)
p = algo.get_last_point()
inst_regret = Target.fmax - Target.f(p)

# regret_dic = {"GPO": np.array(GPO_regret_list),}
# compare_regret(regret_dic)

def test_GPO_HCT_Garland():

T = 1000
Target = Garland.Garland()
domain = [[0, 1]]
partition = BinaryPartition

algo = GPO(rounds=T, domain=domain, partition=partition, algo=HCT)
GPO_regret_list = []
Expand All @@ -51,3 +92,4 @@ def test_GPO_4():
# plot the regret
# regret_dic = {"GPO": np.array(GPO_regret_list)}
# compare_regret(regret_dic)

31 changes: 27 additions & 4 deletions PyXAB/tests/test_algos/test_PCT.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,43 @@
import pytest


def test_PCT_1():
def test_PCT_value_error_1():
partition = BinaryPartition
with pytest.raises(ValueError):
algo = PCT(partition=partition)


def test_PCT_2():
def test_PCT_value_error_2():
domain = [[-5, 5], [-5, 5]]
with pytest.raises(ValueError):
algo = PCT(domain=domain)

def test_PCT_Garland():
T = 1000
target = Garland.Garland()
domain = [[0, 1]]
partition = BinaryPartition
algo = PCT(rounds=T, domain=domain, partition=partition)

cumulative_regret = 0
cumulative_regret_list = []

## uniform noise

for t in range(1, T + 1):
point = algo.pull(t)
reward = target.f(point) + np.random.uniform(-0.1, 0.1)
algo.receive_reward(t, reward)
inst_regret = target.fmax - target.f(point)
cumulative_regret += inst_regret
cumulative_regret_list.append(cumulative_regret)

# plot_regret(np.array(cumulative_regret_list))

print('PCT: ', algo.get_last_point())

def test_PCT_3():
T = 100
def test_PCT_Himmelblau():
T = 1000
target = Himmelblau.Himmelblau()
domain = [[-5, 5], [-5, 5]]
partition = BinaryPartition
Expand Down
Loading

0 comments on commit 5790c09

Please sign in to comment.