Skip to content

Commit 47fb148

Browse files
committed
feat: add open wake word model
1 parent eb8b4fb commit 47fb148

File tree

3 files changed

+54
-3
lines changed

3 files changed

+54
-3
lines changed

src/rai_asr/rai_asr/models/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .base import BaseVoiceDetectionModel
15+
from rai_asr.models.base import BaseVoiceDetectionModel
16+
from rai_asr.models.open_wake_word import OpenWakeWord
17+
from rai_asr.models.silero_vad import SileroVAD
1618

17-
__all__ = ["BaseVoiceDetectionModel"]
19+
__all__ = ["BaseVoiceDetectionModel", "SileroVAD", "OpenWakeWord"]
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (C) 2024 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Tuple
16+
17+
from numpy.typing import NDArray
18+
from openwakeword.model import Model as OWWModel
19+
from openwakeword.utils import download_models
20+
21+
from rai_asr.models import BaseVoiceDetectionModel
22+
23+
24+
class OpenWakeWord(BaseVoiceDetectionModel):
25+
def __init__(self, wake_word_model_path: str, threshold: float = 0.5):
26+
super(OpenWakeWord, self).__init__()
27+
self.model_name = "open_wake_word"
28+
download_models()
29+
self.model = OWWModel(
30+
wakeword_models=[
31+
wake_word_model_path,
32+
],
33+
inference_framework="onnx",
34+
)
35+
self.threshold = threshold
36+
37+
def detected(
38+
self, audio_data: NDArray, input_parameters: dict[str, Any]
39+
) -> Tuple[bool, dict[str, Any]]:
40+
print(len(audio_data))
41+
predictions = self.model.predict(audio_data)
42+
ret = input_parameters.copy()
43+
ret.update({self.model_name: {"predictions": predictions}})
44+
for key, value in predictions.items():
45+
if value > self.threshold:
46+
self.model.reset()
47+
return True, ret
48+
return False, ret

src/rai_asr/rai_asr/models/silero_vad.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def int2float(self, sound: NDArray[np.int16]):
4848
converted_sound = converted_sound.squeeze()
4949
return converted_sound
5050

51-
def detect(
51+
def detected(
5252
self, audio_data: NDArray, input_parameters: dict[str, Any]
5353
) -> Tuple[bool, dict[str, Any]]:
5454
vad_confidence = self.model(
@@ -57,5 +57,6 @@ def detect(
5757
).item()
5858
ret = input_parameters.copy()
5959
ret.update({self.model_name: {"vad_confidence": vad_confidence}})
60+
self.model.reset_states() # NOTE: see streaming example at the bottom https://github.com/snakers4/silero-vad/wiki/Examples-and-Dependencies#dependencies
6061

6162
return vad_confidence > self.threshold, ret

0 commit comments

Comments
 (0)