Skip to content

Commit 950b896

Browse files
committed
setup.py to run exmples accessing package globally
1 parent f44f017 commit 950b896

File tree

10 files changed

+82
-47
lines changed

10 files changed

+82
-47
lines changed

examples/blackjack.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import random
22

3+
from rl.solvers import alpha_mc
4+
35
VALUES = ['A','2','3','4','5','6','7','8','9','10','J','Q','K']
46
SUITS = ['♠','♥','♦','♣']
57
CARDS = [(value,suit) for value in VALUES for suit in SUITS]
@@ -59,7 +61,6 @@ def black_jack_transition(state, action):
5961
new_state = (player_sum, usable_ace, dealer_showing)
6062
return (new_state, 0.), False
6163

62-
6364
dealer_cards = [dealer_showing]
6465
dealer_sum = count(dealer_cards)
6566
if action == 'stand':
@@ -79,4 +80,10 @@ def black_jack_transition(state, action):
7980
elif dealer_sum < player_sum:
8081
return (state, 1.), True
8182
elif dealer_sum == player_sum:
82-
return (state, 0.), True
83+
return (state, 0.), True
84+
85+
86+
vqpi, samples = alpha_mc(states, actions, black_jack_transition, gamma=0.9,
87+
use_N=True, n_episodes=1E4, first_visit=False)
88+
89+

examples/gridworld.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import numpy as np
77

8-
from src.mdp import MDP, TabularReward
8+
from rl.mdp import MDP, TabularReward
99

1010
GRID_SIZE = 5 # 5x5 gridworld
1111

rl/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from .model_free import ModelFree, ModelFreePolicy, EpsilonSoftPolicy
2+
from .solvers import (
3+
tdn,
4+
alpha_mc,
5+
off_policy_mc
6+
)
7+
8+
9+
__all__ = [
10+
'utils',
11+
'ModelFree',
12+
'ModelFreePolicy',
13+
'EpsilonSoftPolicy',
14+
'tdn',
15+
'alpha_mc',
16+
'off_policy_mc',
17+
]

src/armed_bandits.py renamed to rl/armed_bandits.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
import numpy.random as rnd
99

10-
from utils import Policy, RewardGenerator
10+
from rl.utils import Policy, RewardGenerator
1111

1212

1313
GAUSSIAN = [RewardGenerator('normal', rnd.random(), rnd.random()) for _ in range(10)]

src/mdp.py renamed to rl/mdp.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import numpy as np
1616

17-
from utils import (
17+
from rl.utils import (
1818
State,
1919
Action,
2020
Policy,
@@ -143,7 +143,7 @@ def π(self, state: int):
143143

144144
def __call__(self, state: int) -> np.ndarray:
145145
'''
146-
collapses the policy to a single action, i.e. a sample from the
146+
Collapses the policy to a single action, i.e. a sample from the
147147
random variable that represents the policy.
148148
'''
149149
return np.random.choice(self.pi_sa[state], p=self.pi_sa[state])
@@ -220,9 +220,9 @@ def vq_pi(
220220
method: str = 'iter_n'
221221
) -> np.ndarray:
222222
'''
223-
Individual state value functions and action-value functions
224-
vpi and qpi cannot be calculated for bigger problems. That
225-
constraint will give rise to parametrizations via DL.
223+
Individual state value functions and action-value functions
224+
vpi and qpi cannot be calculated for bigger problems. That
225+
constraint will give rise to parametrizations via DL.
226226
'''
227227
policy = policy if policy else self.policy
228228
solver = self.VQ_PI_SOLVERS.get(method)
@@ -237,9 +237,9 @@ def optimize_policy(
237237
policy: MarkovPolicy = None
238238
) -> MarkovPolicy:
239239
'''
240-
Optimal policy is the policy that maximizes the expected
241-
discounted return. It is the policy that maximizes the
242-
value function for each possible state.
240+
Optimal policy is the policy that maximizes the expected
241+
discounted return. It is the policy that maximizes the
242+
value function for each possible state.
243243
'''
244244
policy = policy if policy else self.policy
245245
solver = self.OPTIMAL_POLICY_SOLVERS.get(method)

src/model_free.py renamed to rl/model_free.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414

1515
import numpy as np
1616

17-
from utils import Policy, State, Action, MAX_ITER, MAX_STEPS
18-
from solvers import (
19-
first_visit_monte_carlo,
20-
every_visit_monte_carlo,
21-
off_policy_first_visit,
22-
off_policy_every_visit,
23-
tdn
17+
from rl.utils import (
18+
Policy,
19+
State,
20+
Action,
21+
StateAction,
22+
MAX_ITER,
23+
MAX_STEPS
2424
)
2525

2626
EpisodeStep = NewType(
@@ -83,6 +83,8 @@ def __init__(
8383
self.policy = policy
8484
self.states = State(states)
8585
self.actions = Action(actions)
86+
self.stateaction = StateAction(
87+
[(s,a) for s,a in zip(states, actions)])
8688
self.transition = transition
8789
self.gamma = gamma
8890
self.policy = policy if policy else ModelFreePolicy(

src/solvers.py renamed to rl/solvers.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
import numpy as np
1111
from numpy.linalg import norm as lnorm
1212

13-
from model_free import (
13+
from rl.model_free import (
1414
ModelFree,
1515
ModelFreePolicy,
1616
EpsilonSoftPolicy
1717
)
18-
from utils import (
18+
from rl.utils import (
1919
Policy,
2020
_typecheck_all,
2121
_get_sample_step,
@@ -31,8 +31,9 @@
3131
)
3232

3333

34-
def get_sample(v, q, π, n_episode, optimize):
35-
_idx, _v, _q = n_episode, Vpi(v.copy()), Qpi(q.copy())
34+
def get_sample(MF, v, q, π, n_episode, optimize):
35+
_idx = n_episode
36+
_v, _q = Vpi(v.copy(), MF.states), Qpi(q.copy(), MF.stateaction)
3637
_pi = None
3738
if optimize:
3839
_pi = π.pi.copy()
@@ -118,10 +119,9 @@ def value_iteration(MDP, policy: Policy = None, tol: float = TOL,
118119
policy.update_policy(qᵢ)
119120

120121

121-
122122
def alpha_mc(states: Sequence[Any], actions: Sequence[Any], transition: Transition,
123123
gamma: float=0.9, alpha: float=0.05, use_N :bool=False, first_visit: bool=True,
124-
exploring_starts: bool=False, n_episodes: int=MAX_ITER, max_steps: int=MAX_STEPS,
124+
exploring_starts: bool=True, n_episodes: int=MAX_ITER, max_steps: int=MAX_STEPS,
125125
samples: int=1000, optimize: bool=False, policy: ModelFreePolicy=None,
126126
eps: float=None) -> Tuple[VQPi, Samples]:
127127
'''α-MC state and action-value function estimation, policy optimization
@@ -176,9 +176,9 @@ def alpha_mc(states: Sequence[Any], actions: Sequence[Any], transition: Transiti
176176
'''
177177
if not policy and eps:
178178
_check_ranges(values=[eps], ranges=[(0,1)])
179-
policy = EpsilonSoftPolicy(states, actions, eps=eps)
179+
policy = EpsilonSoftPolicy(actions, states, eps=eps)
180180
elif not policy:
181-
policy = ModelFreePolicy(states, actions)
181+
policy = ModelFreePolicy(actions, states)
182182

183183
_typecheck_all(tabular_idxs=[states, actions],transition=transition,
184184
constants=[gamma, alpha, n_episodes, max_steps, samples],
@@ -264,16 +264,16 @@ def _visit_monte_carlo(MF, first_visit, exploring_starts, use_N, alpha,
264264
n_episode += 1
265265

266266
if sample_step and n_episode % sample_step == 0:
267-
samples.append(get_sample(MF, v, q, π, n_episode))
268-
269-
return v, q
267+
samples.append(get_sample(MF, v, q, π, n_episode, optimize))
268+
269+
return v, q, samples
270270

271271

272272
def off_policy_mc(states: Sequence[Any], actions: Sequence[Any], transition: Transition,
273273
gamma: float=0.9, first_visit: bool=True, ordinary: bool=True,
274274
n_episodes: int=MAX_ITER, max_steps: int=MAX_STEPS, samples: int=1000,
275275
optimize: bool=False, policy: ModelFreePolicy=None, eps: float=None,
276-
b :ModelFreePolicy=None) -> Tuple[VQPi, Samples]:
276+
b: ModelFreePolicy=None) -> Tuple[VQPi, Samples]:
277277
'''Off-policy Monte Carlo state and action value function estimation, policy
278278
279279
Off policy Monte Carlo method for estimating state and action-value functtions
@@ -326,11 +326,11 @@ def off_policy_mc(states: Sequence[Any], actions: Sequence[Any], transition: Tra
326326
if not policy and eps:
327327
_typecheck_all(constants=[eps])
328328
_check_ranges(values=[eps], ranges=[(0,1)])
329-
policy = EpsilonSoftPolicy(states, actions, eps=eps)
329+
policy = EpsilonSoftPolicy(actions, states, eps=eps)
330330
elif not policy:
331-
policy = ModelFreePolicy(states, actions)
331+
policy = ModelFreePolicy(actions, states)
332332
elif not b:
333-
b = ModelFreePolicy(states, actions)
333+
b = ModelFreePolicy(actions, states)
334334

335335
_typecheck_all(tabular_idxs=[states, actions],transition=transition,
336336
constants=[gamma, n_episodes, max_steps, samples],
@@ -409,7 +409,7 @@ def _off_policy_monte_carlo(MF, off_policy, max_episodes, max_steps, first_visit
409409
if sample_step and n_episode % sample_step == 0:
410410
samples.append(get_sample(MF, v, q, π, n_episode))
411411

412-
return v, q
412+
return v, q, samples
413413

414414

415415

@@ -485,9 +485,9 @@ def tdn(states: Sequence[Any], actions: Sequence[Any], transition: Transition,
485485
if not policy and eps:
486486
_typecheck_all(constants=[eps])
487487
_check_ranges(values=[eps], ranges=[(0,1)])
488-
policy = EpsilonSoftPolicy(states, actions, eps=eps)
488+
policy = EpsilonSoftPolicy(actions, states, eps=eps)
489489
elif not policy:
490-
policy = ModelFreePolicy(states, actions)
490+
policy = ModelFreePolicy(actions, states)
491491

492492
_typecheck_all(tabular_idxs=[states,actions], tansition=transition,
493493
constants=[gamma, n, alpha, n_episodes, samples, max_steps],
@@ -541,7 +541,7 @@ def _tdn(MF, n, alpha, n_episodes, max_steps, optimize, sample_step):
541541
n_episode += 1
542542

543543
if sample_step and n_episode % sample_step == 0:
544-
samples.append(get_sample(MF, v, q, π, n_episode))
544+
samples.append(get_sample(MF, v, q, π, n_episode, optimize))
545545

546546
return v, q, samples
547547

src/utils.py renamed to rl/utils.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
TOL = 1E-6
1717
MEAN_ITERS = int(1E4)
1818

19+
1920
class Policy(ABC):
2021
def __init__(self):
2122
pass
@@ -61,21 +62,27 @@ class Action(_TabularIndexer):
6162
pass
6263

6364

65+
class StateAction(_TabularIndexer):
66+
pass
67+
68+
6469
class _TabularValues:
6570
def __init__(self, values: np.ndarray, idx: _TabularIndexer):
6671
self.v = values
6772
self.idx = idx
6873
self.idx_val = {k:v for k,v in zip(idx.index.keys(), values)}
6974

7075
def values(self):
71-
return self.idx_val
76+
return self.v
7277

7378

7479
class Vpi(_TabularValues):
75-
pass
80+
def __str__(self):
81+
return f'Vpi({self.v[:10]}...)'
7682

7783
class Qpi(_TabularValues):
78-
pass
84+
def __str__(self):
85+
return f'Vpi({self.v[:10]}...)'
7986

8087

8188

@@ -114,13 +121,13 @@ def _typecheck_transition(transition):
114121
if not isinstance(transition, Callable):
115122
raise TypeError(
116123
f"transition must be a Callable, not {type(transition)}")
117-
118-
if len(transition.__code__.co_varnames) != 2:
119-
raise TypeError(
120-
"transition must have 2 arguments, not ",
121-
len(transition.__code__.co_varnames))
122-
123124

125+
#check that transition function has just two positional arguments
126+
if transition.__code__.co_argcount != 2:
127+
raise TypeError(
128+
f"transition must have two positional arguments,"
129+
f" not {transition.__code__.co_argcount}")
130+
124131
def _typecheck_constants(*args):
125132
for arg in args:
126133
if not isinstance(arg, (float, int)):

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from setuptools import setup, find_packages
2+
setup(name = 'rl', packages = find_packages())

src/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)