@@ -163,6 +163,13 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
163163 mask_head = mask_head ,
164164 mask_sampler = mask_sampler_obj ,
165165 mask_roi_aligner = mask_roi_aligner_obj ,
166+ class_agnostic_bbox_pred = detection_head_config .class_agnostic_bbox_pred ,
167+ cascade_class_ensemble = detection_head_config .cascade_class_ensemble ,
168+ min_level = model_config .min_level ,
169+ max_level = model_config .max_level ,
170+ num_scales = model_config .anchor .num_scales ,
171+ aspect_ratios = model_config .anchor .aspect_ratios ,
172+ anchor_size = model_config .anchor .anchor_size ,
166173 outer_boxes_scale = model_config .outer_boxes_scale ,
167174 use_gt_boxes_for_masks = model_config .use_gt_boxes_for_masks )
168175 return model
@@ -193,4 +200,9 @@ def build_model(self):
193200 if self .task_config .freeze_backbone :
194201 model .backbone .trainable = False
195202
203+ # Builds the model through warm-up call.
204+ dummy_images = tf .keras .Input (self .task_config .model .input_size )
205+ dummy_image_shape = tf .keras .layers .Input ([2 ])
206+ _ = model (dummy_images , image_shape = dummy_image_shape , training = False )
207+
196208 return model
0 commit comments