Skip to content

Commit e35b2a8

Browse files
authored
Merge pull request #18 from uug-ai/fix_bug
Fix occasional crash bug due to different video resolutions
2 parents 43103c3 + 4627371 commit e35b2a8

File tree

5 files changed

+74
-50
lines changed

5 files changed

+74
-50
lines changed

projects/base_project.py

+32
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
)
77
from projects.ibase_project import IBaseProject
88
from utils.VariableClass import VariableClass
9+
from ultralytics import YOLO
10+
911
import yaml
12+
import os
13+
import torch
1014

1115

1216
class BaseProject(IBaseProject):
@@ -24,6 +28,7 @@ def __init__(self):
2428
self.proj_dir = None
2529
self.mapping = None
2630
self.device = None
31+
self.models = []
2732

2833
def condition_func(self, total_results):
2934
"""
@@ -59,6 +64,9 @@ def connect_models(self):
5964
raise NotImplemented('Should override this!!!')
6065

6166
def __read_config__(self, path):
67+
"""
68+
See ibase_project.py
69+
"""
6270
with open(path, 'r') as file:
6371
config = yaml.safe_load(file)
6472

@@ -72,3 +80,27 @@ def __read_config__(self, path):
7280

7381
raise TypeError('Error while reading configuration file, '
7482
'make sure models and allowed_classes have the same size')
83+
84+
def __connect_models__(self):
85+
"""
86+
See ibase_project.py
87+
"""
88+
_cur_dir = os.getcwd()
89+
# initialise the yolo model, additionally use the device parameter to specify the device to run the model on.
90+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
91+
_cur_dir = pdirname(pabspath(__file__))
92+
model_dir = pjoin(_cur_dir, f'../models')
93+
model_dir = pabspath(model_dir) # normalise the link
94+
95+
models = []
96+
for model_name in self._config.get('models'):
97+
model = YOLO(pjoin(model_dir, model_name)).to(self.device)
98+
models.append(model)
99+
100+
return models
101+
102+
def reset_models(self):
103+
"""
104+
See ibase_project.py
105+
"""
106+
self.models = self.__connect_models__()

projects/helmet/helmet_project.py

+3-24
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,6 @@
1-
from os.path import (
2-
join as pjoin,
3-
dirname as pdirname,
4-
abspath as pabspath
5-
)
6-
7-
from ultralytics import YOLO
8-
91
from projects.base_project import BaseProject
102
from projects.helmet.ihelmet_project import IHelmetProject
113

12-
import os
13-
import torch
14-
154
config_path = './projects/helmet/helmet_config.yaml'
165

176

@@ -117,24 +106,14 @@ def connect_models(self):
117106
Initializes the YOLO models and connects them to the appropriate device (CPU or GPU).
118107
119108
Returns:
120-
tuple: A tuple containing two YOLO models.
109+
models: A tuple containing two YOLO models.
110+
models_allowed_classes: List of corresponding allowed classes for each model.
121111
122112
Raises:
123113
ModuleNotFoundError: If the models cannot be loaded.
124114
"""
125115

126-
_cur_dir = os.getcwd()
127-
# initialise the yolo model, additionally use the device parameter to specify the device to run the model on.
128-
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
129-
_cur_dir = pdirname(pabspath(__file__))
130-
model_dir = pjoin(_cur_dir, f'../../models')
131-
model_dir = pabspath(model_dir) # normalise the link
132-
133-
models = []
134-
for model_name in self._config.get('models'):
135-
model = YOLO(pjoin(model_dir, model_name)).to(self.device)
136-
models.append(model)
137-
116+
models = self.__connect_models__()
138117
models_allowed_classes = self._config.get('allowed_classes')
139118

140119
if not models:

projects/ibase_project.py

+34
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,37 @@ def create_proj_save_dir(self):
4040
Create project save directory after initializing the project.
4141
"""
4242
pass
43+
44+
@abstractmethod
45+
def __read_config__(self, path):
46+
"""
47+
Read project's configuration file.
48+
49+
Returns:
50+
tuple: Configuration file in dictionary format.
51+
52+
Raises:
53+
TypeError: If the models cannot be loaded.
54+
"""
55+
pass
56+
57+
@abstractmethod
58+
def __connect_models__(self):
59+
"""
60+
Initializes the YOLO models and connects them to the appropriate device (CPU or GPU).
61+
62+
Returns:
63+
tuple: A tuple containing two YOLO models.
64+
65+
Raises:
66+
ModuleNotFoundError: If the models cannot be loaded.
67+
"""
68+
pass
69+
70+
@abstractmethod
71+
def reset_models(self):
72+
"""
73+
Reset model after processing video to avoid memory allocation error when the upcoming video comes in with
74+
different resolution.
75+
"""
76+
pass

projects/person/person_project.py

+4-26
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,6 @@
1-
from os.path import (
2-
join as pjoin,
3-
dirname as pdirname,
4-
abspath as pabspath
5-
)
6-
7-
from ultralytics import YOLO
8-
91
from projects.base_project import BaseProject
102
from projects.person.iperson_project import IPersonProject
113

12-
import os
13-
import torch
14-
154
config_path = './projects/person/person_config.yaml'
165

176

@@ -111,30 +100,19 @@ def connect_models(self):
111100
Initializes the YOLO models and connects them to the appropriate device (CPU or GPU).
112101
113102
Returns:
114-
tuple: A tuple containing two YOLO models.
103+
models: A tuple containing two YOLO models.
104+
models_allowed_classes: List of corresponding allowed classes for each model.
115105
116106
Raises:
117107
ModuleNotFoundError: If the models cannot be loaded.
118108
"""
119109

120-
_cur_dir = os.getcwd()
121-
# initialise the yolo model, additionally use the device parameter to specify the device to run the model on.
122-
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
123-
_cur_dir = pdirname(pabspath(__file__))
124-
model_dir = pjoin(_cur_dir, f'../../models')
125-
model_dir = pabspath(model_dir) # normalise the link
126-
127-
models = []
128-
for model_name in self._config.get('models'):
129-
model = YOLO(pjoin(model_dir, model_name)).to(self.device)
130-
models.append(model)
131-
110+
models = self.__connect_models__()
132111
models_allowed_classes = self._config.get('allowed_classes')
133112

134113
if not models:
135114
raise ModuleNotFoundError('Model not found!')
136115

137116
print(f'1. Using device: {self.device}')
138-
print(
139-
f"2. Using {len(models)} models: {[model_name for model_name in self._config.get('models')]}")
117+
print(f"2. Using {len(models)} models: {[model_name for model_name in self._config.get('models')]}")
140118
return models, models_allowed_classes

services/harvest_service.py

+1
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def evaluate(self, video):
190190
frame,
191191
skip_frames_counter)
192192
# Free all resources
193+
self.project.reset_models()
193194
cv2.destroyAllWindows()
194195

195196
return self.export.result_dir_path

0 commit comments

Comments
 (0)