Skip to content

Commit 6c3b5ef

Browse files
Implemented Monte Carle Tree Search
1 parent 0d69b4b commit 6c3b5ef

File tree

1 file changed

+161
-66
lines changed

1 file changed

+161
-66
lines changed

alpha-zero.ipynb

+161-66
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,18 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 1,
5+
"execution_count": 8,
66
"id": "b86297a7",
77
"metadata": {},
8-
"outputs": [
9-
{
10-
"data": {
11-
"text/plain": [
12-
"'1.24.2'"
13-
]
14-
},
15-
"execution_count": 1,
16-
"metadata": {},
17-
"output_type": "execute_result"
18-
}
19-
],
8+
"outputs": [],
209
"source": [
2110
"import numpy as np\n",
22-
"np.__version__"
11+
"import math"
2312
]
2413
},
2514
{
2615
"cell_type": "code",
27-
"execution_count": 2,
16+
"execution_count": 9,
2817
"id": "e9f409ff",
2918
"metadata": {},
3019
"outputs": [],
@@ -48,6 +37,9 @@
4837
" return (state.reshape(-1) == 0).astype(np.uint8)\n",
4938
" \n",
5039
" def check_win(self, state, action):\n",
40+
" if action == None:\n",
41+
" return False\n",
42+
" \n",
5143
" row = action // self.column_count\n",
5244
" column = action % self.column_count\n",
5345
" player = state[row, column]\n",
@@ -75,12 +67,132 @@
7567
" return 0, True\n",
7668
" \n",
7769
" return 0, False\n",
70+
" \n",
71+
" def get_opponent_value(self, value):\n",
72+
" return -value\n",
73+
" \n",
74+
" def change_perspective(self, state, player):\n",
75+
" return (state * player)\n",
7876
" "
7977
]
8078
},
8179
{
8280
"cell_type": "code",
83-
"execution_count": 4,
81+
"execution_count": 10,
82+
"id": "c09a4301",
83+
"metadata": {},
84+
"outputs": [],
85+
"source": [
86+
"class Node:\n",
87+
" def __init__(self, game, args, state, parent=None, action_taken=None):\n",
88+
" self.game = game\n",
89+
" self.args = args\n",
90+
" self.state = state\n",
91+
" self.parent = parent\n",
92+
" self.action_taken = action_taken\n",
93+
" \n",
94+
" self.children = []\n",
95+
" self.expandable_moves = game.get_valid_moves(state)\n",
96+
" \n",
97+
" self.visit_count = 0\n",
98+
" self.value_sum = 0\n",
99+
" \n",
100+
" def is_fully_expanded(self):\n",
101+
" return np.sum(self.expandable_moves) == 0 and len(self.children) > 0\n",
102+
" \n",
103+
" def select(self):\n",
104+
" best_child = None\n",
105+
" best_ucb = -np.inf\n",
106+
" \n",
107+
" for child in self.children:\n",
108+
" ucb = self.get_ucb(child)\n",
109+
" if ucb > best_ucb:\n",
110+
" best_child = child\n",
111+
" best_ucb = ucb\n",
112+
" \n",
113+
" return best_child\n",
114+
" \n",
115+
" def get_ucb(self, child):\n",
116+
" q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2\n",
117+
" return q_value + self.args['C'] * math.sqrt(math.log(self.visit_count) / child.visit_count)\n",
118+
"\n",
119+
" def expand(self):\n",
120+
" action = np.random.choice(np.where(self.expandable_moves == 1)[0])\n",
121+
" self.expandable_moves[action] = 0\n",
122+
" \n",
123+
" child_state = self.state.copy()\n",
124+
" child_state = self.game.get_next_state(child_state, action, 1)\n",
125+
" child_state = self.game.change_perspective(child_state, player = -1)\n",
126+
" \n",
127+
" child = Node(self.game, self.args, child_state, self, action)\n",
128+
" self.children.append(child)\n",
129+
" \n",
130+
" return child\n",
131+
" \n",
132+
" def simulate(self):\n",
133+
" value, terminated = self.game.get_value_and_terminated(self.state, self.action_taken)\n",
134+
" value = self.game.get_opponent_value(value)\n",
135+
" \n",
136+
" if terminated:\n",
137+
" return value\n",
138+
" \n",
139+
" rollout_state = self.state.copy()\n",
140+
" rollout_player = 1\n",
141+
" while True:\n",
142+
" valid_moves = self.game.get_valid_moves(rollout_state)\n",
143+
" action = np.random.choice(np.where(valid_moves == 1)[0])\n",
144+
" rollout_state = self.game.get_next_state(rollout_state, action, rollout_player)\n",
145+
" \n",
146+
" value, terminated = self.game.get_value_and_terminated(rollout_state, action)\n",
147+
" if terminated:\n",
148+
" if rollout_player == -1:\n",
149+
" value = self.game.get_opponent_value(value)\n",
150+
" return value\n",
151+
" \n",
152+
" rollout_player = self.game.get_opponent(rollout_player)\n",
153+
" \n",
154+
" def backpropagate(self, value):\n",
155+
" self.value_sum += value\n",
156+
" self.visit_count += 1\n",
157+
" \n",
158+
" value = self.game.get_opponent_value(value)\n",
159+
" if self.parent is not None:\n",
160+
" self.parent.backpropagate(value)\n",
161+
" \n",
162+
" \n",
163+
"class MCTS:\n",
164+
" def __init__(self, game, args):\n",
165+
" self.game = game\n",
166+
" self.args = args\n",
167+
" \n",
168+
" def search(self, state):\n",
169+
" root = Node(self.game, self.args, state)\n",
170+
" \n",
171+
" for search in range(self.args['num_searches']):\n",
172+
" node = root\n",
173+
" \n",
174+
" while node.is_fully_expanded():\n",
175+
" node = node.select()\n",
176+
" \n",
177+
" value, terminated = self.game.get_value_and_terminated(node.state, node.action_taken)\n",
178+
" value = self.game.get_opponent_value(value)\n",
179+
" \n",
180+
" if not terminated:\n",
181+
" node = node.expand()\n",
182+
" value = node.simulate()\n",
183+
" \n",
184+
" node.backpropagate(value)\n",
185+
"\n",
186+
" action_probs = np.zeros(self.game.action_size)\n",
187+
" for child in root.children:\n",
188+
" action_probs[child.action_taken] = child.visit_count\n",
189+
" action_probs /= np.sum(action_probs)\n",
190+
" return action_probs"
191+
]
192+
},
193+
{
194+
"cell_type": "code",
195+
"execution_count": null,
84196
"id": "e60e21f1",
85197
"metadata": {},
86198
"outputs": [
@@ -92,68 +204,58 @@
92204
" [0 0 0]\n",
93205
" [0 0 0]]\n",
94206
"valid_moves: [0, 1, 2, 3, 4, 5, 6, 7, 8]\n",
95-
"1: 1\n",
96-
"[[0 1 0]\n",
207+
"1: 0\n",
208+
"[[1 0 0]\n",
97209
" [0 0 0]\n",
98210
" [0 0 0]]\n",
99-
"valid_moves: [0, 2, 3, 4, 5, 6, 7, 8]\n",
100-
"-1: 3\n",
101-
"[[ 0 1 0]\n",
102-
" [-1 0 0]\n",
211+
"[[ 1 0 0]\n",
212+
" [ 0 -1 0]\n",
103213
" [ 0 0 0]]\n",
104-
"valid_moves: [0, 2, 4, 5, 6, 7, 8]\n",
105-
"1: 0\n",
214+
"valid_moves: [1, 2, 3, 5, 6, 7, 8]\n",
215+
"1: 1\n",
106216
"[[ 1 1 0]\n",
107-
" [-1 0 0]\n",
108-
" [ 0 0 0]]\n",
109-
"valid_moves: [2, 4, 5, 6, 7, 8]\n",
110-
"-1: 2\n",
111-
"[[ 1 1 -1]\n",
112-
" [-1 0 0]\n",
217+
" [ 0 -1 0]\n",
113218
" [ 0 0 0]]\n",
114-
"valid_moves: [4, 5, 6, 7, 8]\n",
115-
"1: 4\n",
116219
"[[ 1 1 -1]\n",
117-
" [-1 1 0]\n",
220+
" [ 0 -1 0]\n",
118221
" [ 0 0 0]]\n",
119-
"valid_moves: [5, 6, 7, 8]\n",
120-
"-1: 7\n",
121-
"[[ 1 1 -1]\n",
122-
" [-1 1 0]\n",
123-
" [ 0 -1 0]]\n",
124-
"valid_moves: [5, 6, 8]\n",
222+
"valid_moves: [3, 5, 6, 7, 8]\n",
125223
"1: 6\n",
126224
"[[ 1 1 -1]\n",
127-
" [-1 1 0]\n",
128-
" [ 1 -1 0]]\n",
129-
"valid_moves: [5, 8]\n",
130-
"-1: 8\n",
225+
" [ 0 -1 0]\n",
226+
" [ 1 0 0]]\n",
131227
"[[ 1 1 -1]\n",
132-
" [-1 1 0]\n",
133-
" [ 1 -1 -1]]\n",
134-
"valid_moves: [5]\n",
135-
"1: 5\n",
136-
"[[ 1 1 -1]\n",
137-
" [-1 1 1]\n",
138-
" [ 1 -1 -1]]\n",
139-
"Game drawn\n"
228+
" [-1 -1 0]\n",
229+
" [ 1 0 0]]\n",
230+
"valid_moves: [5, 7, 8]\n"
140231
]
141232
}
142233
],
143234
"source": [
144235
"tictactoe = TicTacToe()\n",
145236
"player = 1\n",
237+
"args = {\n",
238+
" 'C': 1.4142,\n",
239+
" 'num_searches': 1000\n",
240+
"}\n",
241+
"mcts = MCTS(tictactoe, args)\n",
146242
"state = tictactoe.get_initial_state()\n",
147243
"\n",
148244
"while True:\n",
149245
" print(state)\n",
150-
" valid_moves = tictactoe.get_valid_moves(state)\n",
151-
" print(\"valid_moves:\", [i for i in range(tictactoe.action_size) if valid_moves[i] == 1])\n",
152-
" action = int(input(f\"{player}: \"))\n",
153246
" \n",
154-
" if valid_moves[action] == 0:\n",
155-
" print(\"invalid action\")\n",
156-
" continue\n",
247+
" if player == 1:\n",
248+
" valid_moves = tictactoe.get_valid_moves(state)\n",
249+
" print(\"valid_moves:\", [i for i in range(tictactoe.action_size) if valid_moves[i] == 1])\n",
250+
" action = int(input(f\"{player}: \"))\n",
251+
"\n",
252+
" if valid_moves[action] == 0:\n",
253+
" print(\"invalid action\")\n",
254+
" continue\n",
255+
" else:\n",
256+
" neutral_state = tictactoe.change_perspective(state, player)\n",
257+
" mcts_probs = mcts.search(neutral_state)\n",
258+
" action = np.argmax(mcts_probs)\n",
157259
" \n",
158260
" state = tictactoe.get_next_state(state, action, player)\n",
159261
" \n",
@@ -173,17 +275,10 @@
173275
{
174276
"cell_type": "code",
175277
"execution_count": null,
176-
"id": "c09a4301",
278+
"id": "20fe08e6",
177279
"metadata": {},
178280
"outputs": [],
179-
"source": [
180-
"class MCTS:\n",
181-
" def __init__(self, game, args):\n",
182-
" self.game = game\n",
183-
" self.args = args\n",
184-
" \n",
185-
" "
186-
]
281+
"source": []
187282
}
188283
],
189284
"metadata": {

0 commit comments

Comments
 (0)