Skip to content

Commit f07b778

Browse files
fabienricFabien Ric
andauthored
Audio transcriptions improvement (tjardoo#133)
* add chunking_strategy, stream and extra_body parameters to the audio transcriptions endpoint * apply formatter * fix clippy issue * fix code style --------- Co-authored-by: Fabien Ric <[email protected]>
1 parent 82c5914 commit f07b778

File tree

2 files changed

+136
-0
lines changed

2 files changed

+136
-0
lines changed

openai_dive/src/v1/endpoints/audio.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use crate::v1::resources::audio::{AudioTranscriptionParameters, AudioTranslation
99
use futures::Stream;
1010
#[cfg(feature = "stream")]
1111
use futures::StreamExt;
12+
use serde_json::Value;
1213
#[cfg(feature = "stream")]
1314
use std::pin::Pin;
1415

@@ -55,10 +56,18 @@ impl Audio<'_> {
5556
form = form.text("language", language.to_string());
5657
}
5758

59+
if let Some(chunking_strategy) = parameters.chunking_strategy {
60+
form = form.text("chunking_strategy", chunking_strategy.to_string());
61+
}
62+
5863
if let Some(response_format) = parameters.response_format {
5964
form = form.text("response_format", response_format.to_string());
6065
}
6166

67+
if let Some(stream) = parameters.stream {
68+
form = form.text("stream", stream.to_string());
69+
}
70+
6271
if let Some(temperature) = parameters.temperature {
6372
form = form.text("temperature", temperature.to_string());
6473
}
@@ -74,6 +83,21 @@ impl Audio<'_> {
7483
);
7584
}
7685

86+
if let Some(extra_body) = parameters.extra_body {
87+
match extra_body {
88+
Value::Object(map) => {
89+
for (key, value) in map {
90+
form = form.text(key, value.to_string());
91+
}
92+
}
93+
_ => {
94+
return Err(APIError::BadRequestError(
95+
"extra_body must be formatted as a map of key: value".to_string(),
96+
));
97+
}
98+
}
99+
}
100+
77101
let response = self
78102
.client
79103
.post_with_form("/audio/transcriptions", form)

openai_dive/src/v1/resources/audio.rs

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use crate::v1::resources::shared::FileUpload;
44
use bytes::Bytes;
55
use derive_builder::Builder;
66
use serde::{Deserialize, Serialize};
7+
use serde_json::Value;
78
use std::fmt::Display;
89
#[cfg(feature = "tokio")]
910
use std::path::Path;
@@ -40,12 +41,18 @@ pub struct AudioTranscriptionParameters {
4041
/// The language of the input audio. Supplying the input language in ISO-639-1 format will improve accuracy and latency.
4142
#[serde(skip_serializing_if = "Option::is_none")]
4243
pub language: Option<String>,
44+
/// Controls how the audio is cut into chunks. When set to "auto", the server first normalizes loudness and then uses voice activity detection (VAD) to choose boundaries. server_vad object can be provided to tweak VAD detection parameters manually. If unset, the audio is transcribed as a single block.
45+
#[serde(skip_serializing_if = "Option::is_none")]
46+
pub chunking_strategy: Option<TranscriptionChunkingStrategy>,
4347
/// An optional text to guide the model's style or continue a previous audio segment. The prompt should match the audio language.
4448
#[serde(skip_serializing_if = "Option::is_none")]
4549
pub prompt: Option<String>,
4650
/// The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt.
4751
#[serde(skip_serializing_if = "Option::is_none")]
4852
pub response_format: Option<AudioOutputFormat>,
53+
/// If set to true, the model response data will be streamed to the client as it is generated using server-sent events. Note: Streaming is not supported for the whisper-1 model and will be ignored.
54+
#[serde(skip_serializing_if = "Option::is_none")]
55+
pub stream: Option<bool>,
4956
/// The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random,
5057
/// while lower values like 0.2 will make it more focused and deterministic.
5158
/// If set to 0, the model will use log probability to automatically increase the temperature until certain thresholds are hit.
@@ -55,6 +62,10 @@ pub struct AudioTranscriptionParameters {
5562
/// Either or both of these options are supported: word, or segment.
5663
#[serde(skip_serializing_if = "Option::is_none")]
5764
pub timestamp_granularities: Option<Vec<TimestampGranularity>>,
65+
/// Allows to pass arbitrary json as an extra_body parameter, for specific features/openai-compatible endpoints.
66+
#[serde(flatten)]
67+
#[serde(skip_serializing_if = "Option::is_none")]
68+
pub extra_body: Option<Value>,
5869
}
5970

6071
#[derive(Serialize, Deserialize, Debug, Default, Builder, Clone, PartialEq)]
@@ -150,6 +161,32 @@ pub enum TimestampGranularity {
150161
Segment,
151162
}
152163

164+
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq)]
165+
#[serde(rename_all = "snake_case")]
166+
pub enum TranscriptionChunkingStrategy {
167+
Auto,
168+
#[serde(untagged)]
169+
VadConfig(VadConfig),
170+
}
171+
172+
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq)]
173+
pub struct VadConfig {
174+
/// Must be set to "server_vad" to enable manual chunking using server side VAD.
175+
pub r#type: VadConfigType,
176+
/// Amount of audio to include before the VAD detected speech (in milliseconds).
177+
pub prefix_padding_ms: Option<usize>,
178+
/// Duration of silence to detect speech stop (in milliseconds). With shorter values the model will respond more quickly, but may jump in on short pauses from the user.
179+
pub silence_duration_ms: Option<usize>,
180+
/// Sensitivity threshold (0.0 to 1.0) for voice activity detection. A higher threshold will require louder audio to activate the model, and thus might perform better in noisy environments.
181+
pub threshold: Option<f32>,
182+
}
183+
184+
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq)]
185+
#[serde(rename_all = "snake_case")]
186+
pub enum VadConfigType {
187+
ServerVad,
188+
}
189+
153190
impl Display for AudioOutputFormat {
154191
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155192
write!(
@@ -179,6 +216,22 @@ impl Display for TimestampGranularity {
179216
}
180217
}
181218

219+
impl Display for TranscriptionChunkingStrategy {
220+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
221+
match self {
222+
TranscriptionChunkingStrategy::Auto => "auto".fmt(f),
223+
TranscriptionChunkingStrategy::VadConfig(vad_config) => vad_config.fmt(f),
224+
}
225+
}
226+
}
227+
228+
impl Display for VadConfig {
229+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230+
let s = serde_json::to_string(self).map_err(|_| std::fmt::Error)?;
231+
write!(f, "{}", s)
232+
}
233+
}
234+
182235
impl AudioSpeechResponse {
183236
#[cfg(feature = "tokio")]
184237
pub async fn save<P: AsRef<Path>>(&self, file_path: P) -> Result<(), APIError> {
@@ -203,3 +256,62 @@ impl AudioSpeechResponse {
203256
Ok(())
204257
}
205258
}
259+
260+
#[cfg(test)]
261+
mod tests {
262+
use crate::v1::resources::audio::{
263+
AudioTranscriptionParameters, AudioTranscriptionParametersBuilder,
264+
TranscriptionChunkingStrategy, VadConfig, VadConfigType,
265+
};
266+
use crate::v1::resources::shared::FileUpload;
267+
268+
#[test]
269+
fn test_audio_transcription_chunking_strategy_auto_serialization_deserialization() {
270+
let chunking_strategy = TranscriptionChunkingStrategy::Auto;
271+
272+
let serialized = serde_json::to_string(&chunking_strategy).unwrap();
273+
assert_eq!(serialized, "\"auto\"");
274+
275+
let deserialized: TranscriptionChunkingStrategy =
276+
serde_json::from_str(serialized.as_str()).unwrap();
277+
assert_eq!(deserialized, chunking_strategy)
278+
}
279+
280+
#[test]
281+
fn test_audio_transcription_chunking_strategy_vad_config_serialization_deserialization() {
282+
let chunking_strategy = TranscriptionChunkingStrategy::VadConfig(VadConfig {
283+
r#type: VadConfigType::ServerVad,
284+
prefix_padding_ms: Some(10),
285+
silence_duration_ms: Some(20),
286+
threshold: Some(0.5),
287+
});
288+
289+
let serialized = serde_json::to_string(&chunking_strategy).unwrap();
290+
assert_eq!(serialized, "{\"type\":\"server_vad\",\"prefix_padding_ms\":10,\"silence_duration_ms\":20,\"threshold\":0.5}");
291+
292+
let deserialized: TranscriptionChunkingStrategy =
293+
serde_json::from_str(serialized.as_str()).unwrap();
294+
assert_eq!(deserialized, chunking_strategy)
295+
}
296+
297+
#[test]
298+
fn test_audio_transcription_extra_body_serialization_deserialization() {
299+
let mut builder = &mut AudioTranscriptionParametersBuilder::default();
300+
builder = builder.file(FileUpload::File("test.wav".to_string()));
301+
builder = builder.model("test");
302+
let extra = serde_json::json!({
303+
"enable_my_feature": true,
304+
"my_param": 10
305+
});
306+
builder = builder.extra_body(extra);
307+
308+
let params: AudioTranscriptionParameters = builder.build().unwrap();
309+
310+
let serialized = serde_json::to_string(&params).unwrap();
311+
assert_eq!(serialized, "{\"file\":{\"File\":\"test.wav\"},\"model\":\"test\",\"enable_my_feature\":true,\"my_param\":10}");
312+
313+
let deserialized: AudioTranscriptionParameters =
314+
serde_json::from_str(serialized.as_str()).unwrap();
315+
assert_eq!(deserialized, params)
316+
}
317+
}

0 commit comments

Comments
 (0)