Skip to content

Commit f07a81a

Browse files
ruby : bug fix on callbacks and no_speech_prob (#2656)
* Don't generate documentation on test * Move .startup to TestBase class * Extract new_segment_callback as a function * Extract progress_callback as a function * Extract abort_callback as a function * Extract register_callbacks as a function * Call callbacks in Whiser::Context#full and #full_parallel * Fix README * Care about the cases content-size is nil and TTY is not available * Add tests for no_speech_prob * Add Whisper::Context#full_get_segment_no_speech_prob and Whisper::Segment#no_speech_prob
1 parent 4183517 commit f07a81a

File tree

7 files changed

+167
-124
lines changed

7 files changed

+167
-124
lines changed

bindings/ruby/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ whisper = Whisper::Context.new("base.en")
6363
You can see the list of prepared model names by `Whisper::Model.preconverted_models.keys`:
6464

6565
```ruby
66-
puts Whisper::Model.preconverted_model_names
66+
puts Whisper::Model.preconverted_models.keys
6767
# tiny
6868
# tiny.en
6969
# tiny-q5_1
@@ -220,7 +220,7 @@ whisper.each_segment do |segment|
220220
end
221221
```
222222

223-
The second argument `samples` may be an array, an object with `length` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy.
223+
The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy.
224224

225225
License
226226
-------

bindings/ruby/ext/ruby_whisper.cpp

+114-76
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ static ID id_pre_converted_models;
5353

5454
static bool is_log_callback_finalized = false;
5555

56+
// High level API
57+
static VALUE rb_whisper_segment_initialize(VALUE context, int index);
58+
5659
/*
5760
* call-seq:
5861
* lang_max_id -> Integer
@@ -187,6 +190,69 @@ static ruby_whisper_callback_container * rb_whisper_callback_container_allocate(
187190
return container;
188191
}
189192

193+
static void new_segment_callback(struct whisper_context *ctx, struct whisper_state *state, int n_new, void *user_data) {
194+
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
195+
196+
// Currently, doesn't support state because
197+
// those require to resolve GC-related problems.
198+
if (!NIL_P(container->callback)) {
199+
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data);
200+
}
201+
const long callbacks_len = RARRAY_LEN(container->callbacks);
202+
if (0 == callbacks_len) {
203+
return;
204+
}
205+
const int n_segments = whisper_full_n_segments_from_state(state);
206+
for (int i = n_new; i > 0; i--) {
207+
int i_segment = n_segments - i;
208+
VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment);
209+
for (int j = 0; j < callbacks_len; j++) {
210+
VALUE cb = rb_ary_entry(container->callbacks, j);
211+
rb_funcall(cb, id_call, 1, segment);
212+
}
213+
}
214+
}
215+
216+
static void progress_callback(struct whisper_context *ctx, struct whisper_state *state, int progress_cur, void *user_data) {
217+
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
218+
const VALUE progress = INT2NUM(progress_cur);
219+
// Currently, doesn't support state because
220+
// those require to resolve GC-related problems.
221+
if (!NIL_P(container->callback)) {
222+
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data);
223+
}
224+
const long callbacks_len = RARRAY_LEN(container->callbacks);
225+
if (0 == callbacks_len) {
226+
return;
227+
}
228+
for (int j = 0; j < callbacks_len; j++) {
229+
VALUE cb = rb_ary_entry(container->callbacks, j);
230+
rb_funcall(cb, id_call, 1, progress);
231+
}
232+
}
233+
234+
static bool abort_callback(void * user_data) {
235+
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
236+
if (!NIL_P(container->callback)) {
237+
VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
238+
if (!NIL_P(result) && Qfalse != result) {
239+
return true;
240+
}
241+
}
242+
const long callbacks_len = RARRAY_LEN(container->callbacks);
243+
if (0 == callbacks_len) {
244+
return false;
245+
}
246+
for (int j = 0; j < callbacks_len; j++) {
247+
VALUE cb = rb_ary_entry(container->callbacks, j);
248+
VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
249+
if (!NIL_P(result) && Qfalse != result) {
250+
return true;
251+
}
252+
}
253+
return false;
254+
}
255+
190256
static VALUE ruby_whisper_params_allocate(VALUE klass) {
191257
ruby_whisper_params *rwp;
192258
rwp = ALLOC(ruby_whisper_params);
@@ -230,8 +296,25 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
230296
return self;
231297
}
232298

233-
// High level API
234-
static VALUE rb_whisper_segment_initialize(VALUE context, int index);
299+
static void register_callbacks(ruby_whisper_params * rwp, VALUE * self) {
300+
if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
301+
rwp->new_segment_callback_container->context = self;
302+
rwp->params.new_segment_callback = new_segment_callback;
303+
rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
304+
}
305+
306+
if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) {
307+
rwp->progress_callback_container->context = self;
308+
rwp->params.progress_callback = progress_callback;
309+
rwp->params.progress_callback_user_data = rwp->progress_callback_container;
310+
}
311+
312+
if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
313+
rwp->abort_callback_container->context = self;
314+
rwp->params.abort_callback = abort_callback;
315+
rwp->params.abort_callback_user_data = rwp->abort_callback_container;
316+
}
317+
}
235318

236319
/*
237320
* transcribe a single file
@@ -353,80 +436,7 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
353436
rwp->params.encoder_begin_callback_user_data = &is_aborted;
354437
}
355438

356-
if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
357-
rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
358-
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
359-
360-
// Currently, doesn't support state because
361-
// those require to resolve GC-related problems.
362-
if (!NIL_P(container->callback)) {
363-
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data);
364-
}
365-
const long callbacks_len = RARRAY_LEN(container->callbacks);
366-
if (0 == callbacks_len) {
367-
return;
368-
}
369-
const int n_segments = whisper_full_n_segments_from_state(state);
370-
for (int i = n_new; i > 0; i--) {
371-
int i_segment = n_segments - i;
372-
VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment);
373-
for (int j = 0; j < callbacks_len; j++) {
374-
VALUE cb = rb_ary_entry(container->callbacks, j);
375-
rb_funcall(cb, id_call, 1, segment);
376-
}
377-
}
378-
};
379-
rwp->new_segment_callback_container->context = &self;
380-
rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
381-
}
382-
383-
if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) {
384-
rwp->params.progress_callback = [](struct whisper_context *ctx, struct whisper_state * /*state*/, int progress_cur, void *user_data) {
385-
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
386-
const VALUE progress = INT2NUM(progress_cur);
387-
// Currently, doesn't support state because
388-
// those require to resolve GC-related problems.
389-
if (!NIL_P(container->callback)) {
390-
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data);
391-
}
392-
const long callbacks_len = RARRAY_LEN(container->callbacks);
393-
if (0 == callbacks_len) {
394-
return;
395-
}
396-
for (int j = 0; j < callbacks_len; j++) {
397-
VALUE cb = rb_ary_entry(container->callbacks, j);
398-
rb_funcall(cb, id_call, 1, progress);
399-
}
400-
};
401-
rwp->progress_callback_container->context = &self;
402-
rwp->params.progress_callback_user_data = rwp->progress_callback_container;
403-
}
404-
405-
if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
406-
rwp->params.abort_callback = [](void * user_data) {
407-
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
408-
if (!NIL_P(container->callback)) {
409-
VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
410-
if (!NIL_P(result) && Qfalse != result) {
411-
return true;
412-
}
413-
}
414-
const long callbacks_len = RARRAY_LEN(container->callbacks);
415-
if (0 == callbacks_len) {
416-
return false;
417-
}
418-
for (int j = 0; j < callbacks_len; j++) {
419-
VALUE cb = rb_ary_entry(container->callbacks, j);
420-
VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
421-
if (!NIL_P(result) && Qfalse != result) {
422-
return true;
423-
}
424-
}
425-
return false;
426-
};
427-
rwp->abort_callback_container->context = &self;
428-
rwp->params.abort_callback_user_data = rwp->abort_callback_container;
429-
}
439+
register_callbacks(rwp, &self);
430440

431441
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
432442
fprintf(stderr, "failed to process audio\n");
@@ -631,6 +641,7 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) {
631641
}
632642
}
633643
}
644+
register_callbacks(rwp, &self);
634645
const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
635646
if (0 == result) {
636647
return Qnil;
@@ -719,6 +730,7 @@ static VALUE ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) {
719730
}
720731
}
721732
}
733+
register_callbacks(rwp, &self);
722734
const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
723735
if (0 == result) {
724736
return Qnil;
@@ -823,6 +835,18 @@ static VALUE ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) {
823835
return rb_str_new2(text);
824836
}
825837

838+
/*
839+
* call-seq:
840+
* full_get_segment_no_speech_prob -> Float
841+
*/
842+
static VALUE ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment) {
843+
ruby_whisper *rw;
844+
Data_Get_Struct(self, ruby_whisper, rw);
845+
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
846+
const float no_speech_prob = whisper_full_get_segment_no_speech_prob(rw->context, c_i_segment);
847+
return DBL2NUM(no_speech_prob);
848+
}
849+
826850
/*
827851
* params.language = "auto" | "en", etc...
828852
*
@@ -1547,6 +1571,18 @@ static VALUE ruby_whisper_segment_get_text(VALUE self) {
15471571
return rb_str_new2(text);
15481572
}
15491573

1574+
/*
1575+
* call-seq:
1576+
* no_speech_prob -> Float
1577+
*/
1578+
static VALUE ruby_whisper_segment_get_no_speech_prob(VALUE self) {
1579+
ruby_whisper_segment *rws;
1580+
Data_Get_Struct(self, ruby_whisper_segment, rws);
1581+
ruby_whisper *rw;
1582+
Data_Get_Struct(rws->context, ruby_whisper, rw);
1583+
return DBL2NUM(whisper_full_get_segment_no_speech_prob(rw->context, rws->index));
1584+
}
1585+
15501586
static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
15511587
rb_gc_mark(rwm->context);
15521588
}
@@ -1809,6 +1845,7 @@ void Init_whisper() {
18091845
rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1);
18101846
rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1);
18111847
rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1);
1848+
rb_define_method(cContext, "full_get_segment_no_speech_prob", ruby_whisper_full_get_segment_no_speech_prob, 1);
18121849
rb_define_method(cContext, "full", ruby_whisper_full, -1);
18131850
rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1);
18141851

@@ -1887,6 +1924,7 @@ void Init_whisper() {
18871924
rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
18881925
rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);
18891926
rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0);
1927+
rb_define_method(cSegment, "no_speech_prob", ruby_whisper_segment_get_no_speech_prob, 0);
18901928

18911929
cModel = rb_define_class_under(mWhisper, "Model", rb_cObject);
18921930
rb_define_alloc_func(cModel, ruby_whisper_model_allocate);

bindings/ruby/lib/whisper/model/uri.rb

+19-13
Original file line numberDiff line numberDiff line change
@@ -79,30 +79,36 @@ def download(response)
7979
downloaded += chunk.bytesize
8080
show_progress downloaded, size
8181
end
82+
$stderr.puts
8283
end
8384
downloading_path.rename path
8485
end
8586

8687
def show_progress(current, size)
87-
return unless $stderr.tty?
88-
return unless size
88+
progress_rate_available = size && $stderr.tty?
8989

9090
unless @prev
9191
@prev = Time.now
92-
$stderr.puts "Downloading #{@uri}"
92+
$stderr.puts "Downloading #{@uri} to #{cache_path}"
9393
end
9494

9595
now = Time.now
96-
return if now - @prev < 1 && current < size
97-
98-
progress_width = 20
99-
progress = current.to_f / size
100-
arrow_length = progress * progress_width
101-
arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length)
102-
line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})"
103-
padding = ' ' * ($stderr.winsize[1] - line.size)
104-
$stderr.print "\r#{line}#{padding}"
105-
$stderr.puts if current >= size
96+
97+
if progress_rate_available
98+
return if now - @prev < 1 && current < size
99+
100+
progress_width = 20
101+
progress = current.to_f / size
102+
arrow_length = progress * progress_width
103+
arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length)
104+
line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})"
105+
padding = ' ' * ($stderr.winsize[1] - line.size)
106+
$stderr.print "\r#{line}#{padding}"
107+
else
108+
return if now - @prev < 1
109+
110+
$stderr.print "."
111+
end
106112
@prev = now
107113
end
108114

bindings/ruby/tests/helper.rb

+17
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,21 @@
44

55
class TestBase < Test::Unit::TestCase
66
AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
7+
8+
class << self
9+
attr_reader :whisper
10+
11+
def startup
12+
@whisper = Whisper::Context.new("base.en")
13+
params = Whisper::Params.new
14+
params.print_timestamps = false
15+
@whisper.transcribe(TestBase::AUDIO, params)
16+
end
17+
end
18+
19+
private
20+
21+
def whisper
22+
self.class.whisper
23+
end
724
end

bindings/ruby/tests/test_package.rb

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_install
2323
version = match_data[2]
2424
basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}"
2525
Dir.mktmpdir do |dir|
26-
system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true
26+
system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{filename.shellescape}", exception: true
2727
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", basename)
2828
end
2929
end

0 commit comments

Comments
 (0)