Skip to content

Commit 0d69b4b

Browse files
Implemented TicTacToe game
1 parent 23c64b9 commit 0d69b4b

File tree

1 file changed

+163
-4
lines changed

1 file changed

+163
-4
lines changed

alpha-zero.ipynb

+163-4
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,185 @@
55
"execution_count": 1,
66
"id": "b86297a7",
77
"metadata": {},
8+
"outputs": [
9+
{
10+
"data": {
11+
"text/plain": [
12+
"'1.24.2'"
13+
]
14+
},
15+
"execution_count": 1,
16+
"metadata": {},
17+
"output_type": "execute_result"
18+
}
19+
],
20+
"source": [
21+
"import numpy as np\n",
22+
"np.__version__"
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": 2,
28+
"id": "e9f409ff",
29+
"metadata": {},
30+
"outputs": [],
31+
"source": [
32+
"class TicTacToe:\n",
33+
" def __init__(self):\n",
34+
" self.row_count = 3\n",
35+
" self.column_count = 3\n",
36+
" self.action_size = self.row_count * self.column_count\n",
37+
" \n",
38+
" def get_initial_state(self):\n",
39+
" return np.zeros((self.row_count, self.column_count), dtype=int)\n",
40+
" \n",
41+
" def get_next_state(self, state, action, player):\n",
42+
" row = action // self.column_count\n",
43+
" column = action % self.column_count\n",
44+
" state[row, column] = player\n",
45+
" return state\n",
46+
" \n",
47+
" def get_valid_moves(self, state):\n",
48+
" return (state.reshape(-1) == 0).astype(np.uint8)\n",
49+
" \n",
50+
" def check_win(self, state, action):\n",
51+
" row = action // self.column_count\n",
52+
" column = action % self.column_count\n",
53+
" player = state[row, column]\n",
54+
" \n",
55+
" return (\n",
56+
" np.sum(state[row, :]) == player * self.column_count\n",
57+
" or np.sum(state[:, column]) == player * self.row_count\n",
58+
" or np.sum(np.diag(state)) == player * self.row_count # change to diagonal length\n",
59+
" or np.sum(np.diag(np.flip(state, axis = 0))) == player * self.row_count # change to diagonal length\n",
60+
" )\n",
61+
" \n",
62+
" def check_draw(self, state):\n",
63+
" if np.sum(self.get_valid_moves(state)) == 0:\n",
64+
" return True\n",
65+
" else:\n",
66+
" return False\n",
67+
" \n",
68+
" def get_opponent(self, player):\n",
69+
" return -player\n",
70+
" \n",
71+
" def get_value_and_terminated(self, state, action):\n",
72+
" if self.check_win(state, action):\n",
73+
" return 1, True\n",
74+
" if self.check_draw(state):\n",
75+
" return 0, True\n",
76+
" \n",
77+
" return 0, False\n",
78+
" "
79+
]
80+
},
81+
{
82+
"cell_type": "code",
83+
"execution_count": 4,
84+
"id": "e60e21f1",
85+
"metadata": {},
886
"outputs": [
987
{
1088
"name": "stdout",
1189
"output_type": "stream",
1290
"text": [
13-
"Hello\n"
91+
"[[0 0 0]\n",
92+
" [0 0 0]\n",
93+
" [0 0 0]]\n",
94+
"valid_moves: [0, 1, 2, 3, 4, 5, 6, 7, 8]\n",
95+
"1: 1\n",
96+
"[[0 1 0]\n",
97+
" [0 0 0]\n",
98+
" [0 0 0]]\n",
99+
"valid_moves: [0, 2, 3, 4, 5, 6, 7, 8]\n",
100+
"-1: 3\n",
101+
"[[ 0 1 0]\n",
102+
" [-1 0 0]\n",
103+
" [ 0 0 0]]\n",
104+
"valid_moves: [0, 2, 4, 5, 6, 7, 8]\n",
105+
"1: 0\n",
106+
"[[ 1 1 0]\n",
107+
" [-1 0 0]\n",
108+
" [ 0 0 0]]\n",
109+
"valid_moves: [2, 4, 5, 6, 7, 8]\n",
110+
"-1: 2\n",
111+
"[[ 1 1 -1]\n",
112+
" [-1 0 0]\n",
113+
" [ 0 0 0]]\n",
114+
"valid_moves: [4, 5, 6, 7, 8]\n",
115+
"1: 4\n",
116+
"[[ 1 1 -1]\n",
117+
" [-1 1 0]\n",
118+
" [ 0 0 0]]\n",
119+
"valid_moves: [5, 6, 7, 8]\n",
120+
"-1: 7\n",
121+
"[[ 1 1 -1]\n",
122+
" [-1 1 0]\n",
123+
" [ 0 -1 0]]\n",
124+
"valid_moves: [5, 6, 8]\n",
125+
"1: 6\n",
126+
"[[ 1 1 -1]\n",
127+
" [-1 1 0]\n",
128+
" [ 1 -1 0]]\n",
129+
"valid_moves: [5, 8]\n",
130+
"-1: 8\n",
131+
"[[ 1 1 -1]\n",
132+
" [-1 1 0]\n",
133+
" [ 1 -1 -1]]\n",
134+
"valid_moves: [5]\n",
135+
"1: 5\n",
136+
"[[ 1 1 -1]\n",
137+
" [-1 1 1]\n",
138+
" [ 1 -1 -1]]\n",
139+
"Game drawn\n"
14140
]
15141
}
16142
],
17143
"source": [
18-
"print(\"Hello\")"
144+
"tictactoe = TicTacToe()\n",
145+
"player = 1\n",
146+
"state = tictactoe.get_initial_state()\n",
147+
"\n",
148+
"while True:\n",
149+
" print(state)\n",
150+
" valid_moves = tictactoe.get_valid_moves(state)\n",
151+
" print(\"valid_moves:\", [i for i in range(tictactoe.action_size) if valid_moves[i] == 1])\n",
152+
" action = int(input(f\"{player}: \"))\n",
153+
" \n",
154+
" if valid_moves[action] == 0:\n",
155+
" print(\"invalid action\")\n",
156+
" continue\n",
157+
" \n",
158+
" state = tictactoe.get_next_state(state, action, player)\n",
159+
" \n",
160+
" value, terminated = tictactoe.get_value_and_terminated(state, action)\n",
161+
" \n",
162+
" if terminated:\n",
163+
" print(state)\n",
164+
" if value == 1:\n",
165+
" print(\"Player \", player, \" won\")\n",
166+
" else:\n",
167+
" print(\"Game drawn\")\n",
168+
" break\n",
169+
" \n",
170+
" player = tictactoe.get_opponent(player)"
19171
]
20172
},
21173
{
22174
"cell_type": "code",
23175
"execution_count": null,
24-
"id": "e9f409ff",
176+
"id": "c09a4301",
25177
"metadata": {},
26178
"outputs": [],
27-
"source": []
179+
"source": [
180+
"class MCTS:\n",
181+
" def __init__(self, game, args):\n",
182+
" self.game = game\n",
183+
" self.args = args\n",
184+
" \n",
185+
" "
186+
]
28187
}
29188
],
30189
"metadata": {

0 commit comments

Comments
 (0)