-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
262 lines (224 loc) · 9.5 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from concurrent.futures import ProcessPoolExecutor
from torch.utils.data import DataLoader, TensorDataset
from self_play import SelfPlay
from game import ChessGame
from model import AlphaZeroNet
from utils import flatten
class AlphaZeroTrainer:
"""
A refactored trainer for an AlphaZero-like approach, incorporating A2C elements such as advantage and entropy bonus.
Attributes:
board_size (int): Board dimension (for chess, usually 8).
num_ware (int): Number of channels in board representation (e.g., 14 for chess).
action_size (int): Number of possible actions.
num_simulations (int): Number of MCTS rollouts per move.
temperature (float): Exploration parameter for MCTS.
gamma (float): Discount factor for reward accumulation.
lr (float): Learning rate.
weight_decay (float): Weight decay for regularization.
batch_size (int): Mini-batch size for training.
entropy_coef (float): Weight for entropy bonus.
value_loss_coef (float): Weight for value loss in total loss.
model (AlphaZeroNet): Neural network model for policy/value.
optimizer (torch.optim.Optimizer): Optimizer for parameter updates.
criterion_value (nn.Module): Value loss function (MSE by default).
"""
def __init__(
self,
board_size: int,
num_ware: int,
action_size: int,
num_simulations: int,
num_res_blocks: int,
in_channels: int,
mid_channels: int,
temperature: float = 1.0,
gamma: float = 0.99,
lr: float = 0.001,
weight_decay: float = 1e-4,
batch_size: int = 8,
entropy_coef: float = 0.01,
value_loss_coef: float = 1.0,
k_epochs: int = 8,
):
"""
Initializes the AlphaZero-like trainer.
Args:
board_size (int): Chess board size, typically 8.
num_ware (int): Number of input channels (planes).
action_size (int): Number of possible moves (e.g., 4032).
num_simulations (int): Number of MCTS rollouts.
num_res_blocks (int): Number of residual blocks in the network.
in_channels (int): Initial in-channels for conv layers.
mid_channels (int): Mid-channels used inside each residual block.
temperature (float): MCTS temperature for exploration.
gamma (float): Discount factor.
lr (float): Learning rate.
weight_decay (float): Weight decay for optimizer.
batch_size (int): Training batch size.
entropy_coef (float): Weight for entropy bonus.
value_loss_coef (float): Weight for value loss in total loss.
"""
self.board_size = board_size
self.num_ware = num_ware
self.action_size = action_size
self.num_simulations = num_simulations
self.temperature = temperature
self.gamma = gamma
self.batch_size = batch_size
self.entropy_coef = entropy_coef
self.value_loss_coef = value_loss_coef
self.k_epochs = k_epochs
device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = AlphaZeroNet(
board_size, num_ware, action_size, num_res_blocks, in_channels, mid_channels
).to(device)
self.optimizer = optim.Adam(
self.model.parameters(), lr=lr, weight_decay=weight_decay
)
self.criterion_value = nn.MSELoss()
def train(self, epochs: int, num_games_per_epoch: int, max_workers=None) -> None:
"""
Orchestrates multiple epochs of self-play data generation and model training.
Args:
epochs (int): Number of epochs to run.
num_games_per_epoch (int): Number of self-play games per epoch.
max_workers (int|None): Maximum number of parallel workers for self-play.
"""
for epoch in range(epochs):
print(f"===== Epoch {epoch + 1}/{epochs} =====")
training_data = self.generate_self_play_data_parallel(
num_games_per_epoch, max_workers
)
loss = self.update_model(training_data)
print(f"Loss after epoch {epoch + 1}: {loss:.6f}")
def generate_self_play_data_parallel(self, num_games: int, max_workers=None):
"""
Runs multiple self-play games in parallel to gather training data.
Args:
num_games (int): Number of self-play games to generate.
max_workers (int|None): Max parallel processes.
Returns:
A list of (state, policy, reward) tuples.
"""
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = [
executor.submit(self._generate_single_game_data)
for _ in range(num_games)
]
results = [f.result() for f in futures]
data = [item for result in results for item in result]
return data
def _generate_single_game_data(self):
"""
Runs one self-play game using MCTS and returns the collected (state, policy, reward) tuples.
"""
game = ChessGame()
self_play = SelfPlay(self.model, game, self.num_simulations, self.temperature)
states, policies, rewards = self_play.play()
return list(zip(states, policies, rewards))
def update_model(self, training_data) -> float:
"""
Single update cycle for the model, using A2C-like approach (policy + value + advantage + entropy).
Args:
training_data: A list of (state, policy, reward) tuples.
Returns:
float: The final loss value from the last training batch (for logging purposes).
"""
self.model.train()
states, policies, rewards = zip(*training_data)
# Compute discounted returns
accumulative_rewards = [[] for _ in range(len(rewards))]
for i, r_l in enumerate(rewards):
R = 0.0
for idx, reward in enumerate(reversed(r_l)):
R = reward if idx == 0 else self.gamma * R
accumulative_rewards[i].insert(0, R)
# Flatten
states = flatten(states)
policies = flatten(policies)
accumulative_rewards = flatten(accumulative_rewards)
# Convert to tensors
device = self.model.device
states_tensor = torch.tensor(np.array(states), dtype=torch.float32).to(device)
policies_tensor = torch.tensor(np.array(policies), dtype=torch.float32).to(
device
)
accumulative_rewards_tensor = (
torch.tensor(np.array(accumulative_rewards), dtype=torch.float32)
.to(device)
.unsqueeze(1)
)
dataset = TensorDataset(
states_tensor, policies_tensor, accumulative_rewards_tensor
)
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
loss_ = 0
for _ in range(self.k_epochs):
epoch_loss = 0
for batch_states, batch_policies, batch_returns in dataloader:
pred_policies, pred_values = self.model(batch_states)
# Squeeze as necessary
pred_values = pred_values.squeeze()
batch_returns = batch_returns.squeeze()
# Advantage = returns - predicted value
advantage = batch_returns - pred_values
# Policy loss with advantage
log_probs = torch.log(pred_policies + 1e-10)
# Summation across action dimension
policy_loss_term = (batch_policies * log_probs).sum(dim=1)
policy_loss = -torch.mean(policy_loss_term * advantage.detach())
# Entropy bonus
entropy = -torch.sum(pred_policies * log_probs, dim=1).mean()
entropy_bonus = self.entropy_coef * entropy
# Value loss (MSE)
value_loss = self.criterion_value(pred_values, batch_returns)
# Combine into final loss
# A2C typical: policy_loss + c1*value_loss - c2*entropy
loss = policy_loss + (self.value_loss_coef * value_loss) - entropy_bonus
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
epoch_loss += loss.item()
epoch_loss /= len(dataloader)
loss_ += epoch_loss
return loss_ / self.k_epochs
def save_model(self, filepath: str) -> None:
"""
Saves the current model's state dictionary.
Args:
filepath (str): File path to save to.
"""
torch.save(self.model.state_dict(), filepath)
print(f"Model saved to {filepath}.")
def load_model(self, filepath: str) -> None:
"""
Loads a model state dictionary from the specified file path.
Args:
filepath (str): File path from which to load.
"""
self.model.load_state_dict(torch.load(filepath, map_location=self.model.device))
self.model.eval()
print(f"Model loaded from {filepath}.")
if __name__ == "__main__":
# Example usage
trainer = AlphaZeroTrainer(
board_size=8,
num_ware=14,
action_size=4032,
num_simulations=3,
num_res_blocks=2,
in_channels=32,
mid_channels=8,
temperature=1.5,
gamma=0.99,
lr=1e-3,
weight_decay=1e-4,
batch_size=4,
entropy_coef=0.01,
value_loss_coef=1.0,
)