-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
111 lines (92 loc) · 3.99 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import torch
import pytorch3d
from pytorch3d.renderer import (
AlphaCompositor,
PointsRasterizationSettings,
PointsRenderer,
PointsRasterizer,
)
import imageio
import numpy as np
def save_checkpoint(epoch, model, args, best=False):
if best:
path = os.path.join(args.checkpoint_dir, 'best_model.pt')
else:
path = os.path.join(args.checkpoint_dir, 'model_epoch_{}.pt'.format(epoch))
torch.save(model.state_dict(), path)
def create_dir(directory):
"""
Creates a directory if it does not already exist.
"""
if not os.path.exists(directory):
os.makedirs(directory)
def get_points_renderer(
image_size=256, device=None, radius=0.01, background_color=(1, 1, 1)
):
"""
Returns a Pytorch3D renderer for point clouds.
Args:
image_size (int): The rendered image size.
device (torch.device): The torch device to use (CPU or GPU). If not specified,
will automatically use GPU if available, otherwise CPU.
radius (float): The radius of the rendered point in NDC.
background_color (tuple): The background color of the rendered image.
Returns:
PointsRenderer.
"""
if device is None:
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
raster_settings = PointsRasterizationSettings(image_size=image_size, radius=radius,)
renderer = PointsRenderer(
rasterizer=PointsRasterizer(raster_settings=raster_settings),
compositor=AlphaCompositor(background_color=background_color),
)
return renderer
def viz_cls (verts, path, device):
"""
visualize classification result
output: a 360-degree gif
"""
image_size=256
background_color=(1, 1, 1)
# Construct various camera viewpoints
dist = 3
elev = 0
azim = [180 - 12*i for i in range(30)]
R, T = pytorch3d.renderer.cameras.look_at_view_transform(dist=dist, elev=elev, azim=azim, device=device)
c = pytorch3d.renderer.FoVPerspectiveCameras(R=R, T=T, fov=60, device=device)
sample_verts = verts.repeat(30,1,1).to(torch.float)
sample_colors = torch.tensor([0.7,0.7,1.0]).repeat(1,sample_verts.shape[1],1).repeat(30,1,1).to(torch.float)
point_cloud = pytorch3d.structures.Pointclouds(points=sample_verts, features=sample_colors).to(device)
renderer = get_points_renderer(image_size=image_size, background_color=background_color, device=device)
rend = renderer(point_cloud, cameras=c).cpu().numpy() # (30, 256, 256, 3)
rend = (np.clip(rend, 0, 1) * 255).astype(np.uint8)
imageio.mimsave(path, rend, fps=15, loop = 0)
def viz_seg(verts, labels, path, device):
"""
Visualize segmentation result as a 360-degree gif.
"""
image_size = 256
background_color = (1, 1, 1)
colors = [[1.0, 1.0, 1.0], [1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [1.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]
# Construct various camera viewpoints
dist = 3
elev = 0
azim = [180 - 12 * i for i in range(30)]
R, T = pytorch3d.renderer.cameras.look_at_view_transform(dist=dist, elev=elev, azim=azim, device=device)
c = pytorch3d.renderer.FoVPerspectiveCameras(R=R, T=T, fov=60, device=device)
sample_verts = verts.unsqueeze(0).repeat(30, 1, 1).to(torch.float)
sample_labels = labels.unsqueeze(0).repeat(30, 1) # Repeat labels for each viewpoint
sample_colors = torch.zeros_like(sample_verts) # Use the same shape as sample_verts
# Colorize points based on segmentation labels
for i in range(6):
sample_colors[sample_labels == i] = torch.tensor(colors[i])
point_cloud = pytorch3d.structures.Pointclouds(points=sample_verts, features=sample_colors).to(device)
renderer = get_points_renderer(image_size=image_size, background_color=background_color, device=device)
rend = renderer(point_cloud, cameras=c).cpu().numpy() # (30, 256, 256, 3)
rend = (np.clip(rend, 0, 1) * 255).astype(np.uint8)
imageio.mimsave(path, rend, fps=15, loop=0)