Skip to content

Commit

Permalink
whisper : use flash attention (#2152)
Browse files Browse the repository at this point in the history
* whisper : use flash attention in the encoder

* whisper : add kv_pad

* whisper : remove extra backend instance (huh?)

* whisper : use FA for cross-attention

* whisper : use FA for self-attention

* whisper : simplify encoder FA

* whisper : add flash_attn runtime parameter

* scripts : add bench log

* scripts : add M1 Pro bench log
  • Loading branch information
ggerganov authored May 15, 2024
1 parent 9d5771a commit 7094ea5
Show file tree
Hide file tree
Showing 13 changed files with 658 additions and 173 deletions.
17 changes: 11 additions & 6 deletions examples/bench/bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ struct whisper_params {

std::string model = "models/ggml-base.en.bin";

bool use_gpu = true;
bool use_gpu = true;
bool flash_attn = false;
};

void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
Expand All @@ -25,10 +26,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
Expand All @@ -49,6 +51,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false");
fprintf(stderr, " %-7s 0 - whisper\n", "");
fprintf(stderr, " %-7s 1 - memcpy\n", "");
fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
Expand All @@ -59,7 +62,9 @@ int whisper_bench_full(const whisper_params & params) {
// whisper init

struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu;

cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;

struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);

Expand Down
7 changes: 6 additions & 1 deletion examples/command/command.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ struct whisper_params {
bool print_energy = false;
bool no_timestamps = true;
bool use_gpu = true;
bool flash_attn = false;

std::string language = "en";
std::string model = "models/ggml-base.en.bin";
Expand Down Expand Up @@ -80,6 +81,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
Expand Down Expand Up @@ -118,6 +120,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
Expand Down Expand Up @@ -696,7 +699,9 @@ int main(int argc, char ** argv) {
// whisper init

struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu;

cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;

struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);

Expand Down
8 changes: 7 additions & 1 deletion examples/lsp/lsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ struct whisper_params {
bool print_special = false;
bool print_energy = false;
bool use_gpu = true;
bool flash_attn = false;

std::string language = "en";
std::string model = "models/ggml-base.en.bin";
Expand Down Expand Up @@ -74,6 +75,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else {
Expand Down Expand Up @@ -105,6 +107,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, "\n");
Expand Down Expand Up @@ -436,7 +439,10 @@ int main(int argc, char ** argv) {

// whisper init
struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu;

cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;

struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
// init audio

Expand Down
9 changes: 7 additions & 2 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ struct whisper_params {
bool no_timestamps = false;
bool log_score = false;
bool use_gpu = true;
bool flash_attn = false;

std::string language = "en";
std::string prompt;
Expand Down Expand Up @@ -168,7 +169,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; }
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
Expand Down Expand Up @@ -234,6 +236,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
Expand Down Expand Up @@ -977,7 +980,9 @@ int main(int argc, char ** argv) {
// whisper init

struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu;

cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;

if (!params.dtw.empty()) {
cparams.dtw_token_timestamps = true;
Expand Down
7 changes: 6 additions & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ struct whisper_params {
bool print_progress = false;
bool no_timestamps = false;
bool use_gpu = true;
bool flash_attn = false;

std::string language = "en";
std::string prompt = "";
Expand Down Expand Up @@ -178,6 +179,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
// server params
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
else if ( arg == "--host") { sparams.hostname = argv[++i]; }
Expand Down Expand Up @@ -502,7 +504,10 @@ int main(int argc, char ** argv) {
}
// whisper init
struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu;

cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;

if (!params.dtw.empty()) {
cparams.dtw_token_timestamps = true;
cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE;
Expand Down
7 changes: 6 additions & 1 deletion examples/stream/stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct whisper_params {
bool tinydiarize = false;
bool save_audio = false; // save audio to wav file
bool use_gpu = true;
bool flash_attn = false;

std::string language = "en";
std::string model = "models/ggml-base.en.bin";
Expand Down Expand Up @@ -72,6 +73,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
else if (arg == "-sa" || arg == "--save-audio") { params.save_audio = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }

else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
Expand Down Expand Up @@ -109,6 +111,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
fprintf(stderr, " -sa, --save-audio [%-7s] save the recorded audio to a file\n", params.save_audio ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU inference\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention during inference\n", params.flash_attn ? "true" : "false");
fprintf(stderr, "\n");
}

Expand Down Expand Up @@ -153,7 +156,9 @@ int main(int argc, char ** argv) {
}

struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu;

cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;

struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);

Expand Down
9 changes: 7 additions & 2 deletions examples/talk-llama/talk-llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ struct whisper_params {
bool no_timestamps = true;
bool verbose_prompt = false;
bool use_gpu = true;
bool flash_attn = false;

std::string person = "Georgi";
std::string bot_name = "LLaMA";
Expand Down Expand Up @@ -105,6 +106,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "-vp" || arg == "--verbose-prompt") { params.verbose_prompt = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
else if (arg == "-bn" || arg == "--bot-name") { params.bot_name = argv[++i]; }
else if (arg == "--session") { params.path_session = argv[++i]; }
Expand All @@ -123,7 +125,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
}
}
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
Expand Down Expand Up @@ -154,6 +155,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
fprintf(stderr, " -vp, --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
fprintf(stderr, " -bn NAME, --bot-name NAME [%-7s] bot name (to display)\n", params.bot_name.c_str());
fprintf(stderr, " -w TEXT, --wake-command T [%-7s] wake-up command to listen for\n", params.wake_cmd.c_str());
Expand Down Expand Up @@ -285,7 +287,9 @@ int main(int argc, char ** argv) {
// whisper init

struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu;

cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;

struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams);
if (!ctx_wsp) {
Expand Down Expand Up @@ -316,6 +320,7 @@ int main(int argc, char ** argv) {
lcparams.n_ctx = 2048;
lcparams.seed = 1;
lcparams.n_threads = params.n_threads;
lcparams.flash_attn = params.flash_attn;

struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lcparams);

Expand Down
7 changes: 6 additions & 1 deletion examples/talk/talk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct whisper_params {
bool print_energy = false;
bool no_timestamps = true;
bool use_gpu = true;
bool flash_attn = false;

std::string person = "Santa";
std::string language = "en";
Expand Down Expand Up @@ -64,6 +65,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
Expand Down Expand Up @@ -99,6 +101,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
Expand Down Expand Up @@ -188,7 +191,9 @@ int main(int argc, char ** argv) {

// whisper init
struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu;

cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;

struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams);

Expand Down
Loading

0 comments on commit 7094ea5

Please sign in to comment.