-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrav.py
92 lines (74 loc) · 2.8 KB
/
trav.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import sys
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.onnx
MAX_TRAV = 21
class Traversability(nn.Module):
def forward(self, untraversable: torch.Tensor) -> torch.Tensor:
device = untraversable.device
max_trav_sq = torch.tensor(MAX_TRAV**2, device=device, dtype=torch.int32)
sq_distances = torch.where(untraversable, 0, max_trav_sq)
extra_col = torch.full((sq_distances.shape[0], sq_distances.shape[1], 1), max_trav_sq, device=device)
for d in range(1, MAX_TRAV * 2 + 1, 2):
sq_distances = torch.minimum(
sq_distances,
torch.minimum(
torch.cat([sq_distances[:, :, 1:] + d, extra_col], 2),
torch.cat([extra_col, sq_distances[:, :, :-1] + d], 2),
),
)
extra_row = torch.full((sq_distances.shape[0], 1, sq_distances.shape[2]), max_trav_sq, device=device)
for d in range(1, MAX_TRAV * 2 + 1, 2):
sq_distances = torch.minimum(
sq_distances,
torch.minimum(
torch.cat([sq_distances[:, 1:, :] + d, extra_row], 1),
torch.cat([extra_row, sq_distances[:, :-1, :] + d], 1),
),
)
return sq_distances.sqrt()
# Create an instance of the model
model = Traversability()
def save_onnx():
# Set the model to evaluation mode
model.eval()
# Define the input shape
dummy_input = torch.randint(0, 2, (3, 224, 224), dtype=torch.bool)
# Specify the output file path
onnx_file_path = "traversability.onnx"
# Convert the model to ONNX format
torch.onnx.export(
model,
dummy_input,
onnx_file_path,
input_names=["input"],
output_names=["output"],
)
print("Model converted to ONNX format. Saved as:", onnx_file_path)
def plot_test_input():
untrav = torch.zeros((3, 224, 224), dtype=torch.bool, device=torch.device("cuda"))
untrav[0, 4:40, 70:75] = True
untrav[1, 60:80, 90:95] = True
untrav[2, 20:30, 120:130] = True
trav = model(untrav)
plt.figure()
plt.subplot(3, 2, 1)
plt.imshow(untrav[0].to(torch.device("cpu")), cmap="gray")
plt.subplot(3, 2, 2)
plt.imshow(trav[0].to(torch.device("cpu")), cmap="viridis")
plt.subplot(3, 2, 3)
plt.imshow(untrav[1].to(torch.device("cpu")), cmap="gray")
plt.subplot(3, 2, 4)
plt.imshow(trav[1].to(torch.device("cpu")), cmap="viridis")
plt.subplot(3, 2, 5)
plt.imshow(untrav[2].to(torch.device("cpu")), cmap="gray")
plt.subplot(3, 2, 6)
plt.imshow(trav[2].to(torch.device("cpu")), cmap="viridis")
plt.show()
if __name__ == "__main__":
args = sys.argv[1:]
if "save" in args:
save_onnx()
if "plot" in args:
plot_test_input()