Skip to content

Commit eb8b4fb

Browse files
committed
feat: add silero vad model
1 parent 9d08988 commit eb8b4fb

File tree

2 files changed

+62
-1
lines changed

2 files changed

+62
-1
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,4 @@ repos:
4444
rev: 7.1.0
4545
hooks:
4646
- id: flake8
47-
args: ["--ignore=E501,E731,W503,W504"]
47+
args: ["--ignore=E501,E731,W503,W504,E203"]
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (C) 2024 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Tuple
16+
17+
import numpy as np
18+
import torch
19+
from numpy.typing import NDArray
20+
21+
from rai_asr.models import BaseVoiceDetectionModel
22+
23+
24+
class SileroVAD(BaseVoiceDetectionModel):
25+
def __init__(self, sampling_rate=16000, threshold=0.5):
26+
super(SileroVAD, self).__init__()
27+
self.model_name = "silero_vad"
28+
self.model, _ = torch.hub.load(
29+
repo_or_dir="snakers4/silero-vad",
30+
model=self.model_name,
31+
) # type: ignore
32+
# NOTE: See silero vad implementation: https://github.com/snakers4/silero-vad/blob/9060f664f20eabb66328e4002a41479ff288f14c/src/silero_vad/utils_vad.py#L61
33+
if sampling_rate == 16000:
34+
self.sampling_rate = 16000
35+
self.window_size = 512
36+
elif sampling_rate == 8000:
37+
self.sampling_rate = 8000
38+
self.window_size = 256
39+
else:
40+
raise ValueError(
41+
"Only 8000 and 16000 sampling rates are supported"
42+
) # TODO: consider if this should be a ValueError or something else
43+
self.threshold = threshold
44+
45+
def int2float(self, sound: NDArray[np.int16]):
46+
converted_sound = sound.astype("float32")
47+
converted_sound *= 1 / 32768
48+
converted_sound = converted_sound.squeeze()
49+
return converted_sound
50+
51+
def detect(
52+
self, audio_data: NDArray, input_parameters: dict[str, Any]
53+
) -> Tuple[bool, dict[str, Any]]:
54+
vad_confidence = self.model(
55+
torch.tensor(self.int2float(audio_data[-self.window_size :])),
56+
self.sampling_rate,
57+
).item()
58+
ret = input_parameters.copy()
59+
ret.update({self.model_name: {"vad_confidence": vad_confidence}})
60+
61+
return vad_confidence > self.threshold, ret

0 commit comments

Comments
 (0)