Skip to content

Commit d7cceea

Browse files
committed
refactor: move transcription models for consistency
1 parent 2ee0ec4 commit d7cceea

File tree

5 files changed

+107
-2
lines changed

5 files changed

+107
-2
lines changed

src/rai_asr/rai_asr/asr_clients.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from scipy.io import wavfile
2525
from whisper.transcribe import transcribe
2626

27+
# WARN: This file is going to be removed in favour of rai_asr.models
28+
2729

2830
class ASRModel:
2931
def __init__(self, model_name: str, sample_rate: int, language: str = "en"):

src/rai_asr/rai_asr/models/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from rai_asr.models.base import BaseVoiceDetectionModel
15+
from rai_asr.models.base import BaseTranscriptionModel, BaseVoiceDetectionModel
16+
from rai_asr.models.local_whisper import LocalWhisper
17+
from rai_asr.models.open_ai_whisper import OpenAIWhisper
1618
from rai_asr.models.open_wake_word import OpenWakeWord
1719
from rai_asr.models.silero_vad import SileroVAD
1820

19-
__all__ = ["BaseVoiceDetectionModel", "SileroVAD", "OpenWakeWord"]
21+
__all__ = [
22+
"BaseVoiceDetectionModel",
23+
"SileroVAD",
24+
"OpenWakeWord",
25+
"BaseTranscriptionModel",
26+
"LocalWhisper",
27+
"OpenAIWhisper",
28+
]

src/rai_asr/rai_asr/models/base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from abc import ABC, abstractmethod
1717
from typing import Any, Tuple
1818

19+
import numpy as np
1920
from numpy._typing import NDArray
2021

2122

@@ -26,3 +27,17 @@ def detected(
2627
self, audio_data: NDArray, input_parameters: dict[str, Any]
2728
) -> Tuple[bool, dict[str, Any]]:
2829
pass
30+
31+
32+
class BaseTranscriptionModel(ABC):
33+
def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
34+
self.model_name = model_name
35+
self.sample_rate = sample_rate
36+
self.language = language
37+
38+
@abstractmethod
39+
def transcribe(self, data: NDArray[np.int16]) -> str:
40+
pass
41+
42+
def __call__(self, data: NDArray[np.int16]) -> str:
43+
return self.transcribe(data)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (C) 2024 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import whisper
17+
from numpy._typing import NDArray
18+
19+
from rai_asr.models.base import BaseTranscriptionModel
20+
21+
22+
class LocalWhisper(BaseTranscriptionModel):
23+
def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
24+
super().__init__(model_name, sample_rate, language)
25+
self.whisper = whisper.load_model(self.model_name)
26+
27+
def transcribe(self, data: NDArray[np.int16]) -> str:
28+
result = whisper.transcribe(self.whisper, data.astype(np.float32) / 32768.0)
29+
transcription = result["text"]
30+
# NOTE: this is only for type enforcement, doesn't need to work on runtime
31+
assert isinstance(transcription, str)
32+
return transcription
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (C) 2024 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import io
16+
import os
17+
from functools import partial
18+
19+
import numpy as np
20+
from numpy.typing import NDArray
21+
from openai import OpenAI
22+
from scipy.io import wavfile
23+
24+
from rai_asr.models.base import BaseTranscriptionModel
25+
26+
27+
class OpenAIWhisper(BaseTranscriptionModel):
28+
def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
29+
super().__init__(model_name, sample_rate, language)
30+
api_key = os.getenv("OPENAI_API_KEY")
31+
if api_key is None:
32+
raise ValueError("OPENAI_API_KEY environment variable is not set.")
33+
self.api_key = api_key
34+
self.openai_client = OpenAI()
35+
self.model = partial(
36+
self.openai_client.audio.transcriptions.create,
37+
model=self.model_name,
38+
)
39+
40+
def transcribe(self, data: NDArray[np.int16]) -> str:
41+
with io.BytesIO() as temp_wav_buffer:
42+
wavfile.write(temp_wav_buffer, self.sample_rate, data)
43+
temp_wav_buffer.seek(0)
44+
temp_wav_buffer.name = "temp.wav"
45+
response = self.model(file=temp_wav_buffer, language=self.language)
46+
transcription = response.text
47+
return transcription

0 commit comments

Comments
 (0)