Skip to content

Commit

Permalink
Span prediction release
Browse files Browse the repository at this point in the history
  • Loading branch information
Nikolaos Gkanatsios committed Nov 26, 2022
1 parent 6b1e11b commit 872bcfd
Show file tree
Hide file tree
Showing 6 changed files with 502 additions and 30 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ If you need to use a different version, you can try to modify `environment.yml`

- Download [object detector's outputs](https://drive.google.com/file/d/1OAArYe2NIfwSURiv6_ORbKAlYbOwfpVS/view?usp=sharing). Unzip inside `DATA_ROOT`.

- Download span predictor's outputs inside `DATA_ROOT`: [ScanRefer_train](https://zenodo.org/record/7363895/files/scanrefer_pred_spans_train.json?download=1), [ScanRefer_val](https://zenodo.org/record/7363895/files/scanrefer_pred_spans_val.json?download=1), [SR3D](https://zenodo.org/record/7363895/files/sr3d_pred_spans.json?download=1), [NR3D](https://zenodo.org/record/7363895/files/nr3d_pred_spans.json?download=1).

- (optional) Download PointNet++ [checkpoint](https://drive.google.com/file/d/1JwMTOaMWfK0JgOBBHU_2oBGXp9ORo9Q3/view?usp=sharing) into `DATA_ROOT`.

- Run `python prepare_data.py --data_root DATA_ROOT` specifying your `DATA_ROOT`. This will create two .pkl files and has to only run once.
Expand All @@ -53,13 +55,17 @@ The above scripts will run training and evaluation on SR3D. You can edit the fol

- To train on multiple datasets, e.g. on SR3D and NR3D simultaneously, set `--TRAIN_DATASET sr3d nr3d`.

- On NR3D and ScanRefer we need much more training epochs to converge. It's better to monitor the validation accuracy and decrease learning rate accordingly. For example, in `det` setup, we decrease lr at epochs 80 and 90 for NR3D and at epoch 65 for Scanrefer. To disable automatic learning rate decay, you can remove `--lr_decay_epochs` from the train script and manually decrease the learning rate when the validation accuracy converges. Be sure to add `--reduce_lr` flag when decreasing learning rate and continuing from a checkpoint to load optimizers correctly.
- On NR3D and ScanRefer we need much more training epochs to converge. It's better to monitor the validation accuracy and decrease learning rate accordingly. For example, in `det` setup, we decrease lr at epochs 80 and 90 for NR3D and at epoch 65 for Scanrefer. To disable automatic learning rate decay, you can remove `--lr_decay_epochs` from the train script and manually decrease the learning rate when the validation accuracy converges. Be sure to add `--reduce_lr` flag when decreasing learning rate and continuing from a checkpoint to load optimizers correctly.

- (Optional) To train a span predictor `cd src` and `python text_cls.py --dataset DATASET`.

## Pre-trained checkpoints
Download our checkpoints for [SR3D_det](https://zenodo.org/record/6430189/files/sr3d_butd_det_52.1_27.pth?download=1), [NR3D_det](https://zenodo.org/record/6430189/files/bdetr_nr3d_43.3.pth?download=1), [ScanRefer_det](https://zenodo.org/record/6430189/files/scanrefer_det_52.2.pth?download=1), [SR3D_cls](https://zenodo.org/record/6430189/files/bdetr_sr3d_cls_67.1.pth?download=1), [NR3D_cls](https://zenodo.org/record/6430189/files/bdetr_nr3d_cls_55.4.pth?download=1). Add `--checkpoint_path CKPT_NAME` to the above scripts in order to utilize the stored checkpoints.

Note that these checkpoints were stored while using `DistributedDataParallel`. To use them outside these checkpoints without `DistributedDataParallel`, take a look [here](https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686).

Lastly, we also release the checkpoints for span prediction ([ScanRefer](https://zenodo.org/record/7363895/files/scanrefer.pt?download=1), [SR3D](https://zenodo.org/record/7363895/files/sr3d.pt?download=1), [NR3D](https://zenodo.org/record/7363895/files/nr3d_unfrozen.pt?download=1))

## How does the evaluation work?
- For each object query, we compute per-token confidence scores and regress bounding boxes.
- Given a target span, we keep the most confident query for it. This is our model's best guess.
Expand Down
2 changes: 1 addition & 1 deletion data/cls_results.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions models/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ def forward(self, outputs, targets):
)
if is_dist_avail_and_initialized():
torch.distributed.all_reduce(num_boxes)
num_boxes = torch.clamp(num_boxes / dist.get_world_size(), min=1).item()

# Compute all the requested losses
losses = {}
Expand Down
63 changes: 36 additions & 27 deletions src/joint_det_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def load_sr3d_annos(self, dset='sr3d'):
split = 'test'
with open('data/meta_data/sr3d_%s_scans.txt' % split) as f:
scan_ids = set(eval(f.read()))
with open(self.data_path + 'sr3d_pred_spans.json', 'r') as f:
pred_spans = json.load(f)
with open(self.data_path + 'refer_it_3d/%s.csv' % dset) as f:
csv_reader = csv.reader(f)
headers = next(csv_reader)
Expand All @@ -147,9 +149,11 @@ def load_sr3d_annos(self, dset='sr3d'):
'target': line[headers['instance_type']],
'anchors': eval(line[headers['anchors_types']]),
'anchor_ids': eval(line[headers['anchor_ids']]),
'dataset': dset
'dataset': dset,
'pred_pos_map': pred_spans[i]['span'], # predicted span
'span_utterance': pred_spans[i]['utterance'] # for assert
}
for line in csv_reader
for i, line in enumerate(csv_reader)
if line[headers['scan_id']] in scan_ids
and
str(line[headers['mentions_target_class']]).lower() == 'true'
Expand All @@ -163,6 +167,8 @@ def load_nr3d_annos(self):
split = 'test'
with open('data/meta_data/nr3d_%s_scans.txt' % split) as f:
scan_ids = set(eval(f.read()))
with open(self.data_path + 'nr3d_pred_spans.json', 'r') as f:
pred_spans = json.load(f)
with open(self.data_path + 'refer_it_3d/nr3d.csv') as f:
csv_reader = csv.reader(f)
headers = next(csv_reader)
Expand All @@ -175,9 +181,11 @@ def load_nr3d_annos(self):
'utterance': line[headers['utterance']],
'anchor_ids': [],
'anchors': [],
'dataset': 'nr3d'
'dataset': 'nr3d',
'pred_pos_map': pred_spans[i]['span'], # predicted span
'span_utterance': pred_spans[i]['utterance'] # for assert
}
for line in csv_reader
for i, line in enumerate(csv_reader)
if line[headers['scan_id']] in scan_ids
and
str(line[headers['mentions_target_class']]).lower() == 'true'
Expand All @@ -197,8 +205,6 @@ def load_nr3d_annos(self):
== anno['target']
and ind != anno['target_id']
]
# Filter out sentences that do not explicitly mention the target class
annos = [anno for anno in annos if anno['target'] in anno['utterance']]
return annos

def load_scanrefer_annos(self):
Expand All @@ -211,6 +217,9 @@ def load_scanrefer_annos(self):
scan_ids = [line.rstrip().strip('\n') for line in f.readlines()]
with open(_path + '_%s.json' % split) as f:
reader = json.load(f)
with open(self.data_path + f'scanrefer_pred_spans_{split}.json') as f:
pred_spans = json.load(f)

annos = [
{
'scan_id': anno['scene_id'],
Expand All @@ -220,19 +229,14 @@ def load_scanrefer_annos(self):
'target': ' '.join(str(anno['object_name']).split('_')),
'anchors': [],
'anchor_ids': [],
'dataset': 'scanrefer'
'dataset': 'scanrefer',
'pred_pos_map': pred_spans[i]['span'], # predicted span
'span_utterance': pred_spans[i]['utterance'] # for assert
}
for anno in reader
for i, anno in enumerate(reader)
if anno['scene_id'] in scan_ids
]
# Fix missing target reference
for anno in annos:
if anno['target'] not in anno['utterance']:
anno['utterance'] = (
' '.join(anno['utterance'].split(' . ')[0].split()[:-1])
+ ' ' + anno['target'] + ' . '
+ ' . '.join(anno['utterance'].split(' . ')[1:])
)

# Add distractor info
scene2obj = defaultdict(list)
sceneobj2used = defaultdict(list)
Expand Down Expand Up @@ -356,7 +360,7 @@ def _augment(self, pc, color, rotate):

# Rotate/flip only if we don't have a view_dep sentence
if rotate:
theta_z = 90*np.random.randint(0, 4) + (2*np.random.rand() - 1) * 5
theta_z = 90 * np.random.randint(0, 4) + 10 * np.random.rand() - 5
# Flipping along the YZ plane
augmentations['yz_flip'] = np.random.random() > 0.5
if augmentations['yz_flip']:
Expand All @@ -366,15 +370,15 @@ def _augment(self, pc, color, rotate):
if augmentations['xz_flip']:
pc[:, 1] = -pc[:, 1]
else:
theta_z = (2*np.random.rand() - 1) * 5
theta_z = (2 * np.random.rand() - 1) * 5
augmentations['theta_z'] = theta_z
pc[:, :3] = rot_z(pc[:, :3], theta_z)
# Rotate around x
theta_x = (2*np.random.rand() - 1) * 2.5
theta_x = (2 * np.random.rand() - 1) * 2.5
augmentations['theta_x'] = theta_x
pc[:, :3] = rot_x(pc[:, :3], theta_x)
# Rotate around y
theta_y = (2*np.random.rand() - 1) * 2.5
theta_y = (2 * np.random.rand() - 1) * 2.5
augmentations['theta_y'] = theta_y
pc[:, :3] = rot_y(pc[:, :3], theta_y)

Expand All @@ -388,13 +392,13 @@ def _augment(self, pc, color, rotate):
pc[:, :3] += augmentations['shift']

# Scale
augmentations['scale'] = 0.98 + 0.04*np.random.random()
augmentations['scale'] = 0.98 + 0.04 * np.random.random()
pc[:, :3] *= augmentations['scale']

# Color
if color is not None:
color += self.mean_rgb
color *= 0.98 + 0.04*np.random.random((len(color), 3))
color *= 0.98 + 0.04 * np.random.random((len(color), 3))
color -= self.mean_rgb
return pc, color, augmentations

Expand Down Expand Up @@ -465,7 +469,7 @@ def _get_token_positive_map(self, anno):
len_ = len(cat_name)
if start_span < 0:
start_span = caption.find(' ' + cat_name)
len_ = len(caption[start_span+1:].split()[0])
len_ = len(caption[start_span + 1:].split()[0])
if start_span < 0:
start_span = caption.find(cat_name)
orig_start_span = start_span
Expand Down Expand Up @@ -511,7 +515,7 @@ def _get_target_boxes(self, anno, scan):
bboxes[:, 3:] - bboxes[:, :3]
), 1)
if self.split == 'train' and self.augment: # jitter boxes
bboxes[:len(tids)] *= (0.95 + 0.1*np.random.random((len(tids), 6)))
bboxes[:len(tids)] *= 0.95 + 0.1 * np.random.random((len(tids), 6))
bboxes[len(tids):, :3] = 1000
box_label_mask = np.zeros(MAX_NUM_OBJ)
box_label_mask[:len(tids)] = 1
Expand Down Expand Up @@ -550,7 +554,7 @@ def _get_scene_objects(self, scan):
), 1)
all_bboxes[keep] = all_bboxes_
if self.split == 'train' and self.augment:
all_bboxes *= (0.95 + 0.1*np.random.random((len(all_bboxes), 6)))
all_bboxes *= 0.95 + 0.1 * np.random.random((len(all_bboxes), 6))

# Which boxes we're interested for
all_bbox_label_mask = keep
Expand Down Expand Up @@ -682,7 +686,13 @@ def __getitem__(self, index):
self._get_target_boxes(anno, scan)

# Positive map for soft-token and contrastive losses
tokens_positive, positive_map = self._get_token_positive_map(anno)
if anno['dataset'] == 'scannet':
_, positive_map = self._get_token_positive_map(anno)
else:
assert anno['utterance'] == anno['span_utterance'] # sanity check
positive_map = np.zeros((MAX_NUM_OBJ, 256))
positive_map_ = np.array(anno['pred_pos_map']).reshape(-1, 256)
positive_map[:len(positive_map_)] = positive_map_

# Scene gt boxes
(
Expand Down Expand Up @@ -738,7 +748,6 @@ def __getitem__(self, index):
' '.join(anno['utterance'].replace(',', ' ,').split())
+ ' . not mentioned'
),
"tokens_positive": tokens_positive.astype(np.int64),
"positive_map": positive_map.astype(np.float32),
"relation": (
self._find_rel(anno['utterance'])
Expand Down
Loading

0 comments on commit 872bcfd

Please sign in to comment.