@@ -7,7 +7,7 @@ The main features of this library are:
77
88 - High level API (just two lines to create neural network)
99 - 4 models architectures for binary and multi class segmentation (including legendary Unet)
10- - 31 available encoders for each architecture
10+ - 45 available encoders for each architecture
1111 - All encoders have pre-trained weights for faster and better convergence
1212
1313### Table of content
@@ -16,10 +16,14 @@ The main features of this library are:
1616 3 . [ Models] ( #models )
1717 1 . [ Architectures] ( #architectires )
1818 2 . [ Encoders] ( #encoders )
19- 3 . [ Pretrained weights] ( #weights )
2019 4 . [ Models API] ( #api )
20+ 1 . [ Input channels] ( #input-channels )
21+ 2 . [ Auxiliary classification output] ( #auxiliary-classification-output )
22+ 3 . [ Depth] ( #depth )
2123 5 . [ Installation] ( #installation )
22- 6 . [ License] ( #license )
24+ 6 . [ Competitions won with the library] ( #competitions-won-with-the-library )
25+ 7 . [ License] ( #license )
26+ 8 . [ Contributing] ( #contributing )
2327
2428### Quick start <a name =" start " ></a >
2529Since the library is built on the PyTorch framework, created segmentation model is just a PyTorch nn.Module, which can be created as easy as:
@@ -60,33 +64,95 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
6064
6165#### Encoders <a name =" encoders " ></a >
6266
63- | Type | Encoder names |
64- | ------------| ---------------------------------------------------------------------------------------------|
65- | VGG | vgg11, vgg13, vgg16, vgg19, vgg11bn, vgg13bn, vgg16bn, vgg19bn |
66- | DenseNet | densenet121, densenet169, densenet201, densenet161 |
67- | DPN | dpn68, dpn68b, dpn92, dpn98, dpn107, dpn131 |
68- | Inception | inceptionresnetv2 |
69- | ResNet | resnet18, resnet34, resnet50, resnet101, resnet152 |
70- | ResNeXt | resnext50_32x4d, resnext101_32x8d, resnext101_32x16d, resnext101_32x32d, resnext101_32x48d |
71- | SE-ResNet | se_resnet50, se_resnet101, se_resnet152 |
72- | SE-ResNeXt | se_resnext50_32x4d, se_resnext101_32x4d |
73- | SENet | senet154 |
74- | EfficientNet | efficientnet-b0, efficientnet-b1, efficientnet-b2, efficientnet-b3, efficientnet-b4, efficientnet-b5, efficientnet-b6, efficientnet-b7
75-
76- #### Weights <a name =" weights " ></a >
77-
78- | Weights name | Encoder names |
79- | ---------------------------------------------------------------------------| -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
80- | imagenet+5k | dpn68b, dpn92, dpn107 |
81- | imagenet | vgg11, vgg13, vgg16, vgg19, vgg11bn, vgg13bn, vgg16bn, vgg19bn, <br > densenet121, densenet169, densenet201, densenet161, dpn68, dpn98, dpn131, <br > inceptionresnetv2, <br > resnet18, resnet34, resnet50, resnet101, resnet152, <br > resnext50_32x4d, resnext101_32x8d, <br > se_resnet50, se_resnet101, se_resnet152, <br > se_resnext50_32x4d, se_resnext101_32x4d, <br > senet154, <br > efficientnet-b0, efficientnet-b1, efficientnet-b2, efficientnet-b3, efficientnet-b4, efficientnet-b5, efficientnet-b6, efficientnet-b7 |
82- | [ instagram] ( https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ ) | resnext101_32x8d, resnext101_32x16d, resnext101_32x32d, resnext101_32x48d |
67+ | Encoder | Weights | Params, M |
68+ | --------------------------------| :------------------------------:| :------------------------------:|
69+ | resnet18 | imagenet | 11M |
70+ | resnet34 | imagenet | 21M |
71+ | resnet50 | imagenet | 23M |
72+ | resnet101 | imagenet | 42M |
73+ | resnet152 | imagenet | 58M |
74+ | resnext50_32x4d | imagenet | 22M |
75+ | resnext101_32x8d | imagenet<br >instagram | 86M |
76+ | resnext101_32x16d | instagram | 191M |
77+ | resnext101_32x32d | instagram | 466M |
78+ | resnext101_32x48d | instagram | 826M |
79+ | dpn68 | imagenet | 11M |
80+ | dpn68b | imagenet+5k | 11M |
81+ | dpn92 | imagenet+5k | 34M |
82+ | dpn98 | imagenet | 58M |
83+ | dpn107 | imagenet+5k | 84M |
84+ | dpn131 | imagenet | 76M |
85+ | vgg11 | imagenet | 9M |
86+ | vgg11_bn | imagenet | 9M |
87+ | vgg13 | imagenet | 9M |
88+ | vgg13_bn | imagenet | 9M |
89+ | vgg16 | imagenet | 14M |
90+ | vgg16_bn | imagenet | 14M |
91+ | vgg19 | imagenet | 20M |
92+ | vgg19_bn | imagenet | 20M |
93+ | senet154 | imagenet | 113M |
94+ | se_resnet50 | imagenet | 26M |
95+ | se_resnet101 | imagenet | 47M |
96+ | se_resnet152 | imagenet | 64M |
97+ | se_resnext50_32x4d | imagenet | 25M |
98+ | se_resnext101_32x4d | imagenet | 46M |
99+ | densenet121 | imagenet | 6M |
100+ | densenet169 | imagenet | 12M |
101+ | densenet201 | imagenet | 18M |
102+ | densenet161 | imagenet | 26M |
103+ | inceptionresnetv2 | imagenet<br >imagenet+background | 54M |
104+ | inceptionv4 | imagenet<br >imagenet+background | 41M |
105+ | efficientnet-b0 | imagenet | 4M |
106+ | efficientnet-b1 | imagenet | 6M |
107+ | efficientnet-b2 | imagenet | 7M |
108+ | efficientnet-b3 | imagenet | 10M |
109+ | efficientnet-b4 | imagenet | 17M |
110+ | efficientnet-b5 | imagenet | 28M |
111+ | efficientnet-b6 | imagenet | 40M |
112+ | efficientnet-b7 | imagenet | 63M |
113+ | mobilenet_v2 | imagenet | 2M |
83114
84115### Models API <a name =" api " ></a >
116+
85117 - ` model.encoder ` - pretrained backbone to extract features of different spatial resolution
86- - ` model.decoder ` - segmentation head, depends on models architecture (` Unet ` /` Linknet ` /` PSPNet ` /` FPN ` )
87- - ` model.activation ` - output activation function, one of ` sigmoid ` , ` softmax `
88- - ` model.forward(x) ` - sequentially pass ` x ` through model\` s encoder and decoder (return logits!)
89- - ` model.predict(x) ` - inference method, switch model to ` .eval() ` mode, call ` .forward(x) ` and apply activation function with ` torch.no_grad() `
118+ - ` model.decoder ` - depends on models architecture (` Unet ` /` Linknet ` /` PSPNet ` /` FPN ` )
119+ - ` model.segmentation_head ` - last block to produce required number of mask channels (include also optional upsampling and activation)
120+ - ` model.classification_head ` - optional block which create classification head on top of encoder
121+ - ` model.forward(x) ` - sequentially pass ` x ` through model\` s encoder, decoder and segmentation head (and classification head if specified)
122+
123+ ##### Input channels
124+ Input channels parameter allow you to create models, which process tensors with arbitrary number of channels.
125+ If you use pretrained weights from imagenet - weights of first convolution will be reused for
126+ 1- or 2- channels inputs, for input channels > 4 weights of first convolution will be initialized randomly.
127+ ``` python
128+ model = smp.FPN(' resnet34' , in_channels = 1 )
129+ mask = model(torch.ones([1 , 1 , 64 , 64 ]))
130+ ```
131+
132+ ##### Auxiliary classification output
133+ All models support ` aux_params ` parameters, which is default set to ` None ` .
134+ If ` aux_params = None ` than classification auxiliary output is not created, else
135+ model produce not only ` mask ` , but also ` label ` output with shape ` NC ` .
136+ Classification head consist of GlobalPooling->Dropout(optional)->Linear->Activation(optional) layers, which can be
137+ configured by ` aux_params ` as follows:
138+ ``` python
139+ aux_params= dict (
140+ pooling = ' avg' , # one of 'avg', 'max'
141+ dropout = 0.5 , # dropout ratio, default is None
142+ activation = ' sigmoid' , # activation function, default is None
143+ classes = 4 , # define number of output labels
144+ )
145+ model = smp.Unet(' resnet34' , classes = 4 , aux_params = aux_params)
146+ mask, label = model(x)
147+ ```
148+
149+ ##### Depth
150+ Depth parameter specify a number of downsampling operations in encoder, so you can make
151+ your model lighted if specify smaller ` depth ` .
152+ ``` python
153+ model = smp.FPN(' resnet34' , depth = 4 )
154+ ```
155+
90156
91157### Installation <a name =" installation " ></a >
92158PyPI version:
@@ -97,11 +163,24 @@ Latest version from source:
97163` ` ` bash
98164$ pip install git+https://github.com/qubvel/segmentation_models.pytorch
99165` ` ` `
166+
167+ # ## Competitions won with the library
168+
169+ ` Segmentation Models` package is widely used in the image segmentation competitions.
170+ [Here](https://github.com/qubvel/segmentation_models.pytorch/blob/master/HALLOFFAME.md) you can find competitions, names of the winners and links to their solutions.
171+
172+
100173# ## License <a name="license"></a>
101174Project is distributed under [MIT License](https://github.com/qubvel/segmentation_models.pytorch/blob/master/LICENSE)
102175
103- # ## Run tests
176+
177+ # ## Contributing
178+
179+ # #### Run test
180+ ` ` ` bash
181+ $ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev pytest -p no:cacheprovider
182+ ```
183+ ##### Generate table
104184``` bash
105- $ docker build -f docker/Dockerfile.dev -t smp:dev .
106- $ docker run --rm smp:dev pytest -p no:cacheprovider
185+ $ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev python misc/generate_table.py
107186```
0 commit comments