Skip to content

Commit 82fdd8a

Browse files
authored
Merge pull request #106 from kaorahi/mcts_step2
MCTS過程のアニメーション(#93)の実装
2 parents eecde0b + b67aae2 commit 82fdd8a

File tree

5 files changed

+141
-18
lines changed

5 files changed

+141
-18
lines changed

animation/animation.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import sys
2+
import select
3+
import time
4+
5+
6+
def animate_mcts(mcts, board, to_move, pv_wait_sec, move_wait_sec):
7+
previous_pv = []
8+
def callback(path):
9+
_animate_path(path, mcts, board, pv_wait_sec, move_wait_sec, previous_pv)
10+
finished = _stdin_has_data()
11+
return finished
12+
mcts.search_with_callback(board, to_move, callback)
13+
14+
15+
def _stdin_has_data():
16+
rlist, _, _ = select.select([sys.stdin], [], [], 0)
17+
return bool(rlist)
18+
19+
20+
def _animate_path(path, mcts, board, pv_wait_sec, move_wait_sec, previous_pv):
21+
# 今回探索した系列の属性値
22+
root_index, i = path[0]
23+
root = mcts.node[root_index]
24+
if root.children_visits[i] == 0:
25+
return
26+
coordinate = board.coordinate
27+
move = coordinate.convert_to_gtp_format(root.action[i])
28+
pv = [coordinate.convert_to_gtp_format(mcts.node[index].action[child_index]) for (index, child_index) in path]
29+
pv_visits = [str(mcts.node[index].children_visits[child_index]) for (index, child_index) in path]
30+
pv_winrate = [str(int(10000 * _get_winrate(mcts, index, child_index, depth))) for depth, (index, child_index) in enumerate(path)]
31+
32+
# lz-analyze の本来の出力内容を加工
33+
children_status_list = root.get_analysis_status_list(board, mcts.get_pv_lists)
34+
fake_status_list = [status.copy() for status in children_status_list]
35+
target = next((status for status in fake_status_list if status["move"] == move), None)
36+
if target is None:
37+
return # can't happen
38+
# 今回探索した系列の初手を最善手と偽って順位をふり直す
39+
target["order"] = -1
40+
fake_status_list.sort(key=lambda status: status["order"])
41+
for order, status in enumerate(fake_status_list):
42+
status["order"] = order
43+
44+
# PV 欄を差しかえながら複数回出力することで一手ずつアニメーション
45+
for k in range(1, len(pv) + 1):
46+
# 前回の系列と共通な手順はスキップ
47+
if pv[:k] == previous_pv[:k]:
48+
continue
49+
50+
target["pv"] = " ".join(pv[:k])
51+
target["pvVisits"] = " ".join(pv_visits[:k])
52+
target["pvWinrate"] = " ".join(pv_winrate[:k])
53+
54+
sys.stdout.write(root.get_analysis_from_status_list("lz", fake_status_list))
55+
sys.stdout.flush()
56+
time.sleep(max(move_wait_sec, 0.0))
57+
58+
previous_pv[:] = pv
59+
time.sleep(max(pv_wait_sec, 0.0))
60+
61+
62+
def _get_winrate(mcts, index, child_index, depth):
63+
node = mcts.node[index]
64+
i = child_index
65+
visits = node.children_visits[i]
66+
value = node.children_value_sum[i] / visits if visits > 0 else node.children_value[i]
67+
winrate = value if depth % 2 == 0 else 1.0 - value
68+
return winrate

gtp/client.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from nn.policy_player import generate_move_from_policy
2020
from nn.utility import load_network
2121
from sgf.reader import SGFReader
22+
from animation.animation import animate_mcts
2223

2324

2425
gtp_command_id = ""
@@ -30,7 +31,8 @@ class GtpClient: # pylint: disable=R0902,R0903
3031
def __init__(self, board_size: int, superko: bool, model_file_path: str, \
3132
use_gpu: bool, policy_move: bool, use_sequential_halving: bool, \
3233
komi: float, mode: TimeControl, visits: int, const_time: float, \
33-
time: float, batch_size: int, tree_size: int, cgos_mode: bool): # pylint: disable=R0913
34+
time: float, batch_size: int, tree_size: int, cgos_mode: bool, \
35+
animation_pv_wait: float, animation_move_wait:float): # pylint: disable=R0913
3436
"""Go Text Protocolクライアントの初期化をする。
3537
3638
Args:
@@ -92,6 +94,8 @@ def __init__(self, board_size: int, superko: bool, model_file_path: str, \
9294
self.policy_move = policy_move
9395
self.use_sequential_halving = use_sequential_halving
9496
self.use_network = False
97+
self.animation_pv_wait = animation_pv_wait
98+
self.animation_move_wait = animation_move_wait
9599

96100
if mode is TimeControl.CONSTANT_PLAYOUT or mode is TimeControl.STRICT_PLAYOUT:
97101
self.time_manager = TimeManager(mode=mode, constant_visits=visits)
@@ -400,6 +404,18 @@ def _decode_analyze_arg(self, arg_list: List[str]) -> (Stone, float):
400404
return error_value
401405
return (to_move, interval)
402406

407+
def _analyze_or_animate(self, mode: str, arg_list: List[str]) -> NoReturn:
408+
if max(self.animation_pv_wait, self.animation_move_wait) >= 0:
409+
self._animate(arg_list, self.animation_pv_wait, self.animation_move_wait)
410+
else:
411+
self._analyze(mode, arg_list)
412+
413+
def _animate(self, arg_list: List[str], pv_wait: float, move_wait: float) -> NoReturn:
414+
to_move, _ = self._decode_analyze_arg(arg_list)
415+
respond_success("", ongoing=True)
416+
animate_mcts(self.mcts, self.board, to_move, pv_wait, move_wait)
417+
print_out("")
418+
403419
def _analyze(self, mode: str, arg_list: List[str]) -> NoReturn:
404420
"""analyzeコマンド(lz-analyze, cgos-analyze)を実行する。
405421
@@ -565,7 +581,7 @@ def run(self) -> NoReturn: # pylint: disable=R0912,R0915
565581
self.board.display_self_atari(Stone.WHITE)
566582
respond_success("")
567583
elif input_gtp_command == "lz-analyze":
568-
self._analyze("lz", command_list[1:])
584+
self._analyze_or_animate("lz", command_list[1:])
569585
print("")
570586
elif input_gtp_command == "lz-genmove_analyze":
571587
self._genmove_analyze("lz", command_list[1:])

main.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,14 @@
4444
help=f"探索木を構成するノードの最大数。デフォルトはMCTS_TREE_SIZE = {MCTS_TREE_SIZE}。")
4545
@click.option('--cgos-mode', type=click.BOOL, default=False, \
4646
help="全ての石を打ち上げるまでパスしないモード設定。デフォルトはFalse。")
47+
@click.option('--animation-pv-wait', type=click.FLOAT, default=-1.0, \
48+
help="lz-analyzeの出力をMCTSアニメーションに差しかえて、系列ごとに指定秒停止。")
49+
@click.option('--animation-move-wait', type=click.FLOAT, default=-1.0, \
50+
help="lz-analyzeの出力をMCTSアニメーションに差しかえて、一手ごとに指定秒停止。")
4751
def gtp_main(size: int, superko: bool, model:str, use_gpu: bool, sequential_halving: bool, \
4852
policy_move: bool, komi: float, visits: int, strict_visits: int, const_time: float, time: float, \
49-
batch_size: int, tree_size: int, cgos_mode: bool):
53+
batch_size: int, tree_size: int, cgos_mode: bool, \
54+
animation_pv_wait: float, animation_move_wait: float):
5055
"""GTPクライアントの起動。
5156
5257
Args:
@@ -64,6 +69,8 @@ def gtp_main(size: int, superko: bool, model:str, use_gpu: bool, sequential_halv
6469
batch_size (int): 探索実行時のニューラルネットワークのミニバッチサイズ。デフォルトはNN_BATCH_SIZE。
6570
tree_size (int): 探索木を構成するノードの最大数。デフォルトはMCTS_TREE_SIZE。
6671
cgos_mode (bool): 全ての石を打ち上げるまでパスしないモード設定。デフォルトはFalse。
72+
animation_pv_wait (float): lz-analyzeの出力をMCTSアニメーションに差しかえて、系列ごとに指定秒停止。
73+
animation_move_wait (float): lz-analyzeの出力をMCTSアニメーションに差しかえて、一手ごとに指定秒停止。
6774
"""
6875
mode = TimeControl.CONSTANT_PLAYOUT
6976

@@ -78,7 +85,7 @@ def gtp_main(size: int, superko: bool, model:str, use_gpu: bool, sequential_halv
7885
program_dir = os.path.dirname(__file__)
7986
client = GtpClient(size, superko, os.path.join(program_dir, model), use_gpu, policy_move, \
8087
sequential_halving, komi, mode, visits, const_time, time, batch_size, tree_size, \
81-
cgos_mode)
88+
cgos_mode, animation_pv_wait, animation_move_wait)
8289
client.run()
8390

8491

mcts/node.py

+13
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,12 @@ def get_analysis(self, board: GoBoard, mode: str, \
408408
Returns:
409409
str: GTP応答用解析結果文字列。
410410
"""
411+
children_status_list = self.get_analysis_status_list(board, pv_lists_func)
412+
return self.get_analysis_from_status_list(mode, children_status_list)
413+
414+
415+
def get_analysis_status_list(self, board: GoBoard, \
416+
pv_lists_func: Callable[[List[str], int], List[str]]):
411417
sorted_list = []
412418
for i in range(self.num_children):
413419
sorted_list.append((self.children_visits[i], i))
@@ -439,7 +445,10 @@ def get_analysis(self, board: GoBoard, mode: str, \
439445
}
440446
)
441447
order += 1
448+
return children_status_list
449+
442450

451+
def get_analysis_from_status_list(self, mode, children_status_list):
443452
out = ""
444453
if mode == "cgos":
445454
cgos_dict = {
@@ -457,6 +466,10 @@ def get_analysis(self, board: GoBoard, mode: str, \
457466
out += f"lcb {int(10000 * status['lcb'])} "
458467
out += f"order {status['order']} "
459468
out += f"pv {status['pv']}"
469+
# if "pvVisits" in status:
470+
# out += f" pvVisits {status['pvVisits']}"
471+
# if "pvWinrate" in status:
472+
# out += f" lizgobanPvWinrate {status['pvWinrate']}"
460473
out += " "
461474
elif mode == "cgos":
462475
cgos_dict["moves"].append(status)

mcts/tree.py

+33-14
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""モンテカルロ木探索の実装。
22
"""
3-
from typing import Any, Dict, List, NoReturn, Tuple
3+
from typing import Any, Dict, List, NoReturn, Tuple, Callable
44
import sys
55
import select
66
import copy
@@ -46,6 +46,14 @@ def __init__(self, network: DualNet, tree_size: int=MCTS_TREE_SIZE, \
4646
self.to_move = Stone.BLACK
4747

4848

49+
def _initialize_search(self, board: GoBoard, color: Stone) -> NoReturn:
50+
self.num_nodes = 0
51+
self.current_root = self.expand_node(board, color)
52+
input_plane = generate_input_planes(board, color, 0)
53+
self.batch_queue.push(input_plane, [], self.current_root)
54+
self.process_mini_batch(board)
55+
56+
4957
def search_best_move(self, board: GoBoard, color: Stone, time_manager: TimeManager, \
5058
analysis_query: Dict[str, Any]) -> int:
5159
"""モンテカルロ木探索を実行して最善手を返す。
@@ -58,16 +66,10 @@ def search_best_move(self, board: GoBoard, color: Stone, time_manager: TimeManag
5866
Returns:
5967
int: 着手する座標。
6068
"""
61-
self.num_nodes = 0
69+
self._initialize_search(board, color)
6270

6371
time_manager.start_timer()
6472

65-
self.current_root = self.expand_node(board, color)
66-
input_plane = generate_input_planes(board, color, 0)
67-
self.batch_queue.push(input_plane, [], self.current_root)
68-
69-
self.process_mini_batch(board)
70-
7173
root = self.node[self.current_root]
7274

7375
# 候補手が1つしかない場合はPASSを返す
@@ -111,12 +113,7 @@ def ponder(self, board: GoBoard, color: Stone, analysis_query: Dict[str, Any]) -
111113
color (Stone): 思考する手番の色。
112114
analysis_query (Dict): 解析情報。
113115
"""
114-
self.num_nodes = 0
115-
116-
self.current_root = self.expand_node(board, color)
117-
input_plane = generate_input_planes(board, color, 0)
118-
self.batch_queue.push(input_plane, [], self.current_root)
119-
self.process_mini_batch(board)
116+
self._initialize_search(board, color)
120117

121118
# 探索を実行する
122119
max_visits = 999999999
@@ -177,6 +174,28 @@ def search(self, board: GoBoard, color: Stone, time_manager: TimeManager, \
177174
sys.stdout.flush()
178175

179176

177+
def search_with_callback(self, board: GoBoard, color: Stone, callback: Callable[List[Tuple[int, int]], bool]) -> NoReturn:
178+
"""探索を実行し、探索系列をコールバック関数へ渡す動作をくり返す。
179+
コールバック関数の戻り値が真になれば終了する。
180+
Args:
181+
board (GoBoard): 現在の局面情報。
182+
color (Stone): 現局面の手番の色。
183+
callback (Callable[List[Tuple[int, int]], bool]): コールバック関数。
184+
"""
185+
original_batch_size = self.batch_size
186+
self.batch_size = 1
187+
self._initialize_search(board, color)
188+
search_board = copy.deepcopy(board)
189+
while True:
190+
path = []
191+
copy_board(dst=search_board, src=board)
192+
self.search_mcts(search_board, color, self.current_root, path)
193+
finished = callback(path)
194+
if finished:
195+
break
196+
self.batch_size = original_batch_size
197+
198+
180199
def search_mcts(self, board: GoBoard, color: Stone, current_index: int, \
181200
path: List[Tuple[int, int]]) -> NoReturn:
182201
"""モンテカルロ木探索を実行する。

0 commit comments

Comments
 (0)