Skip to content

Commit eecde0b

Browse files
authored
Merge pull request #103 from kaorahi/plot_tree2
Improve #94 (plot_tree)
2 parents 8982970 + e751c4c commit eecde0b

File tree

5 files changed

+117
-31
lines changed

5 files changed

+117
-31
lines changed

Diff for: board/go_board.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,11 @@ def get_all_legal_pos(self, color: Stone) -> List[int]:
411411
def display(self, sym: int=0) -> NoReturn:
412412
"""盤面を表示する。
413413
"""
414+
print_err(self.get_board_string(sym=sym))
415+
416+
def get_board_string(self, sym: int=0) -> str:
417+
"""盤面を表わす文字列を返す。
418+
"""
414419
board_string = f"Move : {self.moves}\n"
415420
board_string += f"Prisoner(Black) : {self.prisoner[0]}\n"
416421
board_string += f"Prisoner(White) : {self.prisoner[1]}\n"
@@ -432,7 +437,7 @@ def display(self, sym: int=0) -> NoReturn:
432437

433438
board_string += " +" + "-" * (self.board_size * 2 + 1) + "+\n"
434439

435-
print_err(board_string)
440+
return board_string
436441

437442

438443
def display_self_atari(self, color: Stone) -> NoReturn:
@@ -546,6 +551,13 @@ def get_handicap_history(self) -> List[int]:
546551
"""
547552
return self.record.handicap_pos[:]
548553

554+
def set_history(self, move_history, handicap_history):
555+
self.clear()
556+
for handicap in handicap_history:
557+
self.board.put_handicap_stone(handicap, Stone.BLACK)
558+
for (color, pos, _) in move_history:
559+
self.put_stone(pos, color)
560+
549561
def count_score(self) -> int: # pylint: disable=R0912
550562
"""領地を簡易的にカウントする。
551563

Diff for: board/pattern.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,17 @@
1818
[0xfffc, 0x00000001, 0x00000002],
1919
], dtype=np.uint32)
2020

21+
nb4_empty = [0] * 65536
22+
for i, _ in enumerate(nb4_empty):
23+
if ((i >> 2) & 0x3) == 0:
24+
nb4_empty[i] += 1
25+
if ((i >> 6) & 0x3) == 0:
26+
nb4_empty[i] += 1
27+
if ((i >> 8) & 0x3) == 0:
28+
nb4_empty[i] += 1
29+
if ((i >> 12) & 0x3) == 0:
30+
nb4_empty[i] += 1
31+
2132

2233
class Pattern:
2334
"""配石パターンクラス。
@@ -38,17 +49,6 @@ def __init__(self, board_size: int, pos_func: Callable[[int], int]):
3849
-1, 1, board_size_with_ob - 1, board_size_with_ob, board_size_with_ob + 1
3950
]
4051

41-
self.nb4_empty = [0] * 65536
42-
for i, _ in enumerate(self.nb4_empty):
43-
if ((i >> 2) & 0x3) == 0:
44-
self.nb4_empty[i] += 1
45-
if ((i >> 6) & 0x3) == 0:
46-
self.nb4_empty[i] += 1
47-
if ((i >> 8) & 0x3) == 0:
48-
self.nb4_empty[i] += 1
49-
if ((i >> 12) & 0x3) == 0:
50-
self.nb4_empty[i] += 1
51-
5252
# 眼のパターン
5353
eye_pat3 = [
5454
# +OO XOO +O+ XO+
@@ -148,7 +148,7 @@ def get_n_neighbors_empty(self, pos: int) -> int:
148148
Returns:
149149
int: 上下左右の空点数(最大4)
150150
"""
151-
return self.nb4_empty[self.pat3[pos]]
151+
return nb4_empty[self.pat3[pos]]
152152

153153
def get_eye_color(self, pos: int) -> Stone:
154154
"""指定した座標の眼の色を取得する。

Diff for: graph/plot_tree.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,12 @@ def plot_tree_main(input_json_path: str, output_image_path: str, around_pv: bool
3636
cd tamago
3737
(echo 'tamago-readsgf (;SZ[9]KM[7];B[fe];W[de];B[ec])';
3838
echo 'lz-genmove_analyze 7777777';
39+
echo 'undo';
3940
echo 'tamago-dump_tree') \\
4041
| python3 main.py --model model/model.bin --strict-visits 100 \\
4142
| grep dump_version | gzip > tree.json.gz
4243
python3 graph/plot_tree.py tree.json.gz tree_graph
43-
display tree_graph.png
44+
display tree_graph.svg
4445
"""
4546

4647
opener = gzip.open if input_json_path.endswith('.gz') else open
@@ -61,13 +62,15 @@ def plot_tree_main(input_json_path: str, output_image_path: str, around_pv: bool
6162

6263
for index in sorted_indices_list:
6364
item = node[index]
65+
item_id = get_graphviz_id(index, node)
6466
# ルートノードの場合
6567
if "parent_index" not in item:
66-
dot.node(str(index), label=f"root\n{item['node_visits']} visits")
68+
dot.node(item_id, label=f"root\n{item['node_visits']} visits")
6769
continue
6870

6971
parent_index = item['parent_index']
7072
parent = node[parent_index]
73+
parent_id = get_graphviz_id(parent_index, node)
7174
# around_pv が指定された場合は、PV とその直下の子のみ表示する。
7275
if around_pv and any(order > 0 for order in parent["orders_along_path"]):
7376
continue
@@ -86,7 +89,7 @@ def plot_tree_main(input_json_path: str, output_image_path: str, around_pv: bool
8689
raw_wr = int(raw_winrate * 100)
8790
label = f"{move}\n{wr}%" if visits < 10 else f"{move}\n{wr}% (raw {raw_wr}%)\n{visits} visits"
8891
dot.node(
89-
str(index),
92+
item_id,
9093
label=label,
9194
color=border_color,
9295
fillcolor=node_color,
@@ -99,10 +102,14 @@ def plot_tree_main(input_json_path: str, output_image_path: str, around_pv: bool
99102
)
100103

101104
# エッジの作成
105+
freshness = (item['index'] + 1) / len(node)
106+
whiteness = 0.9
107+
c = f"{int(freshness * whiteness * 255):02x}"
108+
color = f"#{c}{c}{c}"
102109
penwidth = max(0.5, item['policy'] * 10)
103-
dot.edge(str(parent_index), str(index), penwidth=f"{penwidth}")
110+
dot.edge(parent_id, item_id, color=color, penwidth=f"{penwidth}")
104111

105-
dot.render(output_image_path, format='png', view=False, cleanup=True)
112+
dot.render(output_image_path, format='svg', view=False, cleanup=True)
106113

107114
def get_color(value, colormap):
108115
emphasis = 1.5 # 色の違いを強調
@@ -115,5 +122,13 @@ def get_size(visits, shape):
115122
size = size0 if shape == 'square' else size0 * 2 / (math.pi ** 0.5)
116123
return str(size)
117124

125+
def get_graphviz_id(index, node):
126+
max_board_str_len = 400 # 9路盤が340文字程度
127+
index_str = f"node{index}"
128+
# IDが文字「:」を含むと、graphvizで不具合が生じる。
129+
board_str = node[index]['board_string'].replace(':', ' ')
130+
too_long = len(board_str) > max_board_str_len
131+
return index_str if too_long else f"{index_str}\n{board_str}"
132+
118133
if __name__ == "__main__":
119134
plot_tree_main()

Diff for: gtp/client.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -179,13 +179,7 @@ def _undo(self) -> NoReturn:
179179

180180
handicap_history = self.board.get_handicap_history()
181181

182-
self.board.clear()
183-
184-
for handicap in handicap_history:
185-
self.board.put_handicap_stone(handicap, Stone.BLACK)
186-
187-
for (color, pos, _) in history[:-1]:
188-
self.board.put_stone(pos, color)
182+
self.board.set_history(history[:-1], handicap_history)
189183

190184
respond_success("")
191185

Diff for: mcts/dump.py

+71-6
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import json
2-
from typing import Any, Dict, NoReturn
2+
from typing import Any, Tuple, List, Dict, NoReturn
33

44
from program import PROGRAM_NAME, VERSION, PROTOCOL_VERSION
5-
from board.go_board import GoBoard
5+
from board.go_board import GoBoard, copy_board
66
from board.coordinate import Coordinate
77
from board.stone import Stone
88
from mcts.constant import NOT_EXPANDED
@@ -19,10 +19,12 @@ def dump_mcts_to_json(tree_dict: Dict[str, Any], board: GoBoard, superko: bool)
1919
str: MCTSの状態を表すJSON文字列。
2020
"""
2121
state = {
22-
"dump_version": 1,
22+
"dump_version": 2,
2323
"tree": tree_dict,
2424
"board_size": board.get_board_size(),
2525
"komi": board.get_komi(),
26+
"move_history": _serializable_move_history(board.get_move_history()),
27+
"handicap_history": board.get_handicap_history(),
2628
"superko": superko,
2729
"name": PROGRAM_NAME,
2830
"version": VERSION,
@@ -36,7 +38,12 @@ def enrich_mcts_dict(state: Dict[str, Any]) -> NoReturn:
3638
Args:
3739
state (Dict[str, Any]): MCTSの状態を表す辞書。
3840
"""
39-
coord = Coordinate(board_size=state["board_size"])
41+
root_board = GoBoard(board_size=state["board_size"], komi=state["komi"], \
42+
check_superko=state["superko"])
43+
root_board.set_history(_recovered_move_history(state["move_history"]), \
44+
state["handicap_history"])
45+
46+
coord = Coordinate(board_size=root_board.get_board_size())
4047
tree = state["tree"]
4148
node = tree["node"]
4249

@@ -73,24 +80,31 @@ def enrich_mcts_dict(state: Dict[str, Any]) -> NoReturn:
7380
nodes_pool += expanded_children
7481

7582
# その他いろいろな便利項目を追加
83+
initial_move_color = _str_to_stone(tree["to_move"])
7684
for item in node:
7785
is_root = "parent_index" not in item
7886
if is_root:
7987
item["level"] = 0
8088
item["orders_along_path"] = []
89+
item["gtp_moves_along_path"] = []
8190
item["to_move"] = tree["to_move"]
91+
item["board_string"] = root_board.get_board_string()
8292
continue
8393
parent = node[item["parent_index"]]
94+
index_in_brother = item["index_in_brother"]
95+
gtp_move = coord.convert_to_gtp_format(parent["action"][index_in_brother])
8496
item["level"] = parent["level"] + 1
8597
item["orders_along_path"] = [*parent["orders_along_path"], item["order"]]
8698
item["to_move"] = _opposite_color(parent["to_move"])
99+
item["gtp_moves_along_path"] = [*parent["gtp_moves_along_path"], gtp_move]
100+
item["board_string"] = _get_updated_board_string(root_board, initial_move_color, \
101+
item["gtp_moves_along_path"])
87102
# ルートノードは以下の項目を持たないことに注意
88-
index_in_brother = item["index_in_brother"]
89103
item["policy"] = parent["children_policy"][index_in_brother]
90104
item["visits"] = parent["children_visits"][index_in_brother]
91105
item["value"] = parent["children_value"][index_in_brother]
92106
item["value_sum"] = parent["children_value_sum"][index_in_brother]
93-
item["gtp_move"] = coord.convert_to_gtp_format(parent["action"][index_in_brother])
107+
item["gtp_move"] = gtp_move
94108
item["mean_value"] = item["value_sum"] / item["visits"]
95109
last_move_color = _opposite_color(item["to_move"])
96110
item["raw_black_winrate"] = _black_winrate(item["value"], last_move_color)
@@ -101,3 +115,54 @@ def _opposite_color(color):
101115

102116
def _black_winrate(value, last_move_color):
103117
return value if last_move_color == "black" else 1.0 - value
118+
119+
def _serializable_move_history(move_history: List[Tuple[Stone, int, Any]]) -> List[Tuple[str, int]]:
120+
"""着手の履歴をシリアライズ可能な値に変換する。ただしハッシュ値は廃棄する。
121+
122+
Args:
123+
move_history (List[Tuple[Stone, int, np.array]]): 着手の履歴。
124+
125+
Returns:
126+
Lizt[Tuple[str, int]]: シリアライズ可能なよう変換された着手履歴。
127+
"""
128+
return [(_stone_to_str(color), pos) for (color, pos, _) in move_history]
129+
130+
def _recovered_move_history(converted_move_history: List[Tuple[str, int]]) -> List[Tuple[Stone, int, Any]]:
131+
"""_serializable_move_historyで変換された着手履歴から元の着手履歴を復元する。
132+
ただしハッシュ値はNoneに置きかえられる。
133+
134+
Args:
135+
converted_move_history (Lizt[Tuple[str, int]]): 変換された着手履歴。
136+
137+
Returns:
138+
List[Tuple[Stone, int, Any]]: 復元された着手履歴。
139+
"""
140+
return [(_str_to_stone(color_str), pos, None) for (color_str, pos) in converted_move_history]
141+
142+
def _stone_to_str(color: Stone) -> str:
143+
return 'black' if color == Stone.BLACK else 'white'
144+
145+
def _str_to_stone(color_str: str) -> str:
146+
return Stone.BLACK if color_str == 'black' else Stone.WHITE
147+
148+
def _get_updated_board_string(root_board: GoBoard, initial_move_color: Stone, gtp_moves_along_path: List[str]) -> str:
149+
"""一連の着手後の盤面を表わす文字列を返す。
150+
151+
Args:
152+
root_board (GoBoard): 着手前の盤面。
153+
initial_move_color (Stone): 最初の着手の色。
154+
gtp_moves_along_path (List[str]): 着手位置のリスト。
155+
156+
Returns:
157+
str: 着手後の盤面を表わす文字列。
158+
"""
159+
coord = Coordinate(board_size=root_board.get_board_size())
160+
move_color = initial_move_color
161+
# 「board = copy.deepcopy(root_board)」は遅いので避ける。
162+
board = GoBoard(board_size=root_board.get_board_size(), komi=root_board.get_komi(), check_superko=root_board.check_superko)
163+
copy_board(dst=board, src=root_board)
164+
for (k, move) in enumerate(gtp_moves_along_path):
165+
pos = coord.convert_from_gtp_format(move)
166+
board.put_stone(pos, move_color)
167+
move_color = Stone.get_opponent_color(move_color)
168+
return board.get_board_string()

0 commit comments

Comments
 (0)