-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDisCustomSession.py
129 lines (99 loc) · 3.66 KB
/
DisCustomSession.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
from typing import List
import cv2
import onnxruntime as ort
from PIL import Image
from PIL.Image import Image as PILImage
from rembg.sessions.base import BaseSession
class DisCustomSession(BaseSession):
"""
This class represents a custom session for isnet.
"""
def __init__(
self,
model_name: str,
sess_opts: ort.SessionOptions=None,
providers=None,
*args,
**kwargs
):
"""
Initialize a new DisCustomSession object.
Parameters:
model_name (str): The name of the model.
sess_opts (ort.SessionOptions): The session options.
providers: The providers.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Raises:
ValueError: If model_path is None.
"""
model_path = kwargs.get("model_path")
if model_path is None:
raise ValueError("model_path is required")
if sess_opts is None:
sess_opts = ort.SessionOptions()
if "OMP_NUM_THREADS" in os.environ:
sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
sess_opts.intra_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
super().__init__(model_name, sess_opts, providers, *args, **kwargs)
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
"""
Use a pre-trained model to predict the object in the given image.
Parameters:
img (PILImage): The input image.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
List[PILImage]: A list of predicted mask images.
"""
ort_outs = self.inner_session.run(
None,
self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
)
pred = ort_outs[0][:, 0, :, :]
pred = 255.0 - pred.reshape((1024, 1024)) * 255.0
_, pred = cv2.threshold(pred, 220, 255, cv2.THRESH_BINARY_INV)
mask = Image.fromarray(pred.astype("uint8"), mode="L")
mask = mask.resize(img.size, Image.Resampling.LANCZOS)
return [mask]
@classmethod
def download_models(cls, *args, **kwargs):
"""
Download the model files.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The absolute path to the model files.
"""
model_path = kwargs.get("model_path")
if model_path is None:
return
return os.path.abspath(os.path.expanduser(model_path))
@classmethod
def name(cls, *args, **kwargs):
"""
Get the name of the pre-trained model.
Parameters:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
str: The name of the pre-trained model.
"""
return "isnet-custom"
if __name__ == "__main__":
from configparser import ConfigParser
from argparse import ArgumentParser
from rembg import remove
config = ConfigParser()
config.read('config.ini')
parser = ArgumentParser()
parser.add_argument("image", type=str)
args = parser.parse_args()
model_path = config.get('Settings', 'rembg_model_path', fallback=None)
session = DisCustomSession("isnet-custom", model_path=model_path,
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
image = Image.open(args.image)
image = remove(image, session=session)
image.show()