Skip to content
This repository was archived by the owner on Mar 8, 2022. It is now read-only.

Commit c2f2543

Browse files
committed
Only keep track of training state counts if using count-based exploration.
1 parent 5cfcc72 commit c2f2543

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

src/utils.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -188,14 +188,15 @@ def act(i: int, free_queue: mp.SimpleQueue, full_queue: mp.SimpleQueue,
188188
for episode_state_key in episode_state_count_dict:
189189
episode_state_count_dict = dict()
190190

191-
# Update the training state counts
192-
train_state_key = tuple(env_output['frame'].view(-1).tolist())
193-
if train_state_key in train_state_count_dict:
194-
train_state_count_dict[train_state_key] += 1
195-
else:
196-
train_state_count_dict.update({train_state_key: 1})
197-
buffers['train_state_count'][index][0, ...] = \
198-
torch.tensor(1 / np.sqrt(train_state_count_dict.get(train_state_key)))
191+
# Update the training state counts if you're doing count-based exploration
192+
if flags.model == 'count':
193+
train_state_key = tuple(env_output['frame'].view(-1).tolist())
194+
if train_state_key in train_state_count_dict:
195+
train_state_count_dict[train_state_key] += 1
196+
else:
197+
train_state_count_dict.update({train_state_key: 1})
198+
buffers['train_state_count'][index][0, ...] = \
199+
torch.tensor(1 / np.sqrt(train_state_count_dict.get(train_state_key)))
199200

200201
# Do new rollout
201202
for t in range(flags.unroll_length):
@@ -229,14 +230,15 @@ def act(i: int, free_queue: mp.SimpleQueue, full_queue: mp.SimpleQueue,
229230
if env_output['done'][0][0]:
230231
episode_state_count_dict = dict()
231232

232-
# Update the training state counts
233-
train_state_key = tuple(env_output['frame'].view(-1).tolist())
234-
if train_state_key in train_state_count_dict:
235-
train_state_count_dict[train_state_key] += 1
236-
else:
237-
train_state_count_dict.update({train_state_key: 1})
238-
buffers['train_state_count'][index][t + 1, ...] = \
239-
torch.tensor(1 / np.sqrt(train_state_count_dict.get(train_state_key)))
233+
# Update the training state counts if you're doing count-based exploration
234+
if flags.model == 'count':
235+
train_state_key = tuple(env_output['frame'].view(-1).tolist())
236+
if train_state_key in train_state_count_dict:
237+
train_state_count_dict[train_state_key] += 1
238+
else:
239+
train_state_count_dict.update({train_state_key: 1})
240+
buffers['train_state_count'][index][t + 1, ...] = \
241+
torch.tensor(1 / np.sqrt(train_state_count_dict.get(train_state_key)))
240242

241243
timings.time('write')
242244
full_queue.put(index)

0 commit comments

Comments
 (0)