-
Notifications
You must be signed in to change notification settings - Fork 28
Open
Description
Hello, if I want to add a mask branch on the basis of your target detection code, is the following code correct? The question now is that the mask loss is 0 during training. Thank you!
feature_pyramid = self.build_base_network(input_img_batch) # [P3, P4, P5, P6, P7]
rpn_cls_score, rpn_cls_prob, rpn_cnt_scores, rpn_box = self.rpn_net(feature_pyramid)
# print('rpn_box:', rpn_box.shape)
rpn_cnt_prob = tf.nn.sigmoid(rpn_cnt_scores)
rpn_cnt_prob = tf.expand_dims(rpn_cnt_prob, axis=2)
rpn_cnt_prob = tf.broadcast_to(rpn_cnt_prob,
[self.batch_size, tf.shape(rpn_cls_prob)[1], tf.shape(rpn_cls_prob)[2]])
rpn_prob = rpn_cls_prob * rpn_cnt_prob
ftmaps = []
for i in range(3, 8):
p = 'P%d'%i
ftmaps.append(feature_pyramid[p])
# MASK
with tf.variable_scope('mask_target', reuse=tf.AUTO_REUSE):
# rpn_box: (2, ?, 4)
final_box = []
for i in range(self.batch_size):
boxes, _, _ = postprocess_detctions(rpn_bbox=rpn_box[i, :, :],
rpn_cls_prob=rpn_prob[i, :, :],
img_shape=img_shape,
is_training=self.is_training)
final_box.append(boxes)
final_box = tf.stack(final_box, axis=0)
# rois: (2, ?, 14, 14, 256)
croped_rois = self.PyramidROIAlign(final_box, ftmaps, img_shape)
# print('rpn_box: ', final_box.shape)
# print('croped_rois: ', croped_rois.shape)
mask = []
for i in range(self.batch_size):
# print('m: ', croped_rois[i].shape)
m = croped_rois[i]
for _ in range(4):
m = slim.conv2d(m, 256, [3, 3], stride=1, padding='SAME', activation_fn=tf.nn.relu)
# to 28 x 28
m = slim.conv2d_transpose(m, 256, 2, stride=2, padding='VALID', activation_fn=tf.nn.relu)
tf.add_to_collection('__TRANSPOSED__', m)
m = slim.conv2d(m, cfgs.CLASS_NUM + 1, [1, 1], stride=1, padding='VALID', activation_fn=None)
m = tf.nn.sigmoid(m)
mask.append(m)
mask = tf.stack(mask, axis=0)
# mask: (2, ?, 28, 28, 81)
# print('mask: ', mask.shape)
Metadata
Metadata
Assignees
Labels
No labels