13
13
# limitations under the License.
14
14
15
15
16
+ import logging
16
17
import time
17
18
from threading import Event , Lock , Thread
18
- from typing import Any , List , TypedDict
19
+ from typing import Any , List , Optional , TypedDict
19
20
from uuid import uuid4
20
21
21
22
import numpy as np
@@ -40,7 +41,12 @@ def __init__(
40
41
transcription_model : BaseTranscriptionModel ,
41
42
vad : BaseVoiceDetectionModel ,
42
43
grace_period : float = 1.0 ,
44
+ logger : Optional [logging .Logger ] = None ,
43
45
):
46
+ if logger is None :
47
+ self .logger = logging .getLogger (__name__ )
48
+ else :
49
+ self .logger = logger
44
50
microphone = StreamingAudioInputDevice ()
45
51
microphone .configure_device (
46
52
target = str (microphone_device_id ), config = microphone_config
@@ -87,16 +93,20 @@ def run(self):
87
93
)
88
94
89
95
def stop (self ):
96
+ self .logger .info ("Stopping voice agent" )
90
97
self .running = False
91
98
self .connectors ["microphone" ].terminate_action (self .listener_handle )
92
- to_finish = list (self .transcription_threads .keys ())
93
- while len ( to_finish ) > 0 :
99
+ to_finish = len ( list (self .transcription_threads .keys () ))
100
+ while to_finish > 0 :
94
101
for thread_id in self .transcription_threads :
95
102
if self .transcription_threads [thread_id ]["event" ].is_set ():
96
103
self .transcription_threads [thread_id ]["thread" ].join ()
97
- to_finish . remove ( thread_id )
104
+ to_finish -= 1
98
105
else :
99
- print (f"Waiting for transcription of { thread_id } to finish" )
106
+ self .logger .info (
107
+ f"Waiting for transcription of { thread_id } to finish..."
108
+ )
109
+ self .logger .info ("Voice agent stopped" )
100
110
101
111
def on_new_sample (self , indata : np .ndarray , status_flags : dict [str , Any ]):
102
112
sample_time = time .time ()
@@ -112,7 +122,7 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]):
112
122
should_record = self .should_record (indata , output_parameters )
113
123
114
124
if should_record :
115
- print ( "Start recording" )
125
+ self . logger . info ( "starting recording... " )
116
126
self .recording_started = True
117
127
thread_id = str (uuid4 ())[0 :8 ]
118
128
transcription_thread = Thread (
@@ -129,13 +139,14 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]):
129
139
}
130
140
131
141
if voice_detected :
142
+ self .logger .debug ("Voice detected... resetting grace period" )
132
143
self .grace_period_start = sample_time
133
144
134
145
if (
135
146
self .recording_started
136
147
and sample_time - self .grace_period_start > self .grace_period
137
148
):
138
- print ( "Stop recording" )
149
+ self . logger . info ( "Grace period ended... stopping recording" )
139
150
self .recording_started = False
140
151
self .grace_period_start = 0
141
152
with self .sample_buffer_lock :
@@ -148,12 +159,12 @@ def should_record(
148
159
) -> bool :
149
160
for model in self .should_record_pipeline :
150
161
detected , output = model .detected (audio_data , input_parameters )
151
- print (f"Detected: { detected } : { output } " )
152
162
if detected :
153
163
return True
154
164
return False
155
165
156
166
def transcription_thread (self , identifier : str ):
167
+ self .logger .info (f"transcription thread { identifier } started" )
157
168
with self .transcription_lock :
158
169
while self .active_thread == identifier :
159
170
with self .sample_buffer_lock :
@@ -171,7 +182,10 @@ def transcription_thread(self, identifier: str):
171
182
audio_data = np .concatenate (audio_data )
172
183
self .transcription_model .transcribe (audio_data )
173
184
del self .buffer_reminders [identifier ]
185
+ # self.transcription_model.save_wav(f"{identifier}.wav")
174
186
transcription = self .transcription_model .consume_transcription ()
175
187
self .transcription_threads [identifier ]["transcription" ] = transcription
176
188
self .transcription_threads [identifier ]["event" ].set ()
177
- # TODO: sending the transcription
189
+ # TODO: sending the transcription once https://github.com/RobotecAI/rai/pull/360 is merged
190
+ self .logger .info (f"transcription thread { identifier } finished" )
191
+ print (transcription )
0 commit comments