-
Notifications
You must be signed in to change notification settings - Fork 376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix silence detection #175
base: main
Are you sure you want to change the base?
Changes from all commits
57256ab
b6cb176
f724f89
b5fc115
581cf77
fd1bea5
efe4f21
e8dad50
d95c952
685b8d9
f94fc94
1e97153
056992a
ffb19ef
83af6a5
38078f4
571cca6
3a3b512
031d44f
7ff3b0a
c788c16
f4ac60a
8cf2ef4
dda302f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,6 +57,14 @@ public protocol TextDecoding { | |
options: DecodingOptions, | ||
temperature: FloatType | ||
) async throws -> DecodingResult | ||
|
||
func detectSilence( | ||
from encoderOutput: MLMultiArray, | ||
using decoderInputs: DecodingInputs, | ||
sampler tokenSampler: TokenSampling, | ||
options: DecodingOptions, | ||
temperature: FloatType | ||
) async throws -> Float | ||
|
||
@available(*, deprecated, message: "Subject to removal in a future version. Use `detectLanguage(from:using:sampler:options:temperature:) async throws -> DecodingResult` instead.") | ||
@_disfavoredOverload | ||
|
@@ -340,6 +348,58 @@ public class TextDecoderContextPrefill: WhisperMLModel { | |
|
||
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) | ||
open class TextDecoder: TextDecoding, WhisperMLModel { | ||
func softmax(_ logits: MLMultiArray) -> [Float] { | ||
let count = logits.count | ||
var expValues = [Float](repeating: 0.0, count: count) | ||
var sumExpValues: Float = 0.0 | ||
|
||
for i in 0..<count { | ||
let expValue = exp(logits[i].floatValue) | ||
expValues[i] = expValue | ||
sumExpValues += expValue | ||
} | ||
let softmaxProbs = expValues.map { $0 / sumExpValues } | ||
|
||
return softmaxProbs | ||
} | ||
|
||
func calculateNoSpeechProb(logits: MLMultiArray, noSpeechTokenIndex: Int) -> Float { | ||
let softmaxProbs = softmax(logits) | ||
let noSpeechProb = softmaxProbs[noSpeechTokenIndex] | ||
|
||
return noSpeechProb | ||
} | ||
|
||
public func detectSilence( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea for this helper function, could you also include a top level function so folks can call WhisperKit.detectSilence similar to detectLanguage #146 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, added a top level function in WhisperKit.swift file |
||
from encoderOutput: MLMultiArray, | ||
using decoderInputs: DecodingInputs, | ||
sampler tokenSampler: TokenSampling, | ||
options: DecodingOptions, | ||
temperature: FloatType | ||
) async throws -> Float { | ||
let noSpeechTokenIndex = 50362 | ||
|
||
let predictedLogits = try await self.predictLogits( | ||
inputIds: decoderInputs.inputIds, | ||
cacheLength: decoderInputs.cacheLength, | ||
keyCache: decoderInputs.keyCache, | ||
valueCache: decoderInputs.valueCache, | ||
kvCacheUpdateMask: decoderInputs.kvCacheUpdateMask, | ||
encoderOutputEmbeds: encoderOutput, | ||
decoderKeyPaddingMask: decoderInputs.decoderKeyPaddingMask | ||
) | ||
|
||
guard let logitsArray = predictedLogits?.logits else { | ||
throw WhisperError.decodingLogitsFailed("Unable to decode logits") | ||
} | ||
|
||
let noSpeechProb = calculateNoSpeechProb(logits: logitsArray, noSpeechTokenIndex: noSpeechTokenIndex) | ||
|
||
return noSpeechProb | ||
|
||
} | ||
|
||
|
||
public var model: MLModel? | ||
public var tokenizer: WhisperTokenizer? | ||
public var prefillData: WhisperMLModel? | ||
|
@@ -549,6 +609,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel { | |
var currentTokens: [Int] = decoderInputs.initialPrompt | ||
var nextToken: Int = decoderInputs.initialPrompt.last! | ||
var logProbs: [Float] = Array(repeating: 0, count: currentTokens.count) | ||
var noSpeechProb: Float = 0.0 | ||
|
||
// Logits filters | ||
var logitsFilters: [any LogitsFiltering] = [] | ||
|
@@ -641,6 +702,34 @@ open class TextDecoder: TextDecoding, WhisperMLModel { | |
for filter in logitsFilters { | ||
logits = filter.filterLogits(logits, withTokens: currentTokens) | ||
} | ||
|
||
if tokenIndex == intialPromptIndex { | ||
//print(tokenizer.specialTokens.noSpeechToken) //it prints 50257 | ||
let noSpeechTokenIndex = 50362 // I think from models index for the "no speech" token is 50362? | ||
noSpeechProb = calculateNoSpeechProb(logits: logits, noSpeechTokenIndex: noSpeechTokenIndex) | ||
|
||
let avgLogProb = logProbs.reduce(0, +) / Float(logProbs.count) | ||
|
||
if let threshold = options.noSpeechThreshold, noSpeechProb > threshold { | ||
if options.logProbThreshold == nil || avgLogProb < options.logProbThreshold! { | ||
print("Detected silence with noSpeechProb \(noSpeechProb) and avgLogProb \(avgLogProb), skipping segment.") | ||
return DecodingResult( | ||
language: Constants.defaultLanguageCode, | ||
languageProbs: [:], | ||
tokens: [], | ||
tokenLogProbs: [], | ||
text: "", | ||
avgLogProb: avgLogProb, | ||
noSpeechProb: noSpeechProb, | ||
temperature: 0.0, | ||
compressionRatio: 0.0, | ||
cache: nil, | ||
timings: TranscriptionTimings(), | ||
fallback: nil | ||
) | ||
} | ||
} | ||
} | ||
|
||
let filteringTime = Date().timeIntervalSince(nonInferenceStartTime) | ||
timings.decodingFiltering += filteringTime | ||
|
@@ -794,8 +883,6 @@ open class TextDecoder: TextDecoding, WhisperMLModel { | |
temperature = Float(sampler.temperature).rounded(3) | ||
} | ||
|
||
let noSpeechProb: Float = 0 // TODO: implement no speech prob | ||
|
||
// If language is still nil here, check language can be inferred from tokens | ||
var language = options.language ?? Constants.defaultLanguageCode | ||
var languageProbs = [String: Float]() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -666,7 +666,37 @@ final class UnitTests: XCTestCase { | |
XCTAssertEqual(result.language, language) | ||
} | ||
} | ||
|
||
|
||
func testDetectSilenceHelperMethod() async throws { | ||
let whisperKit = try await WhisperKit( | ||
modelFolder: tinyModelPath(), | ||
verbose: true, | ||
logLevel: .debug | ||
) | ||
|
||
let silentAudioSamples: [Float] = [Float](repeating: 0.0, count: 16000) // 1 second of silence at 16kHz | ||
let jfkAudioSamples = try XCTUnwrap(loadAudioSamples(forResource: "ted_60", withExtension: "m4a")) | ||
|
||
let testAudioFiles: [(String, [Float], Bool)] = [ | ||
("silent_clip", silentAudioSamples, false), // Not expecting speech | ||
("non_silent_clip", jfkAudioSamples, true) // Expecting speech | ||
] | ||
|
||
for (audioFileName, audioSamples, expectingSpeech) in testAudioFiles { | ||
let silenceProbability = try await whisperKit.detectSilence(audioArray: audioSamples) | ||
|
||
//print("Test case: \(audioFileName), Expecting speech: \(expectingSpeech), Calculated silence probability: \(silenceProbability)") | ||
// calculated noSpeechProb values for silent and non-silent clips are 0.002598221 and 0.26186648. | ||
// Given these values, a threshold of 0.6 might be too high to accurately distinguish between | ||
// silence and speech.Based on the debug values, here I picked a threshold of 0.2 | ||
if expectingSpeech { | ||
XCTAssertGreaterThan(silenceProbability, 0.2, "Expected speech, but detected silence for \(audioFileName) with probability \(silenceProbability)") | ||
} else { | ||
XCTAssertLessThanOrEqual(silenceProbability, 0.2, "Expected silence, but detected speech for \(audioFileName) with probability \(silenceProbability)") | ||
} | ||
} | ||
} | ||
|
||
func testNoTimestamps() async throws { | ||
let options = DecodingOptions(withoutTimestamps: true) | ||
|
||
|
@@ -708,7 +738,70 @@ final class UnitTests: XCTestCase { | |
|
||
XCTAssertNotNil(result.text) | ||
} | ||
|
||
|
||
func testSilentAudio() async throws { | ||
let whisperKit = try await WhisperKit(modelFolder: tinyModelPath(), verbose: true, logLevel: .debug) | ||
|
||
let silentAudioSamples: [Float] = [Float](repeating: 0.0, count: 16000) | ||
|
||
let options = DecodingOptions(usePrefillPrompt: false, skipSpecialTokens: false) | ||
|
||
let result: [TranscriptionResult] = try await whisperKit.transcribe(audioArray: silentAudioSamples, decodeOptions: options) | ||
|
||
XCTAssertTrue(result.first?.segments.isEmpty ?? false, "Expected no segments for silent audio") | ||
} | ||
|
||
func testInitialSilenceFollowedBySpeech() async throws { | ||
let whisperKit = try await WhisperKit(modelFolder: tinyModelPath(), verbose: true, logLevel: .debug) | ||
|
||
let initialSilenceSpeechSamples: [Float] = loadAudioSamples(forResource: "initial_silence_speech", withExtension: "m4a") | ||
|
||
let options = DecodingOptions(usePrefillPrompt: false, skipSpecialTokens: false, noSpeechThreshold: 0.8) | ||
|
||
let result: [TranscriptionResult] = try await whisperKit.transcribe(audioArray: initialSilenceSpeechSamples, decodeOptions: options) | ||
|
||
if let transcription = result.first?.segments.first?.text { | ||
print("Transcription: \(transcription)") | ||
} else { | ||
print("No transcription found.") | ||
} | ||
|
||
let transcription = result.first?.segments.first?.text | ||
XCTAssertNotNil(transcription, "Expected transcription for audio with initial silence followed by speech") | ||
|
||
XCTAssertTrue(transcription?.contains("Hey") ?? false, "Expected 'Hey' in transcription") | ||
} | ||
func testContinuousSpeechAudio() async throws { | ||
let whisperKit = try await WhisperKit(modelFolder: tinyModelPath(), verbose: true, logLevel: .debug) | ||
|
||
let continuousSpeechSamples: [Float] = loadAudioSamples(forResource: "continuous_speech", withExtension: "wav") | ||
let options = DecodingOptions(usePrefillPrompt: false, skipSpecialTokens: false) | ||
|
||
let result: [TranscriptionResult] = try await whisperKit.transcribe(audioArray: continuousSpeechSamples, decodeOptions: options) | ||
|
||
let transcription = result.first?.segments.first?.text | ||
XCTAssertNotNil(transcription, "Expected transcription for continuous speech audio") | ||
XCTAssertFalse(transcription?.isEmpty ?? true, "Expected non-empty transcription for continuous speech audio") | ||
} | ||
|
||
// MARK: - Helper Function | ||
|
||
func loadAudioSamples(forResource resource: String, withExtension ext: String) -> [Float] { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea 👍 |
||
guard let audioFileURL = Bundle.module.url(forResource: resource, withExtension: ext) else { | ||
XCTFail("Audio file not found") | ||
return [] | ||
} | ||
|
||
do { | ||
let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioFileURL.path) | ||
return AudioProcessor.convertBufferToArray(buffer: audioBuffer) | ||
} catch { | ||
XCTFail("Failed to load audio samples: \(error.localizedDescription)") | ||
return [] | ||
} | ||
} | ||
|
||
func testSilence() async throws { | ||
let whisperKit = try await WhisperKit(modelFolder: tinyModelPath(), verbose: true, logLevel: .debug) | ||
let audioSamples = [Float](repeating: 0.0, count: 30 * 16000) | ||
|
@@ -920,7 +1013,7 @@ final class UnitTests: XCTestCase { | |
let result2 = tokensFilter2.filterLogits(logits2, withTokens: [1]) | ||
XCTAssertEqual(result2.data(for: 2), [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]) | ||
} | ||
|
||
func testTimestampRulesFilter() throws { | ||
// NOTE: for non-multilingual models we supress tokens immediately | ||
let tokensFilter1 = TimestampRulesFilter( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second look at the source, I'm not sure we need a filter for this, since that will suppress the log probs of the "predicted token" during inference. What you already have here https://github.com/argmaxinc/WhisperKit/pull/175/files#diff-5dd6579fc66020b1085535bce41d2c2cc399a0b2b8f0ba225fc89f39d9ebdbc8R402 is checking the specific no speech index, which does everything you would need already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing that you could add to the filters is supressing the no speech token in the
SupressTokensFilter
similar to this https://github.com/openai/whisper/blob/ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab/whisper/decoding.py#L638-L640 (also needs the remaining suppress tokens but can be fixed later)