|
2 | 2 | "cells": [
|
3 | 3 | {
|
4 | 4 | "cell_type": "code",
|
5 |
| - "execution_count": 1, |
| 5 | + "execution_count": 8, |
6 | 6 | "id": "b86297a7",
|
7 | 7 | "metadata": {},
|
8 |
| - "outputs": [ |
9 |
| - { |
10 |
| - "data": { |
11 |
| - "text/plain": [ |
12 |
| - "'1.24.2'" |
13 |
| - ] |
14 |
| - }, |
15 |
| - "execution_count": 1, |
16 |
| - "metadata": {}, |
17 |
| - "output_type": "execute_result" |
18 |
| - } |
19 |
| - ], |
| 8 | + "outputs": [], |
20 | 9 | "source": [
|
21 | 10 | "import numpy as np\n",
|
22 |
| - "np.__version__" |
| 11 | + "import math" |
23 | 12 | ]
|
24 | 13 | },
|
25 | 14 | {
|
26 | 15 | "cell_type": "code",
|
27 |
| - "execution_count": 2, |
| 16 | + "execution_count": 9, |
28 | 17 | "id": "e9f409ff",
|
29 | 18 | "metadata": {},
|
30 | 19 | "outputs": [],
|
|
48 | 37 | " return (state.reshape(-1) == 0).astype(np.uint8)\n",
|
49 | 38 | " \n",
|
50 | 39 | " def check_win(self, state, action):\n",
|
| 40 | + " if action == None:\n", |
| 41 | + " return False\n", |
| 42 | + " \n", |
51 | 43 | " row = action // self.column_count\n",
|
52 | 44 | " column = action % self.column_count\n",
|
53 | 45 | " player = state[row, column]\n",
|
|
75 | 67 | " return 0, True\n",
|
76 | 68 | " \n",
|
77 | 69 | " return 0, False\n",
|
| 70 | + " \n", |
| 71 | + " def get_opponent_value(self, value):\n", |
| 72 | + " return -value\n", |
| 73 | + " \n", |
| 74 | + " def change_perspective(self, state, player):\n", |
| 75 | + " return (state * player)\n", |
78 | 76 | " "
|
79 | 77 | ]
|
80 | 78 | },
|
81 | 79 | {
|
82 | 80 | "cell_type": "code",
|
83 |
| - "execution_count": 4, |
| 81 | + "execution_count": 10, |
| 82 | + "id": "c09a4301", |
| 83 | + "metadata": {}, |
| 84 | + "outputs": [], |
| 85 | + "source": [ |
| 86 | + "class Node:\n", |
| 87 | + " def __init__(self, game, args, state, parent=None, action_taken=None):\n", |
| 88 | + " self.game = game\n", |
| 89 | + " self.args = args\n", |
| 90 | + " self.state = state\n", |
| 91 | + " self.parent = parent\n", |
| 92 | + " self.action_taken = action_taken\n", |
| 93 | + " \n", |
| 94 | + " self.children = []\n", |
| 95 | + " self.expandable_moves = game.get_valid_moves(state)\n", |
| 96 | + " \n", |
| 97 | + " self.visit_count = 0\n", |
| 98 | + " self.value_sum = 0\n", |
| 99 | + " \n", |
| 100 | + " def is_fully_expanded(self):\n", |
| 101 | + " return np.sum(self.expandable_moves) == 0 and len(self.children) > 0\n", |
| 102 | + " \n", |
| 103 | + " def select(self):\n", |
| 104 | + " best_child = None\n", |
| 105 | + " best_ucb = -np.inf\n", |
| 106 | + " \n", |
| 107 | + " for child in self.children:\n", |
| 108 | + " ucb = self.get_ucb(child)\n", |
| 109 | + " if ucb > best_ucb:\n", |
| 110 | + " best_child = child\n", |
| 111 | + " best_ucb = ucb\n", |
| 112 | + " \n", |
| 113 | + " return best_child\n", |
| 114 | + " \n", |
| 115 | + " def get_ucb(self, child):\n", |
| 116 | + " q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2\n", |
| 117 | + " return q_value + self.args['C'] * math.sqrt(math.log(self.visit_count) / child.visit_count)\n", |
| 118 | + "\n", |
| 119 | + " def expand(self):\n", |
| 120 | + " action = np.random.choice(np.where(self.expandable_moves == 1)[0])\n", |
| 121 | + " self.expandable_moves[action] = 0\n", |
| 122 | + " \n", |
| 123 | + " child_state = self.state.copy()\n", |
| 124 | + " child_state = self.game.get_next_state(child_state, action, 1)\n", |
| 125 | + " child_state = self.game.change_perspective(child_state, player = -1)\n", |
| 126 | + " \n", |
| 127 | + " child = Node(self.game, self.args, child_state, self, action)\n", |
| 128 | + " self.children.append(child)\n", |
| 129 | + " \n", |
| 130 | + " return child\n", |
| 131 | + " \n", |
| 132 | + " def simulate(self):\n", |
| 133 | + " value, terminated = self.game.get_value_and_terminated(self.state, self.action_taken)\n", |
| 134 | + " value = self.game.get_opponent_value(value)\n", |
| 135 | + " \n", |
| 136 | + " if terminated:\n", |
| 137 | + " return value\n", |
| 138 | + " \n", |
| 139 | + " rollout_state = self.state.copy()\n", |
| 140 | + " rollout_player = 1\n", |
| 141 | + " while True:\n", |
| 142 | + " valid_moves = self.game.get_valid_moves(rollout_state)\n", |
| 143 | + " action = np.random.choice(np.where(valid_moves == 1)[0])\n", |
| 144 | + " rollout_state = self.game.get_next_state(rollout_state, action, rollout_player)\n", |
| 145 | + " \n", |
| 146 | + " value, terminated = self.game.get_value_and_terminated(rollout_state, action)\n", |
| 147 | + " if terminated:\n", |
| 148 | + " if rollout_player == -1:\n", |
| 149 | + " value = self.game.get_opponent_value(value)\n", |
| 150 | + " return value\n", |
| 151 | + " \n", |
| 152 | + " rollout_player = self.game.get_opponent(rollout_player)\n", |
| 153 | + " \n", |
| 154 | + " def backpropagate(self, value):\n", |
| 155 | + " self.value_sum += value\n", |
| 156 | + " self.visit_count += 1\n", |
| 157 | + " \n", |
| 158 | + " value = self.game.get_opponent_value(value)\n", |
| 159 | + " if self.parent is not None:\n", |
| 160 | + " self.parent.backpropagate(value)\n", |
| 161 | + " \n", |
| 162 | + " \n", |
| 163 | + "class MCTS:\n", |
| 164 | + " def __init__(self, game, args):\n", |
| 165 | + " self.game = game\n", |
| 166 | + " self.args = args\n", |
| 167 | + " \n", |
| 168 | + " def search(self, state):\n", |
| 169 | + " root = Node(self.game, self.args, state)\n", |
| 170 | + " \n", |
| 171 | + " for search in range(self.args['num_searches']):\n", |
| 172 | + " node = root\n", |
| 173 | + " \n", |
| 174 | + " while node.is_fully_expanded():\n", |
| 175 | + " node = node.select()\n", |
| 176 | + " \n", |
| 177 | + " value, terminated = self.game.get_value_and_terminated(node.state, node.action_taken)\n", |
| 178 | + " value = self.game.get_opponent_value(value)\n", |
| 179 | + " \n", |
| 180 | + " if not terminated:\n", |
| 181 | + " node = node.expand()\n", |
| 182 | + " value = node.simulate()\n", |
| 183 | + " \n", |
| 184 | + " node.backpropagate(value)\n", |
| 185 | + "\n", |
| 186 | + " action_probs = np.zeros(self.game.action_size)\n", |
| 187 | + " for child in root.children:\n", |
| 188 | + " action_probs[child.action_taken] = child.visit_count\n", |
| 189 | + " action_probs /= np.sum(action_probs)\n", |
| 190 | + " return action_probs" |
| 191 | + ] |
| 192 | + }, |
| 193 | + { |
| 194 | + "cell_type": "code", |
| 195 | + "execution_count": null, |
84 | 196 | "id": "e60e21f1",
|
85 | 197 | "metadata": {},
|
86 | 198 | "outputs": [
|
|
92 | 204 | " [0 0 0]\n",
|
93 | 205 | " [0 0 0]]\n",
|
94 | 206 | "valid_moves: [0, 1, 2, 3, 4, 5, 6, 7, 8]\n",
|
95 |
| - "1: 1\n", |
96 |
| - "[[0 1 0]\n", |
| 207 | + "1: 0\n", |
| 208 | + "[[1 0 0]\n", |
97 | 209 | " [0 0 0]\n",
|
98 | 210 | " [0 0 0]]\n",
|
99 |
| - "valid_moves: [0, 2, 3, 4, 5, 6, 7, 8]\n", |
100 |
| - "-1: 3\n", |
101 |
| - "[[ 0 1 0]\n", |
102 |
| - " [-1 0 0]\n", |
| 211 | + "[[ 1 0 0]\n", |
| 212 | + " [ 0 -1 0]\n", |
103 | 213 | " [ 0 0 0]]\n",
|
104 |
| - "valid_moves: [0, 2, 4, 5, 6, 7, 8]\n", |
105 |
| - "1: 0\n", |
| 214 | + "valid_moves: [1, 2, 3, 5, 6, 7, 8]\n", |
| 215 | + "1: 1\n", |
106 | 216 | "[[ 1 1 0]\n",
|
107 |
| - " [-1 0 0]\n", |
108 |
| - " [ 0 0 0]]\n", |
109 |
| - "valid_moves: [2, 4, 5, 6, 7, 8]\n", |
110 |
| - "-1: 2\n", |
111 |
| - "[[ 1 1 -1]\n", |
112 |
| - " [-1 0 0]\n", |
| 217 | + " [ 0 -1 0]\n", |
113 | 218 | " [ 0 0 0]]\n",
|
114 |
| - "valid_moves: [4, 5, 6, 7, 8]\n", |
115 |
| - "1: 4\n", |
116 | 219 | "[[ 1 1 -1]\n",
|
117 |
| - " [-1 1 0]\n", |
| 220 | + " [ 0 -1 0]\n", |
118 | 221 | " [ 0 0 0]]\n",
|
119 |
| - "valid_moves: [5, 6, 7, 8]\n", |
120 |
| - "-1: 7\n", |
121 |
| - "[[ 1 1 -1]\n", |
122 |
| - " [-1 1 0]\n", |
123 |
| - " [ 0 -1 0]]\n", |
124 |
| - "valid_moves: [5, 6, 8]\n", |
| 222 | + "valid_moves: [3, 5, 6, 7, 8]\n", |
125 | 223 | "1: 6\n",
|
126 | 224 | "[[ 1 1 -1]\n",
|
127 |
| - " [-1 1 0]\n", |
128 |
| - " [ 1 -1 0]]\n", |
129 |
| - "valid_moves: [5, 8]\n", |
130 |
| - "-1: 8\n", |
| 225 | + " [ 0 -1 0]\n", |
| 226 | + " [ 1 0 0]]\n", |
131 | 227 | "[[ 1 1 -1]\n",
|
132 |
| - " [-1 1 0]\n", |
133 |
| - " [ 1 -1 -1]]\n", |
134 |
| - "valid_moves: [5]\n", |
135 |
| - "1: 5\n", |
136 |
| - "[[ 1 1 -1]\n", |
137 |
| - " [-1 1 1]\n", |
138 |
| - " [ 1 -1 -1]]\n", |
139 |
| - "Game drawn\n" |
| 228 | + " [-1 -1 0]\n", |
| 229 | + " [ 1 0 0]]\n", |
| 230 | + "valid_moves: [5, 7, 8]\n" |
140 | 231 | ]
|
141 | 232 | }
|
142 | 233 | ],
|
143 | 234 | "source": [
|
144 | 235 | "tictactoe = TicTacToe()\n",
|
145 | 236 | "player = 1\n",
|
| 237 | + "args = {\n", |
| 238 | + " 'C': 1.4142,\n", |
| 239 | + " 'num_searches': 1000\n", |
| 240 | + "}\n", |
| 241 | + "mcts = MCTS(tictactoe, args)\n", |
146 | 242 | "state = tictactoe.get_initial_state()\n",
|
147 | 243 | "\n",
|
148 | 244 | "while True:\n",
|
149 | 245 | " print(state)\n",
|
150 |
| - " valid_moves = tictactoe.get_valid_moves(state)\n", |
151 |
| - " print(\"valid_moves:\", [i for i in range(tictactoe.action_size) if valid_moves[i] == 1])\n", |
152 |
| - " action = int(input(f\"{player}: \"))\n", |
153 | 246 | " \n",
|
154 |
| - " if valid_moves[action] == 0:\n", |
155 |
| - " print(\"invalid action\")\n", |
156 |
| - " continue\n", |
| 247 | + " if player == 1:\n", |
| 248 | + " valid_moves = tictactoe.get_valid_moves(state)\n", |
| 249 | + " print(\"valid_moves:\", [i for i in range(tictactoe.action_size) if valid_moves[i] == 1])\n", |
| 250 | + " action = int(input(f\"{player}: \"))\n", |
| 251 | + "\n", |
| 252 | + " if valid_moves[action] == 0:\n", |
| 253 | + " print(\"invalid action\")\n", |
| 254 | + " continue\n", |
| 255 | + " else:\n", |
| 256 | + " neutral_state = tictactoe.change_perspective(state, player)\n", |
| 257 | + " mcts_probs = mcts.search(neutral_state)\n", |
| 258 | + " action = np.argmax(mcts_probs)\n", |
157 | 259 | " \n",
|
158 | 260 | " state = tictactoe.get_next_state(state, action, player)\n",
|
159 | 261 | " \n",
|
|
173 | 275 | {
|
174 | 276 | "cell_type": "code",
|
175 | 277 | "execution_count": null,
|
176 |
| - "id": "c09a4301", |
| 278 | + "id": "20fe08e6", |
177 | 279 | "metadata": {},
|
178 | 280 | "outputs": [],
|
179 |
| - "source": [ |
180 |
| - "class MCTS:\n", |
181 |
| - " def __init__(self, game, args):\n", |
182 |
| - " self.game = game\n", |
183 |
| - " self.args = args\n", |
184 |
| - " \n", |
185 |
| - " " |
186 |
| - ] |
| 281 | + "source": [] |
187 | 282 | }
|
188 | 283 | ],
|
189 | 284 | "metadata": {
|
|
0 commit comments