Skip to content

Commit bf1d2bf

Browse files
authored
Refactoring (qubvel-org#528)
* Reorganize decoders, add deprecation for utils, add dataset * Fix imports
1 parent 4f94380 commit bf1d2bf

31 files changed

+174
-43
lines changed

requirements.txt

+3
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,6 @@ torchvision>=0.5.0
22
pretrainedmodels==0.7.4
33
efficientnet-pytorch==0.6.3
44
timm==0.4.12
5+
6+
tqdm
7+
opencv-python-headless

segmentation_models_pytorch/__init__.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
from .unet import Unet
2-
from .unetplusplus import UnetPlusPlus
3-
from .manet import MAnet
4-
from .linknet import Linknet
5-
from .fpn import FPN
6-
from .pspnet import PSPNet
7-
from .deeplabv3 import DeepLabV3, DeepLabV3Plus
8-
from .pan import PAN
1+
from .decoders.unet import Unet
2+
from .decoders.unetplusplus import UnetPlusPlus
3+
from .decoders.manet import MAnet
4+
from .decoders.linknet import Linknet
5+
from .decoders.fpn import FPN
6+
from .decoders.pspnet import PSPNet
7+
from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus
8+
from .decoders.pan import PAN
99

1010
from . import encoders
11-
from . import utils
11+
from . import decoders
1212
from . import losses
1313

1414
from .__version__ import __version__

segmentation_models_pytorch/base/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def forward(self, x):
2323

2424
return masks
2525

26+
@torch.no_grad()
2627
def predict(self, x):
2728
"""Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()`
2829
@@ -36,7 +37,6 @@ def predict(self, x):
3637
if self.training:
3738
self.eval()
3839

39-
with torch.no_grad():
40-
x = self.forward(x)
40+
x = self.forward(x)
4141

4242
return x
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .oxford_pet import OxfordPetDataset, SimpleOxfordPetDataset
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import os
2+
import cv2
3+
import shutil
4+
import numpy as np
5+
from tqdm import tqdm
6+
from urllib.request import urlretrieve
7+
8+
9+
class OxfordPetDataset(torch.utils.data.Dataset):
10+
11+
def __init__(self, root, mode="train", transform=None):
12+
13+
assert mode in {"train", "valid", "test"}
14+
15+
self.root = root
16+
self.mode = mode
17+
self.transform = transform
18+
19+
self._download_dataset() # download only if it does not exist
20+
21+
self.images_directory = os.path.join(self.root, "images")
22+
self.masks_directory = os.path.join(self.root, "annotations", "trimaps")
23+
24+
self.filenames = self._read_split() # read train/valid/test splits
25+
26+
def __len__(self):
27+
return len(self.filenames)
28+
29+
def __getitem__(self, idx):
30+
31+
filename = self.filenames[idx]
32+
image_path = os.path.join(self.images_directory, filename + ".jpg")
33+
mask_path = os.path.join(self.masks_directory, filename + ".png")
34+
35+
image = cv2.imread(image_path)
36+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
37+
38+
trimap = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
39+
mask = self._preprocess_mask(trimap)
40+
41+
sample = dict(image=image, mask=mask, trimap=trimap)
42+
if self.transform is not None:
43+
sample = self.transform(**sample)
44+
45+
return sample
46+
47+
@staticmethod
48+
def _preprocess_mask(mask):
49+
mask = mask.astype(np.float32)
50+
mask[mask == 2.0] = 0.0
51+
mask[(mask == 1.0) | (mask == 3.0)] = 1.0
52+
return mask
53+
54+
def _read_split(self):
55+
split_filename = "test.txt" if self.mode == "test" else "trainval.txt"
56+
split_filepath = os.path.join(self.root, "annotations", split_filename)
57+
with open(split_filepath) as f:
58+
split_data = f.read().strip("\n").split("\n")
59+
filenames = [x.split(" ")[0] for x in split_data]
60+
if self.mode == "train": # 90% for train
61+
filenames = [x for i, x in enumerate(filenames) if i % 10 != 0]
62+
elif self.mode == "valid": # 10% for validation
63+
filenames = [x for i, x in enumerate(filenames) if i % 10 == 0]
64+
return filenames
65+
66+
def _download_dataset(self):
67+
68+
# load images
69+
filepath = os.path.join(self.root, "images.tar.gz")
70+
download_url(
71+
url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", filepath=filepath,
72+
)
73+
extract_archive(filepath)
74+
75+
# load annotations
76+
filepath = os.path.join(self.root, "annotations.tar.gz")
77+
download_url(
78+
url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", filepath=filepath,
79+
)
80+
extract_archive(filepath)
81+
82+
83+
class SimpleOxfordPetDataset(OxfordPetDataset):
84+
"""Dataset for example without augmentations and transforms"""
85+
86+
def __getitem__(self, *args, **kwargs):
87+
88+
sample = super().__getitem__(*args, **kwargs)
89+
90+
# resize images
91+
image = cv2.resize(sample["image"], (256, 256), cv2.INTER_LINEAR)
92+
mask = cv2.resize(sample["mask"], (256, 256), cv2.INTER_NEAREST)
93+
trimap = cv2.resize(sample["trimap"], (256, 256), cv2.INTER_NEAREST)
94+
95+
# convert to other format HWC -> CHW
96+
sample["image"] = np.moveaxis(image, -1, 0)
97+
sample["mask"] = np.expand_dims(mask, 0)
98+
sample["trimap"] = np.expand_dims(trimap, 0)
99+
100+
return sample
101+
102+
103+
class TqdmUpTo(tqdm):
104+
def update_to(self, b=1, bsize=1, tsize=None):
105+
if tsize is not None:
106+
self.total = tsize
107+
self.update(b * bsize - self.n)
108+
109+
110+
def download_url(url, filepath):
111+
directory = os.path.dirname(os.path.abspath(filepath))
112+
os.makedirs(directory, exist_ok=True)
113+
if os.path.exists(filepath):
114+
return
115+
116+
with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=os.path.basename(filepath)) as t:
117+
urlretrieve(url, filename=filepath, reporthook=t.update_to, data=None)
118+
t.total = t.n
119+
120+
121+
def extract_archive(filepath):
122+
extract_dir = os.path.dirname(os.path.abspath(filepath))
123+
dst_dir = os.path.splitext(filepath)[0]
124+
if not os.path.exists(dst_dir):
125+
shutil.unpack_archive(filepath, extract_dir)

segmentation_models_pytorch/deeplabv3/model.py segmentation_models_pytorch/decoders/deeplabv3/model.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
import torch.nn as nn
2-
1+
from torch import nn
32
from typing import Optional
4-
from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder
5-
from ..base import SegmentationModel, SegmentationHead, ClassificationHead
6-
from ..encoders import get_encoder
73

4+
from segmentation_models_pytorch.base import SegmentationModel, SegmentationHead, ClassificationHead
5+
from segmentation_models_pytorch.encoders import get_encoder
6+
from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder
87

98
class DeepLabV3(SegmentationModel):
109
"""DeepLabV3_ implementation from "Rethinking Atrous Convolution for Semantic Image Segmentation"

segmentation_models_pytorch/fpn/model.py segmentation_models_pytorch/decoders/fpn/model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from typing import Optional, Union
2-
from .decoder import FPNDecoder
3-
from ..base import SegmentationModel, SegmentationHead, ClassificationHead
4-
from ..encoders import get_encoder
52

3+
from segmentation_models_pytorch.base import SegmentationModel, SegmentationHead, ClassificationHead
4+
from segmentation_models_pytorch.encoders import get_encoder
5+
from .decoder import FPNDecoder
66

77
class FPN(SegmentationModel):
88
"""FPN_ is a fully convolution neural network for image semantic segmentation.

segmentation_models_pytorch/linknet/decoder.py segmentation_models_pytorch/decoders/linknet/decoder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch.nn as nn
22

3-
from ..base import modules
3+
from segmentation_models_pytorch.base import modules
44

55

66
class TransposeX2(nn.Sequential):

segmentation_models_pytorch/linknet/model.py segmentation_models_pytorch/decoders/linknet/model.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from typing import Optional, Union
2+
3+
from segmentation_models_pytorch.base import SegmentationHead, SegmentationModel, ClassificationHead
4+
from segmentation_models_pytorch.encoders import get_encoder
25
from .decoder import LinknetDecoder
3-
from ..base import SegmentationHead, SegmentationModel, ClassificationHead
4-
from ..encoders import get_encoder
56

67

78
class Linknet(SegmentationModel):

segmentation_models_pytorch/manet/decoder.py segmentation_models_pytorch/decoders/manet/decoder.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import torch
22
import torch.nn as nn
33
import torch.nn.functional as F
4-
from ..base import modules as md
4+
5+
from segmentation_models_pytorch.base import modules as md
56

67

78
class PAB(nn.Module):

segmentation_models_pytorch/manet/model.py segmentation_models_pytorch/decoders/manet/model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from typing import Optional, Union, List
2+
3+
from segmentation_models_pytorch.encoders import get_encoder
4+
from segmentation_models_pytorch.base import SegmentationModel, SegmentationHead, ClassificationHead
25
from .decoder import MAnetDecoder
3-
from ..encoders import get_encoder
4-
from ..base import SegmentationModel
5-
from ..base import SegmentationHead, ClassificationHead
66

77

88
class MAnet(SegmentationModel):

segmentation_models_pytorch/pan/model.py segmentation_models_pytorch/decoders/pan/model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from typing import Optional, Union
2+
3+
from segmentation_models_pytorch.encoders import get_encoder
4+
from segmentation_models_pytorch.base import SegmentationModel, SegmentationHead, ClassificationHead
25
from .decoder import PANDecoder
3-
from ..encoders import get_encoder
4-
from ..base import SegmentationModel
5-
from ..base import SegmentationHead, ClassificationHead
66

77

88
class PAN(SegmentationModel):

segmentation_models_pytorch/pspnet/decoder.py segmentation_models_pytorch/decoders/pspnet/decoder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44

5-
from ..base import modules
5+
from segmentation_models_pytorch.base import modules
66

77

88
class PSPBlock(nn.Module):

segmentation_models_pytorch/pspnet/model.py segmentation_models_pytorch/decoders/pspnet/model.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
from typing import Optional, Union
22

3+
from segmentation_models_pytorch.encoders import get_encoder
4+
from segmentation_models_pytorch.base import SegmentationModel, SegmentationHead, ClassificationHead
35
from .decoder import PSPDecoder
4-
from ..encoders import get_encoder
5-
6-
from ..base import SegmentationModel
7-
from ..base import SegmentationHead, ClassificationHead
86

97

108
class PSPNet(SegmentationModel):

segmentation_models_pytorch/unet/decoder.py segmentation_models_pytorch/decoders/unet/decoder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44

5-
from ..base import modules as md
5+
from segmentation_models_pytorch.base import modules as md
66

77

88
class DecoderBlock(nn.Module):

segmentation_models_pytorch/unet/model.py segmentation_models_pytorch/decoders/unet/model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from typing import Optional, Union, List
2+
3+
from segmentation_models_pytorch.encoders import get_encoder
4+
from segmentation_models_pytorch.base import SegmentationModel, SegmentationHead, ClassificationHead
25
from .decoder import UnetDecoder
3-
from ..encoders import get_encoder
4-
from ..base import SegmentationModel
5-
from ..base import SegmentationHead, ClassificationHead
66

77

88
class Unet(SegmentationModel):

segmentation_models_pytorch/unetplusplus/decoder.py segmentation_models_pytorch/decoders/unetplusplus/decoder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44

5-
from ..base import modules as md
5+
from segmentation_models_pytorch.base import modules as md
66

77

88
class DecoderBlock(nn.Module):

segmentation_models_pytorch/unetplusplus/model.py segmentation_models_pytorch/decoders/unetplusplus/model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from typing import Optional, Union, List
2+
3+
from segmentation_models_pytorch.encoders import get_encoder
4+
from segmentation_models_pytorch.base import SegmentationModel, SegmentationHead, ClassificationHead
25
from .decoder import UnetPlusPlusDecoder
3-
from ..encoders import get_encoder
4-
from ..base import SegmentationModel
5-
from ..base import SegmentationHead, ClassificationHead
66

77

88
class UnetPlusPlus(SegmentationModel):

segmentation_models_pytorch/metrics/.gitkeep

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import warnings
2+
warnings.warn("`smp.utils` module is deprecated and will be removed in future releases.", DeprecationWarning)
3+
14
from . import train
25
from . import losses
3-
from . import metrics
6+
from . import metrics

0 commit comments

Comments
 (0)