Skip to content

Commit 8f3f032

Browse files
committed
Fix suggest_move() in randompolicymixin & greedypolicymixin
1 parent f1b171c commit 8f3f032

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
parser.add_argument('--model_type', dest='model', default='full',
2222
help='choose residual block architecture {original,elu,full}')
2323
parser.add_argument('--optimizer', dest='opt', default='adam')
24-
parser.add_argument('--gtp_policy', dest='gpt_policy', default='mctspolicy',
24+
parser.add_argument('--gtp_policy', dest='gpt_policy', default='greedypolicy',
2525
help='choose gtp bot player') # random,mctspolicy
2626
parser.add_argument('--num_playouts', type=int, dest='num_playouts', default=1600,
2727
help='The number of MC search per move, the more the better.')

utils/gtp_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def make_gtp_instance(flags, hps):
116116
elif strategy_name == 'randompolicy':
117117
instance = RandomPolicyPlayer(n)
118118
elif strategy_name == 'mctspolicy':
119-
instance = MCTSPlayer(net=n, num_playouts=1600)
119+
instance = MCTSPlayer(net=n, num_playouts=flags.num_playouts)
120120
else:
121121
return None
122122
gtp_engine = gtp.Engine(instance)

utils/strategies.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,9 @@ def __init__(self, policy_network):
250250
super().__init__()
251251

252252
def suggest_move(self, position):
253-
move_probabilities = self.policy_network.run(position)
254-
return select_most_likely(position, move_probabilities)
253+
move_probabilities = self.policy_network.run_many(bulk_extract_features([position]))[0][0]
254+
on_board_move_prob = np.reshape(move_probabilities[:-1], (go.N, go.N))
255+
return select_most_likely(position, on_board_move_prob)
255256

256257

257258
class RandomPolicyPlayerMixin:
@@ -260,5 +261,6 @@ def __init__(self, policy_network):
260261
super().__init__()
261262

262263
def suggest_move(self, position):
263-
move_probabilities = self.policy_network.run(position)
264-
return select_weighted_random(position, move_probabilities)
264+
move_probabilities = self.policy_network.run_many(bulk_extract_features([position]))[0][0]
265+
on_board_move_prob = np.reshape(move_probabilities[:-1], (go.N, go.N))
266+
return select_weighted_random(position, on_board_move_prob)

0 commit comments

Comments
 (0)