-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtsp_utils.py
276 lines (225 loc) · 10.8 KB
/
tsp_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
from ortools.constraint_solver import routing_enums_pb2
from ortools.constraint_solver import pywrapcp
from scipy.spatial import distance_matrix
import numpy as np
import matplotlib.pyplot as plt
from msdlib import msd
import torch
import time
from tqdm import tqdm
import pandas as pd
from pointer_net import PN_Actor, PN_Critic
def create_data_model(input_data, start_index=0):
"""Stores the data for the problem."""
data = {}
data['distance_matrix'] = input_data
data['num_vehicles'] = 1
data['depot'] = start_index
return data
def get_solution(manager, routing, solution, start_index):
"""Prints solution on console."""
index = routing.Start(start_index)
solutions = []
while not routing.IsEnd(index):
solutions.append(manager.IndexToNode(index))
index = solution.Value(routing.NextVar(index))
solutions.append(manager.IndexToNode(index))
return solutions
def solve_tsp(input_data, start_index=0):
"""Entry point of the program."""
# Instantiate the data problem.
dist_mat = distance_matrix(input_data, input_data)
data = create_data_model(dist_mat, start_index)
# Create the routing index manager.
manager = pywrapcp.RoutingIndexManager(len(data['distance_matrix']),
data['num_vehicles'], data['depot'])
# Create Routing Model.
routing = pywrapcp.RoutingModel(manager)
def distance_callback(from_index, to_index):
"""Returns the distance between the two nodes."""
# Convert from routing variable Index to distance matrix NodeIndex.
from_node = manager.IndexToNode(from_index)
to_node = manager.IndexToNode(to_index)
return data['distance_matrix'][from_node][to_node]
transit_callback_index = routing.RegisterTransitCallback(distance_callback)
# Define cost of each arc.
routing.SetArcCostEvaluatorOfAllVehicles(transit_callback_index)
# Setting first solution heuristic.
search_parameters = pywrapcp.DefaultRoutingSearchParameters()
search_parameters.first_solution_strategy = (
routing_enums_pb2.FirstSolutionStrategy.PATH_CHEAPEST_ARC)
# Solve the problem.
solution = routing.SolveWithParameters(search_parameters)
# Print solution on console.
if solution:
return get_solution(manager, routing, solution, start_index)
def solve_tsp_ortools(_data):
data = _data * 1000
or_solutions = []
runtime = []
for i in tqdm(range(data.shape[0])):
t = time.time()
or_solutions.append(solve_tsp(data[i].tolist()))
runtime.append(time.time() - t)
return or_solutions, runtime
def calculate_distances(instances, solutions):
return [np.sqrt(np.square(instances[i][solutions[i][1:]] - instances[i][solutions[i][:-1]]).sum(axis=1)).sum() for i in range(instances.shape[0])]
# solutions will be a dict containing <name>: {'solution': [], 'distance': [], 'time': []}
def plot_tsp_solutions(instances, solutions, num_plot=5, plot_index=[], title=''):
n_algo = len(solutions)
if title == '':
title = 'TSP routes'
if len(plot_index) == 0:
indices = np.random.choice(
instances.shape[0], size=num_plot, replace=False)
else:
indices = plot_index.copy()
colors = msd.get_named_colors()
for i in sorted(indices):
fig_title = title + ' (index: %d)' % i
fig, ax = plt.subplots(figsize=(5 * n_algo, 5), ncols=n_algo)
if n_algo == 1:
ax = [ax]
fig.suptitle(fig_title, y=1.04, fontsize=12.5, fontweight='bold')
for j, name in enumerate(solutions):
ax[j].plot(instances[i][solutions[name]['solution'][i], 0], instances[i]
[solutions[name]['solution'][i], 1], color=colors[j], marker='o')
ax[j].set_title('%s solution\ndistance: %.3f; time: %.4f' % (
name, solutions[name]['distance'][i], solutions[name]['time'][i]))
fig.tight_layout()
plt.show()
def generate_tsp_instances(config, inference=False, num_test=None):
if inference:
if num_test is None:
num_test = config.num_test
data = torch.rand(num_test, config.problem_size, config.dimension)
else:
data = torch.rand(config.batch_size,
config.problem_size, config.dimension)
return data
class Agent():
def __init__(self, config, dtype=torch.float32):
self.dtype = dtype
self.device = torch.device("cuda" if torch.cuda.is_available() and 'cuda' in config.device else "cpu")
self.config = config
# initializing models
self.actor = PN_Actor(config).to(device=self.device)
self.critic = PN_Critic(config).to(device=self.device)
self.opt_actor = torch.optim.Adam(self.actor.parameters(), lr=self.config.lr)
self.decay_actor = torch.optim.lr_scheduler.StepLR(self.opt_actor, step_size=self.config.lr_step, gamma=self.config.lr_step_decay)
self.opt_critic = torch.optim.Adam(self.critic.parameters(), lr=self.config.lr)
self.decay_critic = torch.optim.lr_scheduler.StepLR(self.opt_critic, step_size=self.config.lr_step, gamma=self.config.lr_step_decay)
# setting model mode
if self.config.inference:
self.actor.eval()
self.critic.eval()
else:
if self.config.resume_training:
self.load_weights()
self.actor.train()
self.critic.train()
# initializing other values
self.range_index = torch.tensor(list(range(self.config.batch_size)), device=self.device).long()
self.reload_factor = .5 if self.config.maximize else 1.5
self.ct_reward = -1e8 if self.config.maximize else 1e8
self.push_factor = 1 + .02 if self.config.maximize else 1 - .02
self.loss_direction = -1 if self.config.maximize else 1
self.grad_clip = 1
self.it = 0
self.reward_stack = []
self.others_stack = []
def check_input(self, x):
# formatting input
if not isinstance(x, torch.Tensor):
x = torch.tensor(x)
x = x.to(dtype=self.dtype, device=self.device)
return x
def solve_agent(self, x):
x = self.check_input(x)
# setting temperature
self.actor.T = self.config.T
# actor ouput
out_index = self.actor(x)
return out_index
def learn(self, x):
# setting temperature 1
self.actor.T = 1
# predicting from actor
x = self.check_input(x)
# print('x.shape:', x.shape, x.min())
indices = self.actor(x)
# print('incides.shape:', indices.shape, indices.min())
log_prob = self.actor.log_prob
# print('log_prob.shape:', log_prob.shape, log_prob.min())
# reward
reward = self.get_reward(x, indices).detach()
# print('reward.shape:', reward.shape, reward.min())
# # exponential baseline calculation
# if self.it == 0:
# self.avg_baseline = reward
# else:
# self.avg_baseline = reward * (1 - self.config.alpha) + self.avg_baseline * self.config.alpha
# V = self.avg_baseline
# baseline prediction from critic
V = self.critic(x) # / self.critic(x) * reward.min()
# print('V.shape:', V.shape, V.min())
# advantage = (reward - V)
# print('advantage.shape:', advantage.shape, advantage.min())
# loss
actor_loss = self.loss_direction * ((reward - V).detach() * log_prob).mean()
critic_loss = ((reward - V) ** 2).mean()
# zero grading
self.opt_actor.zero_grad()
self.opt_critic.zero_grad()
# back-prop-actor
actor_loss.backward()
torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.grad_clip)
self.opt_actor.step()
self.decay_actor.step()
# back-prop-critic
critic_loss.backward()
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.grad_clip)
self.opt_critic.step()
self.decay_critic.step()
if not self.config.active_search:
self.take_summary(reward.mean().tolist(), [V.mean().tolist()])
# printing progress
if self.config.print_progress:
print('\repoch: %06d| actor_loss: %.4f, critic_loss: %.4f, avg_distance: %.3f, avg_critic_out: %.3f '
%(self.it, actor_loss.tolist(), critic_loss.tolist(), reward.mean().tolist(), V.mean().tolist()), end='')
def get_reward(self, x, indices):
# calculate tsp distances
distances = torch.tensor([torch.sqrt(torch.square(x[i, indices[i, 1:], :] - x[i, indices[i, :-1], :]).sum(axis=1)).sum()
for i in range(self.config.batch_size)], dtype=self.dtype, device=self.device)
return distances
def take_summary(self, reward, others=[]):
# stacking reward and other values for learning plot
self.reward_stack.append(reward)
self.others_stack.append(others)
# checking whether to store the model or not
self.it += 1
if self.it % self.config.save_after == 0 or self.it == self.config.epoch:
last_reward = np.mean(self.reward_stack[-self.config.save_after:])
if (self.ct_reward < last_reward and self.config.maximize) or (self.ct_reward > last_reward and not self.config.maximize):
self.save_weights()
if (self.ct_reward * self.reload_factor > last_reward and self.config.maximize) or (self.ct_reward * self.reload_factor < last_reward and not self.config.maximize):
self.load_weights()
def save_weights(self,):
torch.save(self.actor.state_dict(), self.config.savepath + '/actor_model_weights.pt')
torch.save(self.critic.state_dict(), self.config.savepath + '/critic_model_weights.pt')
def load_weights(self,):
self.actor.load_state_dict(torch.load(self.config.loadpath + '/actor_model_weights.pt'))
self.critic.load_state_dict(torch.load(self.config.loadpath + '/critic_model_weights.pt'))
def plot_learning(self,):
rwd_name = 'distance'
other_names = ['critic_out']
same_srs_names = [] # from others stack
df = pd.DataFrame(self.others_stack, columns=other_names)
df[rwd_name] = self.reward_stack
df[rwd_name+'_rolling'] = df[rwd_name].rolling(10).mean()
same_srs_cols = [rwd_name, rwd_name+'_rolling'] + same_srs_names
srs_cols = [c for c in df.columns if c not in same_srs_cols]
same_srs = [df[c] for c in same_srs_cols]
srs = [df[c] for c in srs_cols]
segs = 1
msd.plot_time_series(same_srs=same_srs, srs=srs, segs=segs, fig_y=5)