Skip to content

Commit 4bc62f1

Browse files
committedOct 25, 2016
💄
1 parent fc3f834 commit 4bc62f1

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed
 

‎my_atari_ram_policy.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,17 @@ def get_parameters_flat(self):
153153

154154
def set_parameters_flat(self, th):
155155
self.sess.run(tf.initialize_all_variables())
156+
# Get shape of parameters from the class.
156157
n_in = self.n_in
157158
n_hid = self.n_hid
158159
n_actions = self.n_actions
160+
# Grab and reshape weight matrices.
159161
W_01 = th[:n_hid*n_in].reshape(n_in,n_hid)
160162
W_12 = th[n_hid*n_in:n_hid*n_in+n_hid*n_actions].reshape(n_hid,n_actions)
163+
# Pull the biases off the end of th.
161164
b_01 = th[-n_hid-n_actions:-n_actions]
162165
b_12 = th[-n_actions:]
166+
# Assign the variables the values passed through th.
163167
self.sess.run(tf.assign(self.W_01, W_01))
164168
self.sess.run(tf.assign(self.W_12, W_12))
165169
self.sess.run(tf.assign(self.b_01, b_01))
@@ -221,28 +225,34 @@ def test_AtariRAMPolicy():
221225
th_new = policy.get_parameters_flat()
222226
assert not np.any(th_new - theta)
223227

224-
def fpen(th): #, probs_na, obs, a_n, q_n):
228+
# We define a new function in test_AtariRAMPolicy() so we can hand in the
229+
# data needed and leave the parameters of the model as the only input.
230+
def fpen(th):
225231
thprev = policy.get_parameters_flat()
226232
policy.set_parameters_flat(th)
227-
surr, kl = policy.compute_surr_kl(*poar_train)#probs_na, obs, a_n, q_n)
233+
surr, kl = policy.compute_surr_kl(*poar_train)
228234
out = penalty_coeff * kl - surr
229235
policy.set_parameters_flat(thprev)
230236
return out
231237

232-
print(fpen(theta))#, probs_na, obs, a_n, q_n))
233-
def fgradpen(th): #, probs_na, obs, a_n, q_n):
238+
# Quick check it works.
239+
fpen(theta)
240+
241+
# Do the same thing for the gradient of fpen.
242+
def fgradpen(th):
234243
thprev = policy.get_parameters_flat()
235244
policy.set_parameters_flat(th)
236-
out = - policy.compute_grad_lagrangian(penalty_coeff, *poar_train) #probs_na, obs, a_n, q_n)
245+
out = - policy.compute_grad_lagrangian(penalty_coeff, *poar_train)
237246
policy.set_parameters_flat(thprev)
238247
return out
239-
print(fgradpen(theta)) #, probs_na, obs, a_n, q_n).shape)
240248

241-
# opt.check_grad(fpen, fgradpen, theta)
242-
# eps = np.sqrt(np.finfo(float).eps)
243-
# opt.approx_fprime(theta, fpen, eps)
249+
# Testing fgradpen()
250+
fgradpen(theta)
251+
252+
# Test out our functions in the context of lbfgs-b minimization with scipy.
244253
res = opt.fmin_l_bfgs_b(fpen, theta, fprime=fgradpen, maxiter=20)
245-
# res = opt.fmin_cg(fpen, theta, maxiter=20, fprime=fgradpen)
254+
255+
print(res)
246256

247257
if __name__ == "__main__":
248258
test_AtariRAMPolicy()

0 commit comments

Comments
 (0)
Please sign in to comment.