Skip to content

Commit e0769d5

Browse files
author
Ritvik Vasan
committed
utils for meshing sdf pcs
1 parent d707647 commit e0769d5

File tree

8 files changed

+727
-149
lines changed

8 files changed

+727
-149
lines changed

pointcloudutils/__init__.py

-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,3 @@
1111

1212
def get_module_version():
1313
return __version__
14-
15-
16-
from .example import Example # noqa: F401

pointcloudutils/datamodules/shapenet.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(
1616
self,
1717
dataset_folder: str = "/allen/aics/modeling/ritvik/projects/occupancy_networks/data/ShapeNet",
1818
method: str = "shapenet_dfnet",
19+
x_label: str = 'pcloud',
1920
dataset_type: str = "partial_pointcloud",
2021
train_split: str = "train",
2122
val_split: str = "val",
@@ -47,6 +48,7 @@ def __init__(
4748
self.categories = categories
4849
self.train_split = train_split
4950
self.val_split = val_split
51+
self.x_label = x_label
5052
self.test_split = test_split
5153
self.points_subsample = points_subsample
5254
self.input_type = input_type
@@ -91,7 +93,7 @@ def _get_dataset(self, mode):
9193
)
9294

9395
if inputs_field is not None:
94-
fields["pcloud"] = inputs_field
96+
fields[self.x_label] = inputs_field
9597

9698
if self.return_idx:
9799
fields["idx"] = IndexField()
@@ -110,6 +112,7 @@ def _get_dataset(self, mode):
110112
self.splits[mode],
111113
self.categories,
112114
transform,
115+
self.x_label,
113116
)
114117
return dataset
115118

pointcloudutils/datamodules/shapenet_dataset/core.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(
4141
split=None,
4242
categories=None,
4343
transform=None,
44+
x_label='pcloud',
4445
):
4546
"""Initialization of the the 3D shape dataset.
4647
@@ -55,6 +56,7 @@ def __init__(
5556
# Attributes
5657
self.dataset_folder = dataset_folder
5758
self.fields = fields
59+
self.x_label = x_label
5860
self.transform = transform
5961
self.split = split
6062
# If categories is None, use all subfolders
@@ -141,15 +143,15 @@ def __getitem__(self, idx):
141143

142144
def transforms(self, data, transform_type=None):
143145
if "rotate" in transform_type:
144-
data["pcloud"], data["points"], data["points_iou"] = rotate_pointcloud(
145-
pointcloud=data["pcloud"],
146+
data[self.x_label], data["points"], data["points_iou"] = rotate_pointcloud(
147+
pointcloud=data[self.x_label],
146148
points=data["points"],
147149
points_iou=data.get("points_iou"),
148150
)
149151

150152
if "translate" in transform_type:
151-
data["pcloud"], data["points"], data["points_iou"] = translate_pointcloud(
152-
pointcloud=data["pcloud"],
153+
data[self.x_label], data["points"], data["points_iou"] = translate_pointcloud(
154+
pointcloud=data[self.x_label],
153155
points=data["points"],
154156
points_iou=data.get("points_iou"),
155157
)
@@ -164,13 +166,13 @@ def transforms(self, data, transform_type=None):
164166

165167
if data.get("points_iou") is not None:
166168
(
167-
data["pcloud"],
168-
data["pcloud"],
169+
data[self.x_label],
170+
data[self.x_label],
169171
data["points_iou"],
170172
points_df,
171173
points_iou_df,
172174
) = single_translate_pointcloud(
173-
pointcloud=data["pcloud"],
175+
pointcloud=data[self.x_label],
174176
points=data["points"],
175177
points_iou=data["points_iou"],
176178
points_df=points_df,
@@ -181,8 +183,8 @@ def transforms(self, data, transform_type=None):
181183
if points_iou_df is not None:
182184
data["points_iou.df"] = points_iou_df
183185
else:
184-
data["pcloud"], data["points"], points_df = single_translate_pointcloud(
185-
pointcloud=data["pcloud"],
186+
data[self.x_label], data["points"], points_df = single_translate_pointcloud(
187+
pointcloud=data[self.x_label],
186188
points=data["points"],
187189
points_df=points_df,
188190
)

pointcloudutils/example.py

-135
This file was deleted.

pointcloudutils/networks/equiv_transformer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch import nn
33

44
from cyto_dl import utils
5-
from .vnn import VNLinear, VNRotationMatrix
5+
from cyto_dl.nn.point_cloud.vnn import VNLinear, VNRotationMatrix
66

77
log = utils.get_pylogger(__name__)
88

0 commit comments

Comments
 (0)