Skip to content

Commit ce01725

Browse files
istranicalgitbook-bot
authored andcommitted
GITBOOK-11: change request with no subject merged in GitBook
1 parent fe32944 commit ce01725

File tree

1 file changed

+24
-7
lines changed

1 file changed

+24
-7
lines changed

tutorials/deep-learning/training-models/training-an-object-detection-and-segmentation-model-in-pytorch.md

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ num_classes = len(ds_train.categories.info.class_names)
4848

4949
For complex dataset like this one, it's critical to carefully define the pre-processing function that returns the torch tensors that are use for training. Here we use an [Albumentations](https://github.com/albumentations-team/albumentations) augmentation pipeline combined with additional pre-processing steps that are necessary for this particular model.
5050

51+
{% hint style="danger" %}
52+
**Note:** This tutorial assumes that the number of masks and bounding boxes for each image is equal
53+
{% endhint %}
54+
5155
```python
5256
# Augmentation pipeline using Albumentations
5357
tform_train = A.Compose([
@@ -58,7 +62,7 @@ tform_train = A.Compose([
5862
], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['class_labels', 'bbox_ids'], min_area=25, min_visibility=0.6)) # 'label_fields' and 'box_ids' are all the fields that will be cut when a bounding box is cut.
5963

6064

61-
# Transformation function for pre-processing the deeplake sample before sending it to the model
65+
# Transformation function for pre-processing the Deep Lake sample before sending it to the model
6266
def transform(sample_in):
6367

6468
# Convert boxes to Pascal VOC format
@@ -71,12 +75,25 @@ def transform(sample_in):
7175

7276
# Pass all data to the Albumentations transformation
7377
# Mask must be converted to a list
74-
transformed = tform_train(image = images,
75-
masks = [sample_in['masks'][:,:,i].astype(np.uint8) for i in range(sample_in['masks'].shape[2])],
76-
bboxes = boxes,
77-
bbox_ids = np.arange(boxes.shape[0]),
78-
class_labels = sample_in['categories'],
79-
)
78+
masks = sample_in['masks']
79+
mask_shape = masks.shape
80+
81+
# This if-else statement was not necessary in Albumentations <1.3.x, because the empty mask scenario was handled gracefully inside of Albumentations. In Albumebtations >1.3.x, empty list of masks fails
82+
if mask_shape[2]>0:
83+
transformed = tform_train(image = images,
84+
masks = [masks[:,:,i].astype(np.uint8) for i in range(mask_shape[2])],
85+
bboxes = boxes,
86+
bbox_ids = np.arange(boxes.shape[0]),
87+
class_labels = sample_in['categories'],
88+
)
89+
else:
90+
transformed = tform_train(image = images,
91+
bboxes = boxes,
92+
bbox_ids = np.arange(boxes.shape[0]),
93+
class_labels = sample_in['categories'],
94+
)
95+
96+
8097

8198
# Convert boxes and labels from lists to torch tensors, because Albumentations does not do that automatically.
8299
# Be very careful with rounding and casting to integers, becuase that can create bounding boxes with invalid dimensions

0 commit comments

Comments
 (0)