13
13
# limitations under the License.
14
14
15
15
16
+ import time
16
17
from threading import Lock , Thread
17
- from typing import Any , List , Tuple
18
+ from typing import Any , List , cast
18
19
19
20
import numpy as np
20
21
from numpy .typing import NDArray
21
22
22
23
from rai .agents .base import BaseAgent
23
24
from rai .communication import AudioInputDeviceConfig , StreamingAudioInputDevice
24
- from rai_asr .models .base import BaseVoiceDetectionModel
25
+ from rai_asr .models .base import BaseTranscriptionModel , BaseVoiceDetectionModel
25
26
26
27
27
28
class VoiceRecognitionAgent (BaseAgent ):
@@ -38,34 +39,50 @@ def __call__(self):
38
39
self .run ()
39
40
40
41
def setup (
41
- self , microphone_device_id : int , microphone_config : AudioInputDeviceConfig
42
+ self ,
43
+ microphone_device_id : int , # TODO: Change to name based instead of id based identification
44
+ microphone_config : AudioInputDeviceConfig ,
45
+ transcription_model : BaseTranscriptionModel ,
42
46
):
43
- assert isinstance (self .connectors ["microphone" ], StreamingAudioInputDevice )
47
+ self .connectors ["microphone" ] = cast (
48
+ StreamingAudioInputDevice , self .connectors ["microphone" ]
49
+ )
44
50
self .microphone_device_id = str (microphone_device_id )
45
51
self .connectors ["microphone" ].configure_device (
46
52
target = self .microphone_device_id , config = microphone_config
47
53
)
54
+ self .transcription_model = transcription_model
48
55
self .ran_setup = True
56
+ self .running = False
57
+
58
+ def add_detection_model (
59
+ self , model : BaseVoiceDetectionModel , pipeline : str = "record"
60
+ ):
61
+ if pipeline == "record" :
62
+ self .should_record_pipeline .append (model )
63
+ elif pipeline == "stop" :
64
+ self .should_stop_pipeline .append (model )
65
+ else :
66
+ raise ValueError ("Pipeline should be either 'record' or 'stop'" )
49
67
50
68
def run (self ):
69
+ self .running = True
51
70
self .listener_handle = self .connectors ["microphone" ].start_action (
52
71
self .microphone_device_id , self .on_new_sample
53
72
)
54
73
self .transcription_thread = Thread (target = self ._transcription_function )
55
74
self .transcription_thread .start ()
56
75
57
76
def stop (self ):
77
+ self .running = False
58
78
self .connectors ["microphone" ].terminate_action (self .listener_handle )
59
79
self .transcription_thread .join ()
60
80
61
81
def on_new_sample (self , indata : np .ndarray , status_flags : dict [str , Any ]):
62
- should_stop , should_cancel = self .should_stop_recording (indata )
63
- print (indata )
64
- if should_cancel :
65
- self .cancel_task ()
66
- if (self .recording_started and not should_stop ) or (
67
- self .should_start_recording (indata )
68
- ):
82
+ should_stop = self .should_stop_recording (indata )
83
+ if self .should_start_recording (indata ):
84
+ self .recording_started = True
85
+ if self .recording_started and not should_stop :
69
86
with self .transcription_lock :
70
87
self .shared_samples .extend (indata )
71
88
@@ -75,23 +92,27 @@ def should_start_recording(self, audio_data: NDArray[np.int16]) -> bool:
75
92
should_listen , output_parameters = model .detected (
76
93
audio_data , output_parameters
77
94
)
95
+ print (should_listen , output_parameters )
78
96
if not should_listen :
79
97
return False
80
98
return True
81
99
82
- def should_stop_recording (self , audio_data : NDArray [np .int16 ]) -> Tuple [ bool , bool ] :
100
+ def should_stop_recording (self , audio_data : NDArray [np .int16 ]) -> bool :
83
101
output_parameters = {}
84
102
for model in self .should_stop_pipeline :
85
103
should_listen , output_parameters = model .detected (
86
104
audio_data , output_parameters
87
105
)
88
- # TODO: Add handling output parametrs for checking if should cancel
89
106
if should_listen :
90
- return False , False
91
- return True , False
107
+ return True
108
+ return False
92
109
93
110
def _transcription_function (self ):
94
- with self .transcription_lock :
95
- samples = np .array (self .shared_samples )
96
- print (samples )
97
- self .shared_samples = []
111
+ while self .running :
112
+ time .sleep (0.1 )
113
+ # critical section for samples
114
+ with self .transcription_lock :
115
+ samples = np .array (self .shared_samples )
116
+ self .shared_samples = []
117
+ # end critical section for samples
118
+ self .transcription_model .add_samples (samples )
0 commit comments