Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix silence detection #175

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
57256ab
Add SilenceDetectionFilter to improve silence detection logic
Jun 26, 2024
b6cb176
Implement detectSilence in TextDecoder
Jun 26, 2024
f724f89
Add silence detection and segment skipping logic in TranscribeTask
Jun 26, 2024
b5fc115
Update models to support silence detection
Jun 26, 2024
581cf77
Add tests for SilenceDetectionFilter
Jun 26, 2024
fd1bea5
Add silent_audio.mp3 for testing fully silent audio
Jun 26, 2024
efe4f21
Add initial_silence_speech.m4a for testing initial silence detection:…
Jun 26, 2024
e8dad50
Add continuous_speech.m4a for testing continuous speech detection
Jun 26, 2024
d95c952
Update TranscribeTask.swift
aigerimmmm Jun 27, 2024
685b8d9
Update Sources/WhisperKit/Core/TextDecoder.swift
aigerimmmm Jun 30, 2024
f94fc94
Update Sources/WhisperKit/Core/TranscribeTask.swift
aigerimmmm Jun 30, 2024
1e97153
Added silence detection and log probability checks in TranscribeTask
Jul 5, 2024
056992a
Kept noSpeechThreshold as original at 0.6
Jul 5, 2024
ffb19ef
Added checking noSpeechProb at the SOT token and added softmax probab…
Jul 5, 2024
83af6a5
Added top-level function for detectSilence.
Jul 5, 2024
38078f4
Added SilenceLogitsFilter to handle suppression of no speech token
Jul 5, 2024
571cca6
Changed silent audio samples to [Float](repeating: 0.0, count: 16000)…
Jul 5, 2024
3a3b512
Update TextDecoder.swift
aigerimmmm Jul 8, 2024
031d44f
Delete Sources/WhisperKit/Core/LogitsFilter.swift
aigerimmmm Jul 8, 2024
7ff3b0a
Restore LogitsFilter.swift file
Jul 8, 2024
c788c16
Update TextDecoder.swift
aigerimmmm Jul 8, 2024
f4ac60a
Update TranscribeTask.swift
aigerimmmm Jul 8, 2024
8cf2ef4
Update UnitTests.swift
aigerimmmm Jul 8, 2024
dda302f
Update TextDecoder.swift
aigerimmmm Jul 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions Sources/WhisperKit/Core/LogitsFilter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,37 @@ open class LanguageLogitsFilter: LogitsFiltering {
return indexes
}
}

@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
open class SilenceLogitsFilter: LogitsFiltering {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second look at the source, I'm not sure we need a filter for this, since that will suppress the log probs of the "predicted token" during inference. What you already have here https://github.com/argmaxinc/WhisperKit/pull/175/files#diff-5dd6579fc66020b1085535bce41d2c2cc399a0b2b8f0ba225fc89f39d9ebdbc8R402 is checking the specific no speech index, which does everything you would need already.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing that you could add to the filters is supressing the no speech token in the SupressTokensFilter similar to this https://github.com/openai/whisper/blob/ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab/whisper/decoding.py#L638-L640 (also needs the remaining suppress tokens but can be fixed later)

let silenceToken: Int
let logitsDim: Int
let sampleBegin: Int
let nonSilenceTokenIndexes: [[NSNumber]]

public init(silenceToken: Int, logitsDim: Int, sampleBegin: Int) {
self.silenceToken = silenceToken
self.logitsDim = logitsDim
self.sampleBegin = sampleBegin
self.nonSilenceTokenIndexes = SilenceLogitsFilter.getNonSilenceTokenIndexes(logitsDim: self.logitsDim, silenceToken: self.silenceToken)
}

/// Retain the logits that correspond to silence tokens and suppress non-silence tokens
public func filterLogits(_ logits: MLMultiArray, withTokens tokens: [Int]) -> MLMultiArray {
guard tokens.count == sampleBegin else {
return logits
}
logits.fill(indexes: nonSilenceTokenIndexes, with: -FloatType.infinity)
return logits
}

private static func getNonSilenceTokenIndexes(logitsDim: Int, silenceToken: Int) -> [[NSNumber]] {
var indexes: [[NSNumber]] = []
for i in 0..<logitsDim {
if i != silenceToken {
indexes.append([0, 0, i as NSNumber])
}
}
return indexes
}
}
91 changes: 89 additions & 2 deletions Sources/WhisperKit/Core/TextDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ public protocol TextDecoding {
options: DecodingOptions,
temperature: FloatType
) async throws -> DecodingResult

func detectSilence(
from encoderOutput: MLMultiArray,
using decoderInputs: DecodingInputs,
sampler tokenSampler: TokenSampling,
options: DecodingOptions,
temperature: FloatType
) async throws -> Float

@available(*, deprecated, message: "Subject to removal in a future version. Use `detectLanguage(from:using:sampler:options:temperature:) async throws -> DecodingResult` instead.")
@_disfavoredOverload
Expand Down Expand Up @@ -340,6 +348,58 @@ public class TextDecoderContextPrefill: WhisperMLModel {

@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
open class TextDecoder: TextDecoding, WhisperMLModel {
func softmax(_ logits: MLMultiArray) -> [Float] {
let count = logits.count
var expValues = [Float](repeating: 0.0, count: count)
var sumExpValues: Float = 0.0

for i in 0..<count {
let expValue = exp(logits[i].floatValue)
expValues[i] = expValue
sumExpValues += expValue
}
let softmaxProbs = expValues.map { $0 / sumExpValues }

return softmaxProbs
}

func calculateNoSpeechProb(logits: MLMultiArray, noSpeechTokenIndex: Int) -> Float {
let softmaxProbs = softmax(logits)
let noSpeechProb = softmaxProbs[noSpeechTokenIndex]

return noSpeechProb
}

public func detectSilence(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea for this helper function, could you also include a top level function so folks can call WhisperKit.detectSilence similar to detectLanguage #146

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, added a top level function in WhisperKit.swift file

from encoderOutput: MLMultiArray,
using decoderInputs: DecodingInputs,
sampler tokenSampler: TokenSampling,
options: DecodingOptions,
temperature: FloatType
) async throws -> Float {
let noSpeechTokenIndex = 50362

let predictedLogits = try await self.predictLogits(
inputIds: decoderInputs.inputIds,
cacheLength: decoderInputs.cacheLength,
keyCache: decoderInputs.keyCache,
valueCache: decoderInputs.valueCache,
kvCacheUpdateMask: decoderInputs.kvCacheUpdateMask,
encoderOutputEmbeds: encoderOutput,
decoderKeyPaddingMask: decoderInputs.decoderKeyPaddingMask
)

guard let logitsArray = predictedLogits?.logits else {
throw WhisperError.decodingLogitsFailed("Unable to decode logits")
}

let noSpeechProb = calculateNoSpeechProb(logits: logitsArray, noSpeechTokenIndex: noSpeechTokenIndex)

return noSpeechProb

}


public var model: MLModel?
public var tokenizer: WhisperTokenizer?
public var prefillData: WhisperMLModel?
Expand Down Expand Up @@ -549,6 +609,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel {
var currentTokens: [Int] = decoderInputs.initialPrompt
var nextToken: Int = decoderInputs.initialPrompt.last!
var logProbs: [Float] = Array(repeating: 0, count: currentTokens.count)
var noSpeechProb: Float = 0.0

// Logits filters
var logitsFilters: [any LogitsFiltering] = []
Expand Down Expand Up @@ -641,6 +702,34 @@ open class TextDecoder: TextDecoding, WhisperMLModel {
for filter in logitsFilters {
logits = filter.filterLogits(logits, withTokens: currentTokens)
}

if tokenIndex == intialPromptIndex {
//print(tokenizer.specialTokens.noSpeechToken) //it prints 50257
let noSpeechTokenIndex = 50362 // I think from models index for the "no speech" token is 50362?
noSpeechProb = calculateNoSpeechProb(logits: logits, noSpeechTokenIndex: noSpeechTokenIndex)

let avgLogProb = logProbs.reduce(0, +) / Float(logProbs.count)

if let threshold = options.noSpeechThreshold, noSpeechProb > threshold {
if options.logProbThreshold == nil || avgLogProb < options.logProbThreshold! {
print("Detected silence with noSpeechProb \(noSpeechProb) and avgLogProb \(avgLogProb), skipping segment.")
return DecodingResult(
language: Constants.defaultLanguageCode,
languageProbs: [:],
tokens: [],
tokenLogProbs: [],
text: "",
avgLogProb: avgLogProb,
noSpeechProb: noSpeechProb,
temperature: 0.0,
compressionRatio: 0.0,
cache: nil,
timings: TranscriptionTimings(),
fallback: nil
)
}
}
}

let filteringTime = Date().timeIntervalSince(nonInferenceStartTime)
timings.decodingFiltering += filteringTime
Expand Down Expand Up @@ -794,8 +883,6 @@ open class TextDecoder: TextDecoding, WhisperMLModel {
temperature = Float(sampler.temperature).rounded(3)
}

let noSpeechProb: Float = 0 // TODO: implement no speech prob

// If language is still nil here, check language can be inferred from tokens
var language = options.language ?? Constants.defaultLanguageCode
var languageProbs = [String: Float]()
Expand Down
6 changes: 6 additions & 0 deletions Sources/WhisperKit/Core/TranscribeTask.swift
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ final class TranscribeTask {
// Send to decoder to predict text tokens with fallback
let decodingResult = try await decodeWithFallback(encoderSegment: encoderOutput, decodingOptions: options, callback: decodingCallback)

if decodingResult.noSpeechProb > (options.noSpeechThreshold ?? 0.6) && decodingResult.avgLogProb < (options.logProbThreshold ?? -1.0) {
seek += segmentSize
continue
}

// MARK: Windowing

// At this point we have a completed window aka segment
Expand Down Expand Up @@ -269,6 +274,7 @@ final class TranscribeTask {
let tokenSampler = GreedyTokenSampler(temperature: temp, eotToken: tokenizer.specialTokens.endToken, decodingOptions: options)

var currentDecodingOptions = options

// For a multilingual model, if language is not passed and detectLanguage is true, detect language and set in options
if textDecoder.isModelMultilingual, options.language == nil, options.detectLanguage {
let languageDecodingResult: DecodingResult? = try? await textDecoder.detectLanguage(
Expand Down
44 changes: 44 additions & 0 deletions Sources/WhisperKit/Core/WhisperKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,50 @@ open class WhisperKit {

return (language: languageDecodingResult.language, langProbs: languageDecodingResult.languageProbs)
}


/// Detects silence in the audio samples in the provided array.
///
/// - Parameter audioArray: An array of audio samples.
/// - Returns: The probability of silence in the audio.
public func detectSilence(audioArray: [Float]) async throws -> Float {
if modelState != .loaded {
try await loadModels()
}

guard let tokenizer else {
throw WhisperError.tokenizerUnavailable()
}

let options = DecodingOptions()
let decoderInputs = try textDecoder.prepareDecoderInputs(withPrompt: [tokenizer.specialTokens.startOfTranscriptToken])
decoderInputs.kvCacheUpdateMask[0] = 1.0
decoderInputs.decoderKeyPaddingMask[0] = 0.0

guard let audioSamples = AudioProcessor.padOrTrimAudio(fromArray: audioArray, startAt: 0, toLength: WhisperKit.windowSamples) else {
throw WhisperError.transcriptionFailed("Audio samples are nil")
}

guard let melOutput = try await featureExtractor.logMelSpectrogram(fromAudio: audioSamples) else {
throw WhisperError.transcriptionFailed("Mel output is nil")
}

guard let encoderOutput = try await audioEncoder.encodeFeatures(melOutput) else {
throw WhisperError.transcriptionFailed("Encoder output is nil")
}

let tokenSampler = GreedyTokenSampler(temperature: 0, eotToken: tokenizer.specialTokens.endToken, decodingOptions: options)
let noSpeechProb = try await textDecoder.detectSilence(
from: encoderOutput,
using: decoderInputs,
sampler: tokenSampler,
options: options,
temperature: 0
)

return noSpeechProb
}


// MARK: - Transcribe multiple audio files

Expand Down
Binary file not shown.
Binary file not shown.
Binary file added Tests/WhisperKitTests/Resources/silent_audio.mp3
Binary file not shown.
97 changes: 95 additions & 2 deletions Tests/WhisperKitTests/UnitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,37 @@ final class UnitTests: XCTestCase {
XCTAssertEqual(result.language, language)
}
}


func testDetectSilenceHelperMethod() async throws {
let whisperKit = try await WhisperKit(
modelFolder: tinyModelPath(),
verbose: true,
logLevel: .debug
)

let silentAudioSamples: [Float] = [Float](repeating: 0.0, count: 16000) // 1 second of silence at 16kHz
let jfkAudioSamples = try XCTUnwrap(loadAudioSamples(forResource: "ted_60", withExtension: "m4a"))

let testAudioFiles: [(String, [Float], Bool)] = [
("silent_clip", silentAudioSamples, false), // Not expecting speech
("non_silent_clip", jfkAudioSamples, true) // Expecting speech
]

for (audioFileName, audioSamples, expectingSpeech) in testAudioFiles {
let silenceProbability = try await whisperKit.detectSilence(audioArray: audioSamples)

//print("Test case: \(audioFileName), Expecting speech: \(expectingSpeech), Calculated silence probability: \(silenceProbability)")
// calculated noSpeechProb values for silent and non-silent clips are 0.002598221 and 0.26186648.
// Given these values, a threshold of 0.6 might be too high to accurately distinguish between
// silence and speech.Based on the debug values, here I picked a threshold of 0.2
if expectingSpeech {
XCTAssertGreaterThan(silenceProbability, 0.2, "Expected speech, but detected silence for \(audioFileName) with probability \(silenceProbability)")
} else {
XCTAssertLessThanOrEqual(silenceProbability, 0.2, "Expected silence, but detected speech for \(audioFileName) with probability \(silenceProbability)")
}
}
}

func testNoTimestamps() async throws {
let options = DecodingOptions(withoutTimestamps: true)

Expand Down Expand Up @@ -708,7 +738,70 @@ final class UnitTests: XCTestCase {

XCTAssertNotNil(result.text)
}


func testSilentAudio() async throws {
let whisperKit = try await WhisperKit(modelFolder: tinyModelPath(), verbose: true, logLevel: .debug)

let silentAudioSamples: [Float] = [Float](repeating: 0.0, count: 16000)

let options = DecodingOptions(usePrefillPrompt: false, skipSpecialTokens: false)

let result: [TranscriptionResult] = try await whisperKit.transcribe(audioArray: silentAudioSamples, decodeOptions: options)

XCTAssertTrue(result.first?.segments.isEmpty ?? false, "Expected no segments for silent audio")
}

func testInitialSilenceFollowedBySpeech() async throws {
let whisperKit = try await WhisperKit(modelFolder: tinyModelPath(), verbose: true, logLevel: .debug)

let initialSilenceSpeechSamples: [Float] = loadAudioSamples(forResource: "initial_silence_speech", withExtension: "m4a")

let options = DecodingOptions(usePrefillPrompt: false, skipSpecialTokens: false, noSpeechThreshold: 0.8)

let result: [TranscriptionResult] = try await whisperKit.transcribe(audioArray: initialSilenceSpeechSamples, decodeOptions: options)

if let transcription = result.first?.segments.first?.text {
print("Transcription: \(transcription)")
} else {
print("No transcription found.")
}

let transcription = result.first?.segments.first?.text
XCTAssertNotNil(transcription, "Expected transcription for audio with initial silence followed by speech")

XCTAssertTrue(transcription?.contains("Hey") ?? false, "Expected 'Hey' in transcription")
}
func testContinuousSpeechAudio() async throws {
let whisperKit = try await WhisperKit(modelFolder: tinyModelPath(), verbose: true, logLevel: .debug)

let continuousSpeechSamples: [Float] = loadAudioSamples(forResource: "continuous_speech", withExtension: "wav")
let options = DecodingOptions(usePrefillPrompt: false, skipSpecialTokens: false)

let result: [TranscriptionResult] = try await whisperKit.transcribe(audioArray: continuousSpeechSamples, decodeOptions: options)

let transcription = result.first?.segments.first?.text
XCTAssertNotNil(transcription, "Expected transcription for continuous speech audio")
XCTAssertFalse(transcription?.isEmpty ?? true, "Expected non-empty transcription for continuous speech audio")
}

// MARK: - Helper Function

func loadAudioSamples(forResource resource: String, withExtension ext: String) -> [Float] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea 👍

guard let audioFileURL = Bundle.module.url(forResource: resource, withExtension: ext) else {
XCTFail("Audio file not found")
return []
}

do {
let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioFileURL.path)
return AudioProcessor.convertBufferToArray(buffer: audioBuffer)
} catch {
XCTFail("Failed to load audio samples: \(error.localizedDescription)")
return []
}
}

func testSilence() async throws {
let whisperKit = try await WhisperKit(modelFolder: tinyModelPath(), verbose: true, logLevel: .debug)
let audioSamples = [Float](repeating: 0.0, count: 30 * 16000)
Expand Down Expand Up @@ -920,7 +1013,7 @@ final class UnitTests: XCTestCase {
let result2 = tokensFilter2.filterLogits(logits2, withTokens: [1])
XCTAssertEqual(result2.data(for: 2), [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
}

func testTimestampRulesFilter() throws {
// NOTE: for non-multilingual models we supress tokens immediately
let tokensFilter1 = TimestampRulesFilter(
Expand Down