Skip to content

Commit

Permalink
Eval with more background objects
Browse files Browse the repository at this point in the history
  • Loading branch information
Nikolaos Gkanatsios committed Oct 6, 2022
1 parent 817b707 commit 6b1e11b
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/joint_det_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from src.visual_data_handlers import Scan
from .scannet_classes import REL_ALIASES, VIEW_DEP_RELS


NUM_CLASSES = 485
DC = ScannetDatasetConfig(NUM_CLASSES)
DC18 = ScannetDatasetConfig(18)
Expand Down Expand Up @@ -525,11 +526,12 @@ def _get_scene_objects(self, scan):
for ind in range(len(scan.three_d_objects))
])[:MAX_NUM_OBJ]
keep = np.array([False] * MAX_NUM_OBJ)
keep[:len(keep_)] = keep_
keep[:len(keep_)] = True # keep_

# Class ids
cid = np.array([
DC.nyu40id2class[self.label_map[scan.get_object_instance_label(k)]]
if keep_[k] else 325 # this is the 'object' class
for k, kept in enumerate(keep) if kept
])
class_ids = np.zeros((MAX_NUM_OBJ,))
Expand Down Expand Up @@ -705,7 +707,10 @@ def __getitem__(self, index):
all_detected_bbox_label_mask = all_bbox_label_mask
detected_class_ids = np.zeros((len(all_bboxes,)))
classes = np.array(self.cls_results[anno['scan_id']])
detected_class_ids[all_bbox_label_mask] = classes[classes > -1]
# detected_class_ids[all_bbox_label_mask] = classes[classes > -1]
classes[classes == -1] = 325 # 'object' class
_k = all_bbox_label_mask.sum()
detected_class_ids[:_k] = classes[:_k]

# Visualize for debugging
if self.visualize and anno['dataset'].startswith('sr3d'):
Expand Down

0 comments on commit 6b1e11b

Please sign in to comment.