@@ -36,6 +36,7 @@ class ThreadData(TypedDict):
36
36
thread : Thread
37
37
event : Event
38
38
transcription : str
39
+ joined : bool
39
40
40
41
41
42
class VoiceRecognitionAgent (BaseAgent ):
@@ -78,7 +79,7 @@ def __init__(
78
79
self .sample_buffer_lock = Lock ()
79
80
self .active_thread = ""
80
81
self .transcription_threads : dict [str , ThreadData ] = {}
81
- self .buffer_reminders : dict [str , list [NDArray ]] = {}
82
+ self .transcription_buffers : dict [str , list [NDArray ]] = {}
82
83
83
84
def __call__ (self ):
84
85
self .run ()
@@ -106,12 +107,13 @@ def stop(self):
106
107
self .logger .info ("Stopping voice agent" )
107
108
self .running = False
108
109
self .connectors ["microphone" ].terminate_action (self .listener_handle )
109
- to_finish = len (list (self .transcription_threads .keys ()))
110
- while to_finish > 0 :
110
+ while not all (
111
+ [thread ["joined" ] for thread in self .transcription_threads .values ()]
112
+ ):
111
113
for thread_id in self .transcription_threads :
112
114
if self .transcription_threads [thread_id ]["event" ].is_set ():
113
115
self .transcription_threads [thread_id ]["thread" ].join ()
114
- to_finish -= 1
116
+ self . transcription_threads [ thread_id ][ "joined" ] = True
115
117
else :
116
118
self .logger .info (
117
119
f"Waiting for transcription of { thread_id } to finish..."
@@ -125,6 +127,12 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]):
125
127
if not self .recording_started and len (self .sample_buffer ) > 5 :
126
128
self .sample_buffer = self .sample_buffer [- 5 :]
127
129
130
+ # attempt to join finished threads:
131
+ for thread_id in self .transcription_threads :
132
+ if self .transcription_threads [thread_id ]["event" ].is_set ():
133
+ self .transcription_threads [thread_id ]["thread" ].join ()
134
+ self .transcription_threads [thread_id ]["joined" ] = True
135
+
128
136
voice_detected , output_parameters = self .vad .detected (indata , {})
129
137
should_record = False
130
138
# TODO: second condition is temporary
@@ -141,11 +149,11 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]):
141
149
)
142
150
transcription_finished = Event ()
143
151
self .active_thread = thread_id
144
- transcription_thread .start ()
145
152
self .transcription_threads [thread_id ] = {
146
153
"thread" : transcription_thread ,
147
154
"event" : transcription_finished ,
148
155
"transcription" : "" ,
156
+ "joined" : False ,
149
157
}
150
158
151
159
if voice_detected :
@@ -156,12 +164,15 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]):
156
164
self .recording_started
157
165
and sample_time - self .grace_period_start > self .grace_period
158
166
):
159
- self .logger .info ("Grace period ended... stopping recording" )
167
+ self .logger .info (
168
+ "Grace period ended... stopping recording, starting transcription"
169
+ )
160
170
self .recording_started = False
161
171
self .grace_period_start = 0
162
172
with self .sample_buffer_lock :
163
- self .buffer_reminders [self .active_thread ] = self .sample_buffer
173
+ self .transcription_buffers [self .active_thread ] = self .sample_buffer
164
174
self .sample_buffer = []
175
+ self .transcription_threads [self .active_thread ]["thread" ].start ()
165
176
self .active_thread = ""
166
177
167
178
def should_record (
@@ -175,31 +186,46 @@ def should_record(
175
186
176
187
def transcription_thread (self , identifier : str ):
177
188
self .logger .info (f"transcription thread { identifier } started" )
178
- with self .transcription_lock :
179
- while self .active_thread == identifier :
180
- with self .sample_buffer_lock :
181
- if len (self .sample_buffer ) == 0 :
182
- continue
183
- audio_data = self .sample_buffer .copy ()
184
- self .sample_buffer = []
185
- audio_data = np .concatenate (audio_data )
186
- self .transcription_model .transcribe (audio_data )
187
-
188
- # transciption of the reminder of the buffer
189
- with self .sample_buffer_lock :
190
- if identifier in self .buffer_reminders :
191
- audio_data = self .buffer_reminders [identifier ]
192
- audio_data = np .concatenate (audio_data )
193
- self .transcription_model .transcribe (audio_data )
194
- del self .buffer_reminders [identifier ]
195
- # self.transcription_model.save_wav(f"{identifier}.wav")
196
- transcription = self .transcription_model .consume_transcription ()
197
- print ("Transcription: " , transcription )
198
- self .connectors ["ros2" ].send_message (
199
- ROS2ARIMessage (
200
- {"data" : transcription }, {"msg_type" : "std_msgs/msg/String" }
201
- ),
202
- "/from_human" ,
203
- )
204
- self .transcription_threads [identifier ]["transcription" ] = transcription
205
- self .transcription_threads [identifier ]["event" ].set ()
189
+ audio_data = np .concatenate (self .transcription_buffers [identifier ])
190
+ with self .transcription_lock : # this is only necessary for the local model... TODO: fix this somehow
191
+ transcription = self .transcription_model .transcribe (audio_data )
192
+ self .connectors ["ros2" ].send_message (
193
+ ROS2ARIMessage (
194
+ {"data" : transcription }, {"msg_type" : "std_msgs/msg/String" }
195
+ ),
196
+ "/from_human" ,
197
+ )
198
+ self .transcription_threads [identifier ]["transcription" ] = transcription
199
+ self .transcription_threads [identifier ]["event" ].set ()
200
+
201
+ # with self.transcription_lock:
202
+ # while self.active_thread == identifier:
203
+ # with self.sample_buffer_lock:
204
+ # if len(self.sample_buffer) == 0:
205
+ # continue
206
+ # audio_data = self.sample_buffer.copy()
207
+ # self.sample_buffer = []
208
+ # audio_data = np.concatenate(audio_data)
209
+ # with self.transcription_lock:
210
+ # self.transcription_model.transcribe(audio_data)
211
+
212
+ # # transciption of the reminder of the buffer
213
+ # with self.sample_buffer_lock:
214
+ # if identifier in self.transcription_buffers:
215
+ # audio_data = self.transcription_buffers[identifier]
216
+ # audio_data = np.concatenate(audio_data)
217
+ # with self.transcription_lock:
218
+ # self.transcription_model.transcribe(audio_data)
219
+ # del self.transcription_buffers[identifier]
220
+ # # self.transcription_model.save_wav(f"{identifier}.wav")
221
+ # with self.transcription_lock:
222
+ # transcription = self.transcription_model.consume_transcription()
223
+ # self.logger.info(f"Transcription: {transcription}")
224
+ # self.connectors["ros2"].send_message(
225
+ # ROS2ARIMessage(
226
+ # {"data": transcription}, {"msg_type": "std_msgs/msg/String"}
227
+ # ),
228
+ # "/from_human",
229
+ # )
230
+ # self.transcription_threads[identifier]["transcription"] = transcription
231
+ # self.transcription_threads[identifier]["event"].set()
0 commit comments