Skip to content

Commit

Permalink
Merge pull request #11 from kobanium/develop
Browse files Browse the repository at this point in the history
Support MCTS player
  • Loading branch information
kobanium authored Mar 2, 2023
2 parents 0bcdf30 + 05f85b2 commit 3757781
Show file tree
Hide file tree
Showing 15 changed files with 628 additions and 34 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ ignore=CVS

# Add files or directories matching the regex patterns to the ignore-list. The
# regex matches against paths and can be in Posix or Windows format.
ignore-paths=LICENSE,README.md,model,
ignore-paths=LICENSE,README.md,model,requirements.txt

# Files or directories matching the regex patterns are skipped. The regex
# matches against base names, not paths. The default value ignores emacs file
Expand Down
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"python.linting.enabled": true,
"python.linting.pylintEnabled": true,
"python.linting.pylintArgs": [
"--extension-pkg-whitelist=numpy,torch"
"--extension-pkg-whitelist=numpy,torch",
"--generated-members=numpy.*,torch.*"
]
}
27 changes: 18 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# TamaGo
TamaGoはPythonで実装された囲碁の思考エンジンです。
人間の棋譜を利用した教師あり学習とGumbel AlphaZero方式の強化学習をお試しできるプログラムとなる予定です。
現在はランダムな着手を返すプログラムとなっています
Gumbel AlphaZero方式の強化学習をお試しできるプログラムとなる予定です。
現在はSGF形式の棋譜ファイルからの教師あり学習を実行でき、モンテカルロ木探索による着手生成ができます
Python 3.6で動作確認をしています。

* [使用する前提パッケージ](#requirements)
Expand Down Expand Up @@ -37,6 +37,7 @@ python main.py
| --superko | 超劫ルールの有効化 | true または false | true | false | Positional super koのみ対応しています。|
| --model | ネットワークモデルファイルパス | 学習済みモデルファイルパス | model/model.bin | なし | TamaGoのホームディレクトリからの相対パスで指定してください。指定がない場合はニューラルネットワークを使用せずにランダムに着手します。 |
| --use-gpu | GPU使用フラグ | true または false | true | false | |
| --policy-move | Policyの分布に従って着手するフラグ | true または false | true | false | Policyのみの強さを確認するときに使用します。 |

## プログラムの実行例は下記のとおりです
1) 碁盤のサイズを5、model/model.binを学習済みモデルとして使用し、GPUを使用せずに実行するケース
Expand All @@ -47,10 +48,15 @@ python main.py --size 5 --model model/model.bin --use-gpu false
```
python main.py --superko true
```
3) model/model.binを学習済みモデルとして使用し、Policyの分布に従って着手を生成するケース
```
python main.py --model model/model.bin --policy-move true
```

## 学習済みモデルファイルについて
学習済みのモデルファイルについては[こちら](https://github.com/kobanium/TamaGo/releases)から取得してください。modelフォルダ以下にmodel.binファイルを配置するとコマンドラインオプションの指定無しで動かせます。ニューラルネットワークの構造と学習済みモデルファイルが一致しないとロードできないので、取得したモデルファイルのリリースバージョンとTamaGoのバージョンが一致しているかに注意してください。
Version 0.2.1時点のモデルはGNUGo Level 10に対して約90eloほど強いです。
Version 0.3.0時点のモデルはGNUGo Level 10に対して約+90elo(勝率63.5%)程度の強さです。
モンテカルロ木探索で1手あたり100回探索すると、GNUGo Level 10に対して約+160elo(勝率71.8%)程度の強さです。

# How to execute supervised learning
教師あり学習の実行方法については[こちら](doc/ja/supervised_learning.md)をご参照ください。
Expand All @@ -77,15 +83,18 @@ Policyの値による色付けはPolicyの値が大きいほど赤く、小さ
- [x] Zobrist Hash
- [x] Super Koの判定処理
- 探索部の実装
- [ ] 木とノードのデータ構造
- [x] 木とノードのデータ構造
- [ ] モンテカルロ木探索
- [ ] クラシックなMCTS
- [ ] UCT
- [ ] RAVE
- [ ] ランダムシミュレーション
- [ ] PUCT探索
- ~~クラシックなMCTS~~
- ~~UCT~~
- ~~RAVE~~
- ~~ランダムシミュレーション~~
- [x] PUCT探索
- [x] PUCB値の計算
- [x] ニューラルネットワークのミニバッチ処理
- [ ] Sequential Halving applied to tree探索
- [ ] CGOS対応
- [ ] 持ち時間による探索時間制御
- 学習の実装
- [x] SGFファイルの読み込み処理
- [ ] 学習データ生成
Expand Down
26 changes: 22 additions & 4 deletions board/go_board.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import numpy as np
from board.constant import PASS, OB_SIZE, GTP_X_COORDINATE
from board.coordinate import Coordinate
from board.pattern import Pattern
from board.record import Record
from board.pattern import Pattern, copy_pattern
from board.record import Record, copy_record
from board.stone import Stone
from board.string import StringData
from board.string import StringData, copy_strings
from board.zobrist_hash import affect_stone_hash, affect_string_hash
from common.print_console import print_err

Expand Down Expand Up @@ -256,7 +256,7 @@ def get_all_legal_pos(self, color: Stone) -> List[int]:
Returns:
list[int]: 合法手の座標列。
"""
return [pos for pos in self.onboard_pos if self.is_legal_not_eye(pos, color)]
return [pos for pos in self.onboard_pos if self.is_legal(pos, color)]

def display(self, sym: int=0) -> NoReturn:
"""盤面を表示する。
Expand Down Expand Up @@ -312,3 +312,21 @@ def get_symmetrical_coordinate(self, pos: int, sym: int) -> int:
int: 指定した対称の座標。
"""
return self.sym_map[sym][pos]


def copy_board(dst: GoBoard, src: GoBoard):
"""盤面の情報をコピーする。
Args:
dst (GoBoard): コピー先の盤面情報のデータ。
src (GoBoard): コピー元の盤面情報のデータ。
"""
dst.board = src.board[:]
copy_pattern(dst.pattern, src.pattern)
copy_strings(dst.strings, src.strings)
copy_record(dst.record, src.record)
dst.ko_move = src.ko_move
dst.ko_pos = src.ko_pos
dst.prisoner = src.prisoner[:]
dst.positional_hash = src.positional_hash.copy()
dst.moves = src.moves
10 changes: 10 additions & 0 deletions board/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,13 @@ def get_pat3_symmetry8(pat3: int) -> List[int]:
symmetries[7] = pat3_rotate_90(symmetries[3])

return symmetries


def copy_pattern(dst: Pattern, src: Pattern) -> NoReturn:
"""配石パターンのデータをコピーする。
Args:
dst (Pattern): コピー先の配石パターンのデータ。
src (Pattern): コピー元の配石パターンのデータ。
"""
dst.pat3 = src.pat3.copy()
12 changes: 12 additions & 0 deletions board/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,15 @@ def get(self, moves: int) -> Tuple[Stone, int, np.array]:
(Stone, int, np.array): 着手の色、座標、ハッシュ値。
"""
return (self.color[moves], self.pos[moves], self.hash_value[moves])


def copy_record(dst: Record, src: Record) -> NoReturn:
"""着手履歴をコピーする。
Args:
dst (Record): コピー先の着手履歴データ。
src (Record): コピー元の着手履歴データ。
"""
dst.color = src.color[:]
dst.pos = src.pos[:]
dst.hash_value = src.hash_value.copy()
38 changes: 36 additions & 2 deletions board/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ def __init__(self, board_size: int, pos_func, get_neighbor4):
def clear(self) -> NoReturn:
"""全ての連を削除する。
"""
self.string_id = [0 for _ in self.string_id]
self.string_next = [0 for _ in self.string_next]
self.string_id = [0] * len(self.string_id)
self.string_next = [0] * len(self.string_next)
for string in self.string:
string.remove()

Expand Down Expand Up @@ -594,3 +594,37 @@ def display(self) -> NoReturn:
for nei in neighbors:
neighbor += " " + str(nei)
print_err(f"\tNeighbor {len(neighbors)} : {neighbors}")


def copy_string(dst: String, src: String) -> NoReturn:
"""連の情報をコピーする。
Args:
dst (String): コピー先の連のデータ。
src (String): コピー元の連のデータ。
"""
dst.color = src.color
dst.libs = src.libs
dst.lib = src.lib[:]
dst.neighbors = src.neighbors
dst.neighbor = src.neighbor[:]
dst.origin = src.origin
dst.size = src.size
dst.flag = src.flag


def copy_strings(dst: StringData, src: StringData) -> NoReturn:
"""全ての連の情報をコピーする。ただし、存在しない場合は存在フラグをオフにするだけにする。
Args:
dst (StringData): コピー先の連データ。
src (StringData): コピー元の連データ。
"""
dst.string_id = src.string_id[:]
dst.string_next = src.string_next[:]

for i, string in enumerate(src.string):
if string.exist():
copy_string(dst.string[i], string)
else:
dst.string[i].flag = False
28 changes: 18 additions & 10 deletions gtp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from common.print_console import print_err
from gtp.gogui import GoguiAnalyzeCommand, display_policy_distribution, \
display_policy_score
from mcts.tree import MCTSTree
from nn.policy_player import generate_move_from_policy
from nn.network.dual_net import DualNet
from nn.utility import get_torch_device
Expand All @@ -26,14 +27,15 @@ class GtpClient: # pylint: disable=R0903
"""_Go Text Protocolクライアントの実装クラス
"""
def __init__(self, board_size: int, superko: bool, \
model_file_path: str, use_gpu: bool) -> NoReturn:
model_file_path: str, use_gpu: bool, policy_move) -> NoReturn:
"""Go Text Protocolクライアントの初期化をする。
Args:
board_size (int): 碁盤の大きさ。
superko (bool): 超劫判定の有効化。
model_file_path (str): ネットワークパラメータファイルパス。
use_gpu (bool): GPU使用フラグ。
policy_move (bool): Policyの分布に従って着手するフラグ。
"""
self.gtp_commands = [
"version",
Expand Down Expand Up @@ -67,7 +69,7 @@ def __init__(self, board_size: int, superko: bool, \
GoguiAnalyzeCommand("sboard", "Display policy score (White)", \
"display_policy_white"),
]

self.policy_move = policy_move
self.use_network = False

try:
Expand All @@ -78,12 +80,14 @@ def __init__(self, board_size: int, superko: bool, \
torch.set_grad_enabled(False)
self.network.eval()
self.use_network = True
self.mcts = MCTSTree(network=self.network)
except FileNotFoundError:
print_err(f"Model file {model_file_path} is not found")
except RuntimeError:
print_err(f"Failed to load {model_file_path}")



def _respond_success(self, response: str) -> NoReturn:
"""コマンド処理成功時の応答メッセージを表示する。
Expand Down Expand Up @@ -199,16 +203,20 @@ def _genmove(self, color: str) -> NoReturn:
return

if self.use_network:
# Policy Networkから着手生成
pos = generate_move_from_policy(self.network, self.board, genmove_color)
_, previous_move, _ = self.board.record.get(self.board.moves - 1)
if self.board.moves > 1 and previous_move == PASS:
pos = PASS
if self.policy_move:
# Policy Networkから着手生成
pos = generate_move_from_policy(self.network, self.board, genmove_color)
_, previous_move, _ = self.board.record.get(self.board.moves - 1)
if self.board.moves > 1 and previous_move == PASS:
pos = PASS
else:
# モンテカルロ木探索で着手生成
pos = self.mcts.search_best_move(self.board, genmove_color)
else:
# ランダムに着手生成
legal_pos = self.board.get_all_legal_pos(genmove_color)

if len(legal_pos) > 0:
legal_pos = [pos for pos in self.board.onboard_pos \
if self.board.is_legal_not_eye(pos, genmove_color)]
if legal_pos:
pos = random.choice(legal_pos)
else:
pos = PASS
Expand Down
22 changes: 16 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,31 @@
from gtp.client import GtpClient
from board.constant import BOARD_SIZE

default_model_path = os.path.join("model", "model.bin")

@click.command()
@click.option('--size', type=click.IntRange(2, BOARD_SIZE), default=BOARD_SIZE, help="")
@click.option('--superko', type=click.BOOL, default=False, help="")
@click.option('--model', type=click.STRING, default=os.path.join("model", "model.bin"), help="")
@click.option('--use-gpu', type=click.BOOL, default=False, help="")
def gtp_main(size: int, superko: bool, model:str, use_gpu: bool):
@click.option('--size', type=click.IntRange(2, BOARD_SIZE), default=BOARD_SIZE, \
help=f"碁盤のサイズを指定。デフォルトは{BOARD_SIZE}。")
@click.option('--superko', type=click.BOOL, default=False, help="超劫の有効化フラグ。デフォルトはFalse。")
@click.option('--model', type=click.STRING, default=default_model_path, \
help=f"使用するニューラルネットワークのモデルパスを指定する。プログラムのホームディレクトリの相対パスで指定。\
デフォルトは{default_model_path}。")
@click.option('--use-gpu', type=click.BOOL, default=False, \
help="ニューラルネットワークの計算にGPUを使用するフラグ。デフォルトはFalse。")
@click.option('--policy-move', type=click.BOOL, default=False, \
help="Policyの分布に従った着手生成処理フラグ。デフォルトはFalse。")
def gtp_main(size: int, superko: bool, model:str, use_gpu: bool, policy_move: bool):
"""GTPクライアントの起動。
Args:
size (int): 碁盤の大きさ。
superko (bool): 超劫の有効化フラグ。
model (str): プログラムのホームディレクトリからのモデルファイルの相対パス。
use_gpu (bool): ニューラルネットワークでのGPU使用フラグ。
policy_move (bool): Policyの分布に従った着手生成処理フラグ。デフォルトはFalse。
"""
program_dir = os.path.dirname(__file__)
client = GtpClient(size, superko, os.path.join(program_dir, model), use_gpu)
client = GtpClient(size, superko, os.path.join(program_dir, model), use_gpu, policy_move)
client.run()


Expand Down
34 changes: 34 additions & 0 deletions mcts/batch_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""ニューラルネットワーク計算用のキュー。
"""
from typing import List, Tuple
import numpy as np


class BatchQueue:
"""ミニバッチデータを保持するキュー。
"""
def __init__(self):
"""BatchQueueクラスのコンストラクタ。
"""
self.input_plane = []
self.path = []
self.node_index = []

def push(self, input_plane: np.array, path: List[Tuple[int, int]], node_index: int):
"""キューにデータをプッシュする。
Args:
input_plane (np.array): ニューラルネットワークへの入力データ。
path (List[Tuple[int, int]]): ルートから評価ノードへまでの経路。
node_index (int): ニューラルネットワークが評価する局面に対応するノードのインデックス。
"""
self.input_plane.append(input_plane)
self.path.append(path)
self.node_index.append(node_index)

def clear(self):
"""キューのデータを全て削除する。
"""
self.input_plane = []
self.path = []
self.node_index = []
14 changes: 14 additions & 0 deletions mcts/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""探索用のパラメータ設定
"""

# 未展開の子ノードのインデックス
NOT_EXPANDED = -1

# PUCBの第2項の重みパラメータ
PUCB_SECOND_TERM_WEIGHT = 1.0

# 1手ごとの探索回数
PLAYOUTS = 100

# 探索時のミニバッチサイズ
NN_BATCH_SIZE = 1
Loading

0 comments on commit 3757781

Please sign in to comment.