Skip to content

Commit dae45e8

Browse files
committed
Fixed up randomness for environment and models.
In general we want to have the largest possible diversity between the processes, to prevent learning from degenerating. Randomness in present in the environment through the random seed of the atari emulator and the number of no-ops at the beginning of the game. It is present in the model through the sampling of discrete actions. This patch makes sure there is a training level random seed, which is saved to the args.txt file ( even if it has been generated ). This seed is in turn used to create process level random seeds, which are used for both the environment and the model. The enviroment random seed is used for the emulator too.
1 parent 7d6045e commit dae45e8

File tree

6 files changed

+64
-20
lines changed

6 files changed

+64
-20
lines changed

a3c_ale.py

+29-11
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828

2929
class A3CFF(chainer.ChainList, a3c.A3CModel):
3030

31-
def __init__(self, n_actions):
31+
def __init__(self, n_actions, seed):
3232
self.head = dqn_head.NIPSDQNHead()
3333
self.pi = policy.FCSoftmaxPolicy(
34-
self.head.n_output_channels, n_actions)
34+
self.head.n_output_channels, n_actions, seed)
3535
self.v = v_function.FCVFunction(self.head.n_output_channels)
3636
if sys.version_info < (3,0):
3737
super(A3CFF, self).__init__(self.head, self.pi, self.v)
@@ -46,10 +46,10 @@ def pi_and_v(self, state, keep_same_state=False):
4646

4747
class A3CLSTM(chainer.ChainList, a3c.A3CModel):
4848

49-
def __init__(self, n_actions):
49+
def __init__(self, n_actions, seed):
5050
self.head = dqn_head.NIPSDQNHead()
5151
self.pi = policy.FCSoftmaxPolicy(
52-
self.head.n_output_channels, n_actions)
52+
self.head.n_output_channels, n_actions, seed)
5353
self.v = v_function.FCVFunction(self.head.n_output_channels)
5454
self.lstm = L.LSTM(self.head.n_output_channels,
5555
self.head.n_output_channels)
@@ -214,20 +214,32 @@ def main():
214214
parser.set_defaults(use_lstm=False)
215215
args = parser.parse_args()
216216

217-
if args.seed is not None:
218-
random_seed.set_random_seed(args.seed)
217+
if args.seed is None:
218+
args.seed = np.random.randint(0, 2 ** 16)
219+
220+
221+
# I suggest using train_randstate instead of np.random because it proably
222+
# behaves better for async use.
223+
train_randstate = np.random.RandomState(args.seed)
224+
225+
# Choose random seed before async execution, in oder to assure
226+
# that we obtain different seeds for each process. This can be checked
227+
# by making sure each emulator has different seed, this works because each
228+
# emulator is set to have the same random seeds as its process ( the ALE python
229+
# class ) see ale.py for detials
230+
process_seeds = train_randstate.randint(0, 2 ** 16, args.processes)
219231

220232
args.outdir = prepare_output_dir(args, args.outdir)
221233

222234
print('Output files are saved in {}'.format(args.outdir))
223235

224236
n_actions = ale.ALE(args.rom).number_of_actions
225237

226-
def model_opt():
238+
def model_opt(seed=args.seed):
227239
if args.use_lstm:
228-
model = A3CLSTM(n_actions)
240+
model = A3CLSTM(n_actions,seed=seed)
229241
else:
230-
model = A3CFF(n_actions)
242+
model = A3CFF(n_actions,seed=seed)
231243
opt = rmsprop_async.RMSpropAsync(lr=7e-4, eps=1e-1, alpha=0.99)
232244
opt.setup(model)
233245
opt.add_hook(chainer.optimizer.GradientClipping(40))
@@ -249,9 +261,15 @@ def model_opt():
249261
column_names = ('steps', 'elapsed', 'mean', 'median', 'stdev')
250262
print('\t'.join(column_names), file=f)
251263

264+
# convert np.int64 to python int for JSON
265+
process_seeds = [int(x) for x in process_seeds]
266+
252267
def run_func(process_idx):
253-
env = ale.ALE(args.rom, use_sdl=args.use_sdl)
254-
model, opt = model_opt()
268+
env = ale.ALE(args.rom,
269+
seed=process_seeds[process_idx],
270+
use_sdl=args.use_sdl)
271+
272+
model, opt = model_opt(seed=process_seeds[process_idx])
255273
async.set_shared_params(model, shared_params)
256274
async.set_shared_states(opt, shared_states)
257275

ale.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,22 @@ def __init__(self, rom_filename, seed=None, use_sdl=False, n_last_screens=4,
2727
assert seed >= 0 and seed < 2 ** 16, \
2828
"ALE's random seed must be represented by unsigned int"
2929
else:
30-
# Use numpy's random state
30+
# Warning Starting ALE without explicit random seeds can lead
31+
# to all processes sharing the same inital state. Please check the
32+
# args.txt in case you are concerned about this.
3133
seed = np.random.randint(0, 2 ** 16)
32-
ale.setInt(b'random_seed', seed)
34+
35+
# Remember our (per process) random seed
36+
self.seed = seed
37+
38+
# Intialize a random state for this thread. If we always call
39+
# self.randstate instead of np.random it should make the process
40+
# deterministic.
41+
self.randstate = np.random.RandomState(self.seed)
42+
43+
# Use the random seed for the ALE too
44+
ale.setInt(b'random_seed', self.seed)
45+
3346
ale.setFloat(b'repeat_action_probability', 0.0)
3447
ale.setBool(b'color_averaging', False)
3548
if record_screen_dir is not None:
@@ -142,7 +155,7 @@ def initialize(self):
142155
self.ale.reset_game()
143156

144157
if self.max_start_nullops > 0:
145-
n_nullops = np.random.randint(0, self.max_start_nullops + 1)
158+
n_nullops = self.randstate.randint(0, self.max_start_nullops + 1)
146159
for _ in range(n_nullops):
147160
self.ale.act(0)
148161

async.py

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def run_async(n_process, run_func):
7575

7676
processes = []
7777

78+
# It is not clear to me that this does what it should. --max
7879
def set_seed_and_run(process_idx, run_func):
7980
random_seed.set_random_seed(np.random.randint(0, 2 ** 32))
8081
run_func(process_idx)

policy.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import chainer
55
from chainer import functions as F
66
from chainer import links as L
7+
import numpy as np
78

89
import policy_output
910

@@ -26,19 +27,29 @@ def compute_logits(self, state):
2627
raise NotImplementedError
2728

2829
def __call__(self, state):
29-
return policy_output.SoftmaxPolicyOutput(self.compute_logits(state))
30+
# The SoftmaxPolicyOutput is not persistent, so it cannot hold
31+
# its own random state, rely instead on the policy randstate
32+
# passed as a reference
33+
return policy_output.SoftmaxPolicyOutput(
34+
self.compute_logits(state),
35+
self.policy_randstate)
3036

3137

3238
class FCSoftmaxPolicy(chainer.ChainList, SoftmaxPolicy):
3339
"""Softmax policy that consists of FC layers and rectifiers"""
3440

35-
def __init__(self, n_input_channels, n_actions,
41+
def __init__(self, n_input_channels, n_actions, seed,
3642
n_hidden_layers=0, n_hidden_channels=None):
3743
self.n_input_channels = n_input_channels
3844
self.n_actions = n_actions
3945
self.n_hidden_layers = n_hidden_layers
4046
self.n_hidden_channels = n_hidden_channels
4147

48+
# Have a per policy randstate, this should provide diversity
49+
# in the fact of similar environments
50+
self.model_seed = seed
51+
self.policy_randstate = np.random.RandomState(seed)
52+
4253
layers = []
4354
if n_hidden_layers > 0:
4455
layers.append(L.Linear(n_input_channels, n_hidden_channels))

policy_output.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class PolicyOutput(object):
99
pass
1010

1111

12-
def _sample_discrete_actions(batch_probs):
12+
def _sample_discrete_actions(batch_probs, randstate):
1313
"""Sample a batch of actions from a batch of action probabilities.
1414
1515
Args:
@@ -31,8 +31,9 @@ def _sample_discrete_actions(batch_probs):
3131

3232
class SoftmaxPolicyOutput(PolicyOutput):
3333

34-
def __init__(self, logits):
34+
def __init__(self, logits, randstate):
3535
self.logits = logits
36+
self.policy_output_randstate = randstate
3637

3738
@cached_property
3839
def most_probable_actions(self):
@@ -48,7 +49,7 @@ def log_probs(self):
4849

4950
@cached_property
5051
def action_indices(self):
51-
return _sample_discrete_actions(self.probs.data)
52+
return _sample_discrete_actions(self.probs.data, self.policy_output_randstate)
5253

5354
@cached_property
5455
def sampled_actions_log_probs(self):

prepare_output_dir.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def prepare_output_dir(args, user_specified_dir=None):
3636

3737
# Save all the arguments
3838
with open(os.path.join(outdir, 'args.txt'), 'w') as f:
39-
f.write(json.dumps(vars(args)))
39+
f.write(json.dumps(vars(args))+"\n\n")
4040

4141
# Save `git status`
4242
with open(os.path.join(outdir, 'git-status.txt'), 'w') as f:

0 commit comments

Comments
 (0)