6
6
)
7
7
from projects .ibase_project import IBaseProject
8
8
from utils .VariableClass import VariableClass
9
+ from ultralytics import YOLO
10
+
9
11
import yaml
12
+ import os
13
+ import torch
10
14
11
15
12
16
class BaseProject (IBaseProject ):
@@ -24,6 +28,7 @@ def __init__(self):
24
28
self .proj_dir = None
25
29
self .mapping = None
26
30
self .device = None
31
+ self .models = []
27
32
28
33
def condition_func (self , total_results ):
29
34
"""
@@ -59,6 +64,9 @@ def connect_models(self):
59
64
raise NotImplemented ('Should override this!!!' )
60
65
61
66
def __read_config__ (self , path ):
67
+ """
68
+ See ibase_project.py
69
+ """
62
70
with open (path , 'r' ) as file :
63
71
config = yaml .safe_load (file )
64
72
@@ -72,3 +80,27 @@ def __read_config__(self, path):
72
80
73
81
raise TypeError ('Error while reading configuration file, '
74
82
'make sure models and allowed_classes have the same size' )
83
+
84
+ def __connect_models__ (self ):
85
+ """
86
+ See ibase_project.py
87
+ """
88
+ _cur_dir = os .getcwd ()
89
+ # initialise the yolo model, additionally use the device parameter to specify the device to run the model on.
90
+ self .device = 'cuda' if torch .cuda .is_available () else 'cpu'
91
+ _cur_dir = pdirname (pabspath (__file__ ))
92
+ model_dir = pjoin (_cur_dir , f'../models' )
93
+ model_dir = pabspath (model_dir ) # normalise the link
94
+
95
+ models = []
96
+ for model_name in self ._config .get ('models' ):
97
+ model = YOLO (pjoin (model_dir , model_name )).to (self .device )
98
+ models .append (model )
99
+
100
+ return models
101
+
102
+ def reset_models (self ):
103
+ """
104
+ See ibase_project.py
105
+ """
106
+ self .models = self .__connect_models__ ()
0 commit comments