Skip to content

Commit adee3f9

Browse files
authored
node : add flash_attn param (#2170)
1 parent 4798be1 commit adee3f9

File tree

3 files changed

+6
-0
lines changed

3 files changed

+6
-0
lines changed

examples/addon.node/__test__/whisper.spec.js

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ const whisperParamsMock = {
1212
model: path.join(__dirname, "../../../models/ggml-base.en.bin"),
1313
fname_inp: path.join(__dirname, "../../../samples/jfk.wav"),
1414
use_gpu: true,
15+
flash_attn: false,
1516
no_prints: true,
1617
comma_in_time: false,
1718
translate: true,

examples/addon.node/addon.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ struct whisper_params {
3939
bool no_timestamps = false;
4040
bool no_prints = false;
4141
bool use_gpu = true;
42+
bool flash_attn = false;
4243
bool comma_in_time = true;
4344

4445
std::string language = "en";
@@ -146,6 +147,7 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
146147

147148
struct whisper_context_params cparams = whisper_context_default_params();
148149
cparams.use_gpu = params.use_gpu;
150+
cparams.flash_attn = params.flash_attn;
149151
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
150152

151153
if (ctx == nullptr) {
@@ -326,6 +328,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
326328
std::string model = whisper_params.Get("model").As<Napi::String>();
327329
std::string input = whisper_params.Get("fname_inp").As<Napi::String>();
328330
bool use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>();
331+
bool flash_attn = whisper_params.Get("flash_attn").As<Napi::Boolean>();
329332
bool no_prints = whisper_params.Get("no_prints").As<Napi::Boolean>();
330333
bool no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>();
331334
int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
@@ -346,6 +349,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
346349
params.model = model;
347350
params.fname_inp.emplace_back(input);
348351
params.use_gpu = use_gpu;
352+
params.flash_attn = flash_attn;
349353
params.no_prints = no_prints;
350354
params.no_timestamps = no_timestamps;
351355
params.audio_ctx = audio_ctx;

examples/addon.node/index.js

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ const whisperParams = {
1212
model: path.join(__dirname, "../../models/ggml-base.en.bin"),
1313
fname_inp: path.join(__dirname, "../../samples/jfk.wav"),
1414
use_gpu: true,
15+
flash_attn: false,
1516
no_prints: true,
1617
comma_in_time: false,
1718
translate: true,

0 commit comments

Comments
 (0)