@@ -188,14 +188,15 @@ def act(i: int, free_queue: mp.SimpleQueue, full_queue: mp.SimpleQueue,
188
188
for episode_state_key in episode_state_count_dict :
189
189
episode_state_count_dict = dict ()
190
190
191
- # Update the training state counts
192
- train_state_key = tuple (env_output ['frame' ].view (- 1 ).tolist ())
193
- if train_state_key in train_state_count_dict :
194
- train_state_count_dict [train_state_key ] += 1
195
- else :
196
- train_state_count_dict .update ({train_state_key : 1 })
197
- buffers ['train_state_count' ][index ][0 , ...] = \
198
- torch .tensor (1 / np .sqrt (train_state_count_dict .get (train_state_key )))
191
+ # Update the training state counts if you're doing count-based exploration
192
+ if flags .model == 'count' :
193
+ train_state_key = tuple (env_output ['frame' ].view (- 1 ).tolist ())
194
+ if train_state_key in train_state_count_dict :
195
+ train_state_count_dict [train_state_key ] += 1
196
+ else :
197
+ train_state_count_dict .update ({train_state_key : 1 })
198
+ buffers ['train_state_count' ][index ][0 , ...] = \
199
+ torch .tensor (1 / np .sqrt (train_state_count_dict .get (train_state_key )))
199
200
200
201
# Do new rollout
201
202
for t in range (flags .unroll_length ):
@@ -229,14 +230,15 @@ def act(i: int, free_queue: mp.SimpleQueue, full_queue: mp.SimpleQueue,
229
230
if env_output ['done' ][0 ][0 ]:
230
231
episode_state_count_dict = dict ()
231
232
232
- # Update the training state counts
233
- train_state_key = tuple (env_output ['frame' ].view (- 1 ).tolist ())
234
- if train_state_key in train_state_count_dict :
235
- train_state_count_dict [train_state_key ] += 1
236
- else :
237
- train_state_count_dict .update ({train_state_key : 1 })
238
- buffers ['train_state_count' ][index ][t + 1 , ...] = \
239
- torch .tensor (1 / np .sqrt (train_state_count_dict .get (train_state_key )))
233
+ # Update the training state counts if you're doing count-based exploration
234
+ if flags .model == 'count' :
235
+ train_state_key = tuple (env_output ['frame' ].view (- 1 ).tolist ())
236
+ if train_state_key in train_state_count_dict :
237
+ train_state_count_dict [train_state_key ] += 1
238
+ else :
239
+ train_state_count_dict .update ({train_state_key : 1 })
240
+ buffers ['train_state_count' ][index ][t + 1 , ...] = \
241
+ torch .tensor (1 / np .sqrt (train_state_count_dict .get (train_state_key )))
240
242
241
243
timings .time ('write' )
242
244
full_queue .put (index )
0 commit comments