@@ -41,6 +41,7 @@ def __init__(
41
41
split = None ,
42
42
categories = None ,
43
43
transform = None ,
44
+ x_label = 'pcloud' ,
44
45
):
45
46
"""Initialization of the the 3D shape dataset.
46
47
@@ -55,6 +56,7 @@ def __init__(
55
56
# Attributes
56
57
self .dataset_folder = dataset_folder
57
58
self .fields = fields
59
+ self .x_label = x_label
58
60
self .transform = transform
59
61
self .split = split
60
62
# If categories is None, use all subfolders
@@ -141,15 +143,15 @@ def __getitem__(self, idx):
141
143
142
144
def transforms (self , data , transform_type = None ):
143
145
if "rotate" in transform_type :
144
- data ["pcloud" ], data ["points" ], data ["points_iou" ] = rotate_pointcloud (
145
- pointcloud = data ["pcloud" ],
146
+ data [self . x_label ], data ["points" ], data ["points_iou" ] = rotate_pointcloud (
147
+ pointcloud = data [self . x_label ],
146
148
points = data ["points" ],
147
149
points_iou = data .get ("points_iou" ),
148
150
)
149
151
150
152
if "translate" in transform_type :
151
- data ["pcloud" ], data ["points" ], data ["points_iou" ] = translate_pointcloud (
152
- pointcloud = data ["pcloud" ],
153
+ data [self . x_label ], data ["points" ], data ["points_iou" ] = translate_pointcloud (
154
+ pointcloud = data [self . x_label ],
153
155
points = data ["points" ],
154
156
points_iou = data .get ("points_iou" ),
155
157
)
@@ -164,13 +166,13 @@ def transforms(self, data, transform_type=None):
164
166
165
167
if data .get ("points_iou" ) is not None :
166
168
(
167
- data ["pcloud" ],
168
- data ["pcloud" ],
169
+ data [self . x_label ],
170
+ data [self . x_label ],
169
171
data ["points_iou" ],
170
172
points_df ,
171
173
points_iou_df ,
172
174
) = single_translate_pointcloud (
173
- pointcloud = data ["pcloud" ],
175
+ pointcloud = data [self . x_label ],
174
176
points = data ["points" ],
175
177
points_iou = data ["points_iou" ],
176
178
points_df = points_df ,
@@ -181,8 +183,8 @@ def transforms(self, data, transform_type=None):
181
183
if points_iou_df is not None :
182
184
data ["points_iou.df" ] = points_iou_df
183
185
else :
184
- data ["pcloud" ], data ["points" ], points_df = single_translate_pointcloud (
185
- pointcloud = data ["pcloud" ],
186
+ data [self . x_label ], data ["points" ], points_df = single_translate_pointcloud (
187
+ pointcloud = data [self . x_label ],
186
188
points = data ["points" ],
187
189
points_df = points_df ,
188
190
)
0 commit comments