Skip to content

Commit a469f86

Browse files
authored
Add metrics (qubvel-org#531)
* Add metrics * Add docs * Add example notebook
1 parent bf1d2bf commit a469f86

13 files changed

+4751
-31
lines changed

README.md

+5-3
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ The main features of this library are:
1212

1313
- High level API (just two lines to create a neural network)
1414
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
15-
- 113 available encoders
15+
- 113 available encoders (and 400+ encoders from [timm](https://github.com/rwightman/pytorch-image-models))
1616
- All encoders have pre-trained weights for faster and better convergence
17+
- Popular metrics and losses for training routines
1718

1819
### [📚 Project Documentation 📚](http://smp.readthedocs.io/)
1920

@@ -68,9 +69,10 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
6869
Congratulations! You are done! Now you can train your model with your favorite framework!
6970

7071
### 💡 Examples <a name="examples"></a>
72+
- Training model for pets binary segmentation with Pytorch-Lightning [notebook](https://github.com/qubvel/segmentation_models.pytorch/blob/master/examples/binary_segmentation_intro.ipynb) and [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googlecolab/colabtools/blob/master/examples/binary_segmentation_intro.ipynb)
7173
- Training model for cars segmentation on CamVid dataset [here](https://github.com/qubvel/segmentation_models.pytorch/blob/master/examples/cars%20segmentation%20(camvid).ipynb).
72-
- Training SMP model with [Catalyst](https://github.com/catalyst-team/catalyst) (high-level framework for PyTorch), [TTAch](https://github.com/qubvel/ttach) (TTA library for PyTorch) and [Albumentations](https://github.com/albu/albumentations) (fast image augmentation library) - [here](https://github.com/catalyst-team/catalyst/blob/master/examples/notebooks/segmentation-tutorial.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/catalyst-team/catalyst/blob/master/examples/notebooks/segmentation-tutorial.ipynb)
73-
- Training SMP model with [Pytorch-Lightning](https://pytorch-lightning.readthedocs.io) framework - [here](https://github.com/ternaus/cloths_segmentation) (clothes binary segmentation by [@teranus](https://github.com/ternaus)).
74+
- Training SMP model with [Catalyst](https://github.com/catalyst-team/catalyst) (high-level framework for PyTorch), [TTAch](https://github.com/qubvel/ttach) (TTA library for PyTorch) and [Albumentations](https://github.com/albu/albumentations) (fast image augmentation library) - [here](https://github.com/catalyst-team/catalyst/blob/v21.02rc0/examples/notebooks/segmentation-tutorial.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/catalyst-team/catalyst/blob/v21.02rc0/examples/notebooks/segmentation-tutorial.ipynb)
75+
- Training SMP model with [Pytorch-Lightning](https://pytorch-lightning.readthedocs.io) framework - [here](https://github.com/ternaus/cloths_segmentation) (clothes binary segmentation by [@ternaus](https://github.com/ternaus)).
7476

7577
### 📦 Models <a name="models"></a>
7678

docs/conf.py

+8
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def get_version():
4646
'sphinx.ext.napoleon',
4747
'sphinx.ext.viewcode',
4848
'sphinx.ext.mathjax',
49+
'autodocsumm',
4950
]
5051

5152
# Add any paths that contain templates here, relative to this directory.
@@ -95,6 +96,8 @@ def get_version():
9596
'tqdm',
9697
'numpy',
9798
'timm',
99+
'cv2',
100+
'PIL',
98101
'pretrainedmodels',
99102
'torchvision',
100103
'efficientnet-pytorch',
@@ -118,3 +121,8 @@ def f(app, obj, bound_method):
118121

119122
def setup(app):
120123
app.connect('autodoc-before-process-signature', f)
124+
125+
126+
# Custom configuration --------------------------------------------------------
127+
128+
autodoc_member_order = 'bysource'

docs/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Welcome to Segmentation Models's documentation!
1616
encoders
1717
encoders_timm
1818
losses
19+
metrics
1920
insights
2021

2122

docs/metrics.rst

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
📈 Metrics
2+
==========
3+
4+
Functional metrics
5+
~~~~~~~~~~~~~~~~~~
6+
.. automodule:: segmentation_models_pytorch.metrics.functional
7+
:members:
8+
:autosummary:

docs/requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
faculty-sphinx-theme==0.2.2
2-
six==1.15.0
2+
six==1.15.0
3+
autodocsumm

examples/binary_segmentation_intro.ipynb

+4,090
Large diffs are not rendered by default.

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ efficientnet-pytorch==0.6.3
44
timm==0.4.12
55

66
tqdm
7-
opencv-python-headless
7+
pillow

segmentation_models_pytorch/__init__.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
from . import datasets
2+
from . import encoders
3+
from . import decoders
4+
from . import losses
5+
from . import metrics
6+
17
from .decoders.unet import Unet
28
from .decoders.unetplusplus import UnetPlusPlus
39
from .decoders.manet import MAnet
@@ -7,27 +13,23 @@
713
from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus
814
from .decoders.pan import PAN
915

10-
from . import encoders
11-
from . import decoders
12-
from . import losses
13-
1416
from .__version__ import __version__
1517

16-
from typing import Optional
17-
import torch
18+
# some private imports for create_model function
19+
from typing import Optional as _Optional
20+
import torch as _torch
1821

1922

2023
def create_model(
2124
arch: str,
2225
encoder_name: str = "resnet34",
23-
encoder_weights: Optional[str] = "imagenet",
26+
encoder_weights: _Optional[str] = "imagenet",
2427
in_channels: int = 3,
2528
classes: int = 1,
2629
**kwargs,
27-
) -> torch.nn.Module:
28-
"""Models wrapper. Allows to create any model just with parametes
29-
30-
"""
30+
) -> _torch.nn.Module:
31+
"""Models entrypoint, allows to create any model architecture just with
32+
parameters, without using its class"""
3133

3234
archs = [Unet, UnetPlusPlus, MAnet, Linknet, FPN, PSPNet, DeepLabV3, DeepLabV3Plus, PAN]
3335
archs_dict = {a.__name__.lower(): a for a in archs}

segmentation_models_pytorch/datasets/oxford_pet.py

+12-14
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import os
2-
import cv2
2+
import torch
33
import shutil
44
import numpy as np
5+
6+
from PIL import Image
57
from tqdm import tqdm
68
from urllib.request import urlretrieve
79

@@ -15,8 +17,6 @@ def __init__(self, root, mode="train", transform=None):
1517
self.root = root
1618
self.mode = mode
1719
self.transform = transform
18-
19-
self._download_dataset() # download only if it does not exist
2020

2121
self.images_directory = os.path.join(self.root, "images")
2222
self.masks_directory = os.path.join(self.root, "annotations", "trimaps")
@@ -32,10 +32,9 @@ def __getitem__(self, idx):
3232
image_path = os.path.join(self.images_directory, filename + ".jpg")
3333
mask_path = os.path.join(self.masks_directory, filename + ".png")
3434

35-
image = cv2.imread(image_path)
36-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
35+
image = np.array(Image.open(image_path).convert("RGB"))
3736

38-
trimap = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
37+
trimap = np.array(Image.open(mask_path))
3938
mask = self._preprocess_mask(trimap)
4039

4140
sample = dict(image=image, mask=mask, trimap=trimap)
@@ -63,34 +62,33 @@ def _read_split(self):
6362
filenames = [x for i, x in enumerate(filenames) if i % 10 == 0]
6463
return filenames
6564

66-
def _download_dataset(self):
65+
@staticmethod
66+
def download(root):
6767

6868
# load images
69-
filepath = os.path.join(self.root, "images.tar.gz")
69+
filepath = os.path.join(root, "images.tar.gz")
7070
download_url(
7171
url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", filepath=filepath,
7272
)
7373
extract_archive(filepath)
7474

7575
# load annotations
76-
filepath = os.path.join(self.root, "annotations.tar.gz")
76+
filepath = os.path.join(root, "annotations.tar.gz")
7777
download_url(
7878
url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", filepath=filepath,
7979
)
8080
extract_archive(filepath)
8181

8282

8383
class SimpleOxfordPetDataset(OxfordPetDataset):
84-
"""Dataset for example without augmentations and transforms"""
85-
8684
def __getitem__(self, *args, **kwargs):
8785

8886
sample = super().__getitem__(*args, **kwargs)
8987

9088
# resize images
91-
image = cv2.resize(sample["image"], (256, 256), cv2.INTER_LINEAR)
92-
mask = cv2.resize(sample["mask"], (256, 256), cv2.INTER_NEAREST)
93-
trimap = cv2.resize(sample["trimap"], (256, 256), cv2.INTER_NEAREST)
89+
image = np.array(Image.fromarray(sample["image"]).resize((256, 256), Image.LINEAR))
90+
mask = np.array(Image.fromarray(sample["mask"]).resize((256, 256), Image.NEAREST))
91+
trimap = np.array(Image.fromarray(sample["trimap"]).resize((256, 256), Image.NEAREST))
9492

9593
# convert to other format HWC -> CHW
9694
sample["image"] = np.moveaxis(image, -1, 0)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from .functional import (
2+
get_stats,
3+
fbeta_score,
4+
f1_score,
5+
iou_score,
6+
accuracy,
7+
precision,
8+
recall,
9+
sensitivity,
10+
specificity,
11+
balanced_accuracy,
12+
positive_predictive_value,
13+
negative_predictive_value,
14+
false_negative_rate,
15+
false_positive_rate,
16+
false_discovery_rate,
17+
false_omission_rate,
18+
positive_likelihood_ratio,
19+
negative_likelihood_ratio,
20+
)

0 commit comments

Comments
 (0)