Skip to content

Commit 928a9ff

Browse files
committed
feat: working streaming ASR
1 parent 7dc4b43 commit 928a9ff

File tree

4 files changed

+73
-60
lines changed

4 files changed

+73
-60
lines changed

src/rai/rai/agents/voice_agent.py

+61-35
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class ThreadData(TypedDict):
3636
thread: Thread
3737
event: Event
3838
transcription: str
39+
joined: bool
3940

4041

4142
class VoiceRecognitionAgent(BaseAgent):
@@ -78,7 +79,7 @@ def __init__(
7879
self.sample_buffer_lock = Lock()
7980
self.active_thread = ""
8081
self.transcription_threads: dict[str, ThreadData] = {}
81-
self.buffer_reminders: dict[str, list[NDArray]] = {}
82+
self.transcription_buffers: dict[str, list[NDArray]] = {}
8283

8384
def __call__(self):
8485
self.run()
@@ -106,12 +107,13 @@ def stop(self):
106107
self.logger.info("Stopping voice agent")
107108
self.running = False
108109
self.connectors["microphone"].terminate_action(self.listener_handle)
109-
to_finish = len(list(self.transcription_threads.keys()))
110-
while to_finish > 0:
110+
while not all(
111+
[thread["joined"] for thread in self.transcription_threads.values()]
112+
):
111113
for thread_id in self.transcription_threads:
112114
if self.transcription_threads[thread_id]["event"].is_set():
113115
self.transcription_threads[thread_id]["thread"].join()
114-
to_finish -= 1
116+
self.transcription_threads[thread_id]["joined"] = True
115117
else:
116118
self.logger.info(
117119
f"Waiting for transcription of {thread_id} to finish..."
@@ -125,6 +127,12 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]):
125127
if not self.recording_started and len(self.sample_buffer) > 5:
126128
self.sample_buffer = self.sample_buffer[-5:]
127129

130+
# attempt to join finished threads:
131+
for thread_id in self.transcription_threads:
132+
if self.transcription_threads[thread_id]["event"].is_set():
133+
self.transcription_threads[thread_id]["thread"].join()
134+
self.transcription_threads[thread_id]["joined"] = True
135+
128136
voice_detected, output_parameters = self.vad.detected(indata, {})
129137
should_record = False
130138
# TODO: second condition is temporary
@@ -141,11 +149,11 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]):
141149
)
142150
transcription_finished = Event()
143151
self.active_thread = thread_id
144-
transcription_thread.start()
145152
self.transcription_threads[thread_id] = {
146153
"thread": transcription_thread,
147154
"event": transcription_finished,
148155
"transcription": "",
156+
"joined": False,
149157
}
150158

151159
if voice_detected:
@@ -156,12 +164,15 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]):
156164
self.recording_started
157165
and sample_time - self.grace_period_start > self.grace_period
158166
):
159-
self.logger.info("Grace period ended... stopping recording")
167+
self.logger.info(
168+
"Grace period ended... stopping recording, starting transcription"
169+
)
160170
self.recording_started = False
161171
self.grace_period_start = 0
162172
with self.sample_buffer_lock:
163-
self.buffer_reminders[self.active_thread] = self.sample_buffer
173+
self.transcription_buffers[self.active_thread] = self.sample_buffer
164174
self.sample_buffer = []
175+
self.transcription_threads[self.active_thread]["thread"].start()
165176
self.active_thread = ""
166177

167178
def should_record(
@@ -175,31 +186,46 @@ def should_record(
175186

176187
def transcription_thread(self, identifier: str):
177188
self.logger.info(f"transcription thread {identifier} started")
178-
with self.transcription_lock:
179-
while self.active_thread == identifier:
180-
with self.sample_buffer_lock:
181-
if len(self.sample_buffer) == 0:
182-
continue
183-
audio_data = self.sample_buffer.copy()
184-
self.sample_buffer = []
185-
audio_data = np.concatenate(audio_data)
186-
self.transcription_model.transcribe(audio_data)
187-
188-
# transciption of the reminder of the buffer
189-
with self.sample_buffer_lock:
190-
if identifier in self.buffer_reminders:
191-
audio_data = self.buffer_reminders[identifier]
192-
audio_data = np.concatenate(audio_data)
193-
self.transcription_model.transcribe(audio_data)
194-
del self.buffer_reminders[identifier]
195-
# self.transcription_model.save_wav(f"{identifier}.wav")
196-
transcription = self.transcription_model.consume_transcription()
197-
print("Transcription: ", transcription)
198-
self.connectors["ros2"].send_message(
199-
ROS2ARIMessage(
200-
{"data": transcription}, {"msg_type": "std_msgs/msg/String"}
201-
),
202-
"/from_human",
203-
)
204-
self.transcription_threads[identifier]["transcription"] = transcription
205-
self.transcription_threads[identifier]["event"].set()
189+
audio_data = np.concatenate(self.transcription_buffers[identifier])
190+
with self.transcription_lock: # this is only necessary for the local model... TODO: fix this somehow
191+
transcription = self.transcription_model.transcribe(audio_data)
192+
self.connectors["ros2"].send_message(
193+
ROS2ARIMessage(
194+
{"data": transcription}, {"msg_type": "std_msgs/msg/String"}
195+
),
196+
"/from_human",
197+
)
198+
self.transcription_threads[identifier]["transcription"] = transcription
199+
self.transcription_threads[identifier]["event"].set()
200+
201+
# with self.transcription_lock:
202+
# while self.active_thread == identifier:
203+
# with self.sample_buffer_lock:
204+
# if len(self.sample_buffer) == 0:
205+
# continue
206+
# audio_data = self.sample_buffer.copy()
207+
# self.sample_buffer = []
208+
# audio_data = np.concatenate(audio_data)
209+
# with self.transcription_lock:
210+
# self.transcription_model.transcribe(audio_data)
211+
212+
# # transciption of the reminder of the buffer
213+
# with self.sample_buffer_lock:
214+
# if identifier in self.transcription_buffers:
215+
# audio_data = self.transcription_buffers[identifier]
216+
# audio_data = np.concatenate(audio_data)
217+
# with self.transcription_lock:
218+
# self.transcription_model.transcribe(audio_data)
219+
# del self.transcription_buffers[identifier]
220+
# # self.transcription_model.save_wav(f"{identifier}.wav")
221+
# with self.transcription_lock:
222+
# transcription = self.transcription_model.consume_transcription()
223+
# self.logger.info(f"Transcription: {transcription}")
224+
# self.connectors["ros2"].send_message(
225+
# ROS2ARIMessage(
226+
# {"data": transcription}, {"msg_type": "std_msgs/msg/String"}
227+
# ),
228+
# "/from_human",
229+
# )
230+
# self.transcription_threads[identifier]["transcription"] = transcription
231+
# self.transcription_threads[identifier]["event"].set()

src/rai_asr/rai_asr/models/base.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,6 @@ def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
3737

3838
self.latest_transcription = ""
3939

40-
def consume_transcription(self) -> str:
41-
ret = self.latest_transcription
42-
self.latest_transcription = ""
43-
return ret
44-
4540
@abstractmethod
46-
def transcribe(self, data: NDArray[np.int16]):
41+
def transcribe(self, data: NDArray[np.int16]) -> str:
4742
pass

src/rai_asr/rai_asr/models/local_whisper.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import logging
1516
from typing import cast
1617

1718
import numpy as np
@@ -30,14 +31,10 @@ def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
3031
else:
3132
self.whisper = whisper.load_model(self.model_name)
3233

34+
self.logger = logging.getLogger(__name__)
3335
# TODO: remove sample storage before PR is merged, this is just to enable saving wav files for debugging
3436
# self.samples = None
3537

36-
def consume_transcription(self) -> str:
37-
ret = super().consume_transcription()
38-
# self.samples = None
39-
return ret
40-
4138
# def save_wav(self, output_filename: str):
4239
# assert self.samples is not None, "No samples to save"
4340
# combined_samples = self.samples
@@ -55,14 +52,13 @@ def consume_transcription(self) -> str:
5552
# wav_file.setframerate(self.sample_rate)
5653
# wav_file.writeframes(combined_samples.tobytes())
5754

58-
def transcribe(self, data: NDArray[np.int16]):
59-
# self.samples = (
60-
# np.concatenate((self.samples, data)) if self.samples is not None else data
61-
# )
55+
def transcribe(self, data: NDArray[np.int16]) -> str:
6256
normalized_data = data.astype(np.float32) / 32768.0
6357
result = whisper.transcribe(
6458
self.whisper, normalized_data
6559
) # TODO: handling of additional transcribe arguments (perhaps in model init)
6660
transcription = result["text"]
61+
self.logger.info("transcription: %s", transcription)
6762
transcription = cast(str, transcription)
68-
self.latest_transcription += transcription
63+
self.latest_transcription = transcription
64+
return transcription

src/rai_asr/rai_asr/models/open_ai_whisper.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import io
16+
import logging
1617
import os
1718
from functools import partial
1819

@@ -36,21 +37,16 @@ def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
3637
self.openai_client.audio.transcriptions.create,
3738
model=self.model_name,
3839
)
40+
self.logger = logging.getLogger(__name__)
3941
self.samples = []
4042

41-
def add_samples(self, data: NDArray[np.int16]):
43+
def transcribe(self, data: NDArray[np.int16]) -> str:
4244
normalized_data = data.astype(np.float32) / 32768.0
43-
self.samples = (
44-
np.concatenate([self.samples, normalized_data])
45-
if self.samples is not None
46-
else data
47-
)
48-
49-
def transcribe(self) -> str:
5045
with io.BytesIO() as temp_wav_buffer:
51-
wavfile.write(temp_wav_buffer, self.sample_rate, self.samples)
46+
wavfile.write(temp_wav_buffer, self.sample_rate, normalized_data)
5247
temp_wav_buffer.seek(0)
5348
temp_wav_buffer.name = "temp.wav"
5449
response = self.model(file=temp_wav_buffer, language=self.language)
5550
transcription = response.text
51+
self.logger.info("transcription: %s", transcription)
5652
return transcription

0 commit comments

Comments
 (0)