14
14
15
15
16
16
import time
17
- from threading import Lock , Thread
18
- from typing import Any , List , cast
17
+ from threading import Event , Lock , Thread
18
+ from typing import Any , List , TypedDict
19
+ from uuid import uuid4
19
20
20
21
import numpy as np
21
22
from numpy .typing import NDArray
22
23
23
24
from rai .agents .base import BaseAgent
24
25
from rai .communication import AudioInputDeviceConfig , StreamingAudioInputDevice
25
- from rai_asr .models . base import BaseTranscriptionModel , BaseVoiceDetectionModel
26
+ from rai_asr .models import BaseTranscriptionModel , BaseVoiceDetectionModel
26
27
27
28
28
- class VoiceRecognitionAgent (BaseAgent ):
29
- def __init__ (self ):
30
- super ().__init__ (connectors = {"microphone" : StreamingAudioInputDevice ()})
31
- self .should_record_pipeline : List [BaseVoiceDetectionModel ] = []
32
- self .should_stop_pipeline : List [BaseVoiceDetectionModel ] = []
33
- self .transcription_lock = Lock ()
34
- self .shared_samples = []
35
- self .recording_started = False
36
- self .ran_setup = False
29
+ class ThreadData (TypedDict ):
30
+ thread : Thread
31
+ event : Event
32
+ transcription : str
37
33
38
- def __call__ (self ):
39
- self .run ()
40
34
41
- def setup (
35
+ class VoiceRecognitionAgent (BaseAgent ):
36
+ def __init__ (
42
37
self ,
43
38
microphone_device_id : int , # TODO: Change to name based instead of id based identification
44
39
microphone_config : AudioInputDeviceConfig ,
45
40
transcription_model : BaseTranscriptionModel ,
41
+ vad : BaseVoiceDetectionModel ,
42
+ grace_period : float = 1.0 ,
46
43
):
47
- self .connectors ["microphone" ] = cast (
48
- StreamingAudioInputDevice , self .connectors ["microphone" ]
44
+ microphone = StreamingAudioInputDevice ()
45
+ microphone .configure_device (
46
+ target = str (microphone_device_id ), config = microphone_config
49
47
)
48
+ super ().__init__ (connectors = {"microphone" : microphone })
50
49
self .microphone_device_id = str (microphone_device_id )
51
- self .connectors [ "microphone" ]. configure_device (
52
- target = self .microphone_device_id , config = microphone_config
53
- )
50
+ self .should_record_pipeline : List [ BaseVoiceDetectionModel ] = []
51
+ self .should_stop_pipeline : List [ BaseVoiceDetectionModel ] = []
52
+
54
53
self .transcription_model = transcription_model
55
- self .ran_setup = True
56
- self .running = False
54
+ self .transcription_lock = Lock ()
55
+
56
+ self .vad : BaseVoiceDetectionModel = vad
57
+
58
+ self .grace_period = grace_period
59
+ self .grace_period_start = 0
60
+
61
+ self .recording_started = False
62
+ self .ran_setup = False
63
+
64
+ self .sample_buffer = []
65
+ self .sample_buffer_lock = Lock ()
66
+ self .active_thread = ""
67
+ self .transcription_threads : dict [str , ThreadData ] = {}
68
+ self .buffer_reminders : dict [str , list [NDArray ]] = {}
69
+
70
+ def __call__ (self ):
71
+ self .run ()
57
72
58
73
def add_detection_model (
59
74
self , model : BaseVoiceDetectionModel , pipeline : str = "record"
@@ -70,49 +85,93 @@ def run(self):
70
85
self .listener_handle = self .connectors ["microphone" ].start_action (
71
86
self .microphone_device_id , self .on_new_sample
72
87
)
73
- self .transcription_thread = Thread (target = self ._transcription_function )
74
- self .transcription_thread .start ()
75
88
76
89
def stop (self ):
77
90
self .running = False
78
91
self .connectors ["microphone" ].terminate_action (self .listener_handle )
79
- self .transcription_thread .join ()
92
+ to_finish = list (self .transcription_threads .keys ())
93
+ while len (to_finish ) > 0 :
94
+ for thread_id in self .transcription_threads :
95
+ if self .transcription_threads [thread_id ]["event" ].is_set ():
96
+ self .transcription_threads [thread_id ]["thread" ].join ()
97
+ to_finish .remove (thread_id )
98
+ else :
99
+ print (f"Waiting for transcription of { thread_id } to finish" )
80
100
81
101
def on_new_sample (self , indata : np .ndarray , status_flags : dict [str , Any ]):
82
- should_stop = self .should_stop_recording (indata )
83
- if self .should_start_recording (indata ):
102
+ sample_time = time .time ()
103
+ with self .sample_buffer_lock :
104
+ self .sample_buffer .append (indata )
105
+ if not self .recording_started and len (self .sample_buffer ) > 5 :
106
+ self .sample_buffer = self .sample_buffer [- 5 :]
107
+
108
+ voice_detected , output_parameters = self .vad .detected (indata , {})
109
+ should_record = False
110
+ # TODO: second condition is temporary
111
+ if voice_detected and not self .recording_started :
112
+ should_record = self .should_record (indata , output_parameters )
113
+
114
+ if should_record :
115
+ print ("Start recording" )
84
116
self .recording_started = True
85
- if self .recording_started and not should_stop :
86
- with self .transcription_lock :
87
- self .shared_samples .extend (indata )
117
+ thread_id = str (uuid4 ())[0 :8 ]
118
+ transcription_thread = Thread (
119
+ target = self .transcription_thread ,
120
+ args = [thread_id ],
121
+ )
122
+ transcription_finished = Event ()
123
+ self .active_thread = thread_id
124
+ transcription_thread .start ()
125
+ self .transcription_threads [thread_id ] = {
126
+ "thread" : transcription_thread ,
127
+ "event" : transcription_finished ,
128
+ "transcription" : "" ,
129
+ }
130
+
131
+ if voice_detected :
132
+ self .grace_period_start = sample_time
88
133
89
- def should_start_recording (self , audio_data : NDArray [np .int16 ]) -> bool :
90
- output_parameters = {}
134
+ if (
135
+ self .recording_started
136
+ and sample_time - self .grace_period_start > self .grace_period
137
+ ):
138
+ print ("Stop recording" )
139
+ self .recording_started = False
140
+ self .grace_period_start = 0
141
+ with self .sample_buffer_lock :
142
+ self .buffer_reminders [self .active_thread ] = self .sample_buffer
143
+ self .sample_buffer = []
144
+ self .active_thread = ""
145
+
146
+ def should_record (
147
+ self , audio_data : NDArray , input_parameters : dict [str , Any ]
148
+ ) -> bool :
91
149
for model in self .should_record_pipeline :
92
- should_listen , output_parameters = model .detected (
93
- audio_data , output_parameters
94
- )
95
- print (should_listen , output_parameters )
96
- if not should_listen :
97
- return False
98
- return True
99
-
100
- def should_stop_recording (self , audio_data : NDArray [np .int16 ]) -> bool :
101
- output_parameters = {}
102
- for model in self .should_stop_pipeline :
103
- should_listen , output_parameters = model .detected (
104
- audio_data , output_parameters
105
- )
106
- if should_listen :
150
+ detected , output = model .detected (audio_data , input_parameters )
151
+ print (f"Detected: { detected } : { output } " )
152
+ if detected :
107
153
return True
108
154
return False
109
155
110
- def _transcription_function (self ):
111
- while self .running :
112
- time .sleep (0.1 )
113
- # critical section for samples
114
- with self .transcription_lock :
115
- samples = np .array (self .shared_samples )
116
- self .shared_samples = []
117
- # end critical section for samples
118
- self .transcription_model .add_samples (samples )
156
+ def transcription_thread (self , identifier : str ):
157
+ with self .transcription_lock :
158
+ while self .active_thread == identifier :
159
+ with self .sample_buffer_lock :
160
+ if len (self .sample_buffer ) == 0 :
161
+ continue
162
+ audio_data = self .sample_buffer .copy ()
163
+ self .sample_buffer = []
164
+ audio_data = np .concatenate (audio_data )
165
+ self .transcription_model .transcribe (audio_data )
166
+
167
+ # transciption of the reminder of the buffer
168
+ with self .sample_buffer_lock :
169
+ if identifier in self .buffer_reminders :
170
+ audio_data = self .buffer_reminders [identifier ]
171
+ audio_data = np .concatenate (audio_data )
172
+ self .transcription_model .transcribe (audio_data )
173
+ del self .buffer_reminders [identifier ]
174
+ transcription = self .transcription_model .consume_transcription ()
175
+ self .transcription_threads [identifier ]["transcription" ] = transcription
176
+ self .transcription_threads [identifier ]["event" ].set ()
177
+ # TODO: sending the transcription
0 commit comments