Skip to content

Commit 346052b

Browse files
authored
Merge pull request #2 from AlbertSuarez/add-underlying-models
Add underlying models
2 parents 113acf1 + 5fe9ac6 commit 346052b

File tree

6 files changed

+10
-6
lines changed

6 files changed

+10
-6
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ _That's it_! You have ObjectCut running on port 80 routing traffic using _traefi
113113
114114
### Change underlying model
115115
116-
This project was built using [BASNet](https://github.com/NathanUA/BASNet) as the model for inferring the Salient Object Detection. However, in order to test other ones we added the support to select also [U^2-Net](https://github.com/NathanUA/U-2-Net), also implemented by [Xuebin Qin](https://github.com/NathanUA), in the Inference container specifying it as a environment variable called `MODEL`. You can do that setting your model name at [docker-compose.yml](docker-compose.yml):
116+
This project was built using [BASNet](https://github.com/NathanUA/BASNet) as the model for inferring the Salient Object Detection. However, in order to test other ones we added the support to select also the different versions of [U^2-Net](https://github.com/NathanUA/U-2-Net) (`U2NET`, `U2NETP` and `U2NETPORTRAIT`), also implemented by [Xuebin Qin](https://github.com/NathanUA), in the Inference container specifying it as a environment variable called `MODEL`. You can do that setting your model name at [docker-compose.yml](docker-compose.yml):
117117
118118
```yaml
119119
inference:
@@ -130,7 +130,7 @@ inference:
130130
- object_cut
131131
restart: always
132132
environment:
133-
- MODEL=BASNet # Can also be `U2NET`
133+
- MODEL=BASNet # Can also be `U2NET`, `U2NETP` or `U2NETPORTRAIT`
134134
```
135135

136136
### Integrations

docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ services:
4444
- object_cut
4545
restart: always
4646
environment:
47-
- MODEL=BASNet # Can also be `U2NET`
47+
- MODEL=U2NETP # Can also be `BASNet`, `U2NET` or `U2NETPORTRAIT`
4848
labels:
4949
- 'traefik.enable=true'
5050
- 'traefik.docker.network=object_cut'

inference/Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ ADD ./requirements.lock ${HOME}/requirements.lock
4646
RUN ${HOME}/gdrive_download.sh 1s52ek_4YTDRt_EOkx1FS53u-vJa0c4nu ${HOME}/data/basnet.pth
4747
RUN ${HOME}/gdrive_download.sh 1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ ${HOME}/data/u2net.pth
4848
RUN ${HOME}/gdrive_download.sh 1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy ${HOME}/data/u2netp.pth
49+
RUN ${HOME}/gdrive_download.sh 1IG3HdpcRiDoWNookbncQjeaPN28t90yW ${HOME}/data/u2netportrait.pth
4950

5051
# Install dependencies
5152
RUN python3 -m pip install pip --upgrade

inference/src/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
# Load model
2828
model_name = os.environ.get('MODEL', Model.BASNet.name) # BASNet as default
29+
log.info('Model name: [{}]'.format(model_name))
2930
assert model_name in Model.list()
3031
model_path = os.path.join('data', '{}.pth'.format(model_name.lower()))
3132
log.info('Model path: [{}]'.format(model_path))

inference/src/utils/model_enum.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
from enum import Enum
22

3-
from src.u2_net.model import U2NET
3+
from src.u2_net.model import U2NET, U2NETP
44
from src.bas_net.model import BASNet
55

66

77
class Model(Enum):
88

99
U2NET = U2NET # U2NET
10+
U2NETP = U2NETP # U2NETP
11+
U2NETPORTRAIT = U2NET # U2NETPORTRAIT
1012
BASNet = BASNet # BASNet
1113

1214
def __str__(self):
1315
return self.name
1416

1517
@staticmethod
1618
def list():
17-
return [m.name for m in Model]
19+
return [m for m in Model.__members__.keys()]

multiplexer/test/api/test_remove.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class MultiplexerRemoveTest(BaseTestClass):
88

99
def setUp(self):
1010
self.secret_access = env.get_secret_access()
11-
self.img_url = 'https://objectcut.com/docs/images/object-cut.png'
11+
self.img_url = 'https://objectcut.com/assets/img/raven.jpg'
1212
self.img_url_wrong = 'https://example.com/not-existing.jpg'
1313
self.img_base64_wrong = 'not-a-base64'
1414
self.img_path = os.path.join('test', 'data', 'person.jpg')

0 commit comments

Comments
 (0)