Skip to content

Commit

Permalink
optimize gemv forwarding on ARM64 systems
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Daley committed Oct 25, 2024
1 parent 72461f1 commit cb48505
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,6 @@ In chronological order:

* Dirreke <https://github.com/mseminatore>
* [2024-01-16] Add basic support for the CSKY architecture

* Christopher Daley <https://github.com/cdaley>
* [2024-01-24] Optimize GEMV forwarding on ARM64 systems
24 changes: 20 additions & 4 deletions interface/gemm.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>
#include "common.h"
#ifdef FUNCTION_PROFILE
#include "functable.h"
Expand Down Expand Up @@ -499,6 +500,15 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
#endif

#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && (!defined(BFLOAT16) || defined(GEMM_GEMV_FORWARD_BF16))
#if defined(ARCH_ARM64)
// The gemv kernels in arm64/{gemv_n.S,gemv_n_sve.c,gemv_t.S,gemv_t_sve.c}
// perform poorly in certain circumstances. We use the following boolean
// variable along with the gemv argument values to avoid these inefficient
// gemv cases, see github issue#4951.
bool have_tuned_gemv = false;
#else
bool have_tuned_gemv = true;
#endif
// Check if we can convert GEMM -> GEMV
if (args.k != 0) {
if (args.n == 1) {
Expand All @@ -518,8 +528,11 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
if (transb & 1) {
inc_x = args.ldb;
}
GEMV(&NT, &m, &n, args.alpha, args.a, &lda, args.b, &inc_x, args.beta, args.c, &inc_y);
return;
bool is_efficient_gemv = have_tuned_gemv || ((NT == 'N') || (NT == 'T' && inc_x == 1));
if (is_efficient_gemv) {
GEMV(&NT, &m, &n, args.alpha, args.a, &lda, args.b, &inc_x, args.beta, args.c, &inc_y);
return;
}
}
if (args.m == 1) {
blasint inc_x = args.lda;
Expand All @@ -538,8 +551,11 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
m = args.n;
n = args.k;
}
GEMV(&NT, &m, &n, args.alpha, args.b, &ldb, args.a, &inc_x, args.beta, args.c, &inc_y);
return;
bool is_efficient_gemv = have_tuned_gemv || ((NT == 'N' && inc_y == 1) || (NT == 'T' && inc_x == 1));
if (is_efficient_gemv) {
GEMV(&NT, &m, &n, args.alpha, args.b, &ldb, args.a, &inc_x, args.beta, args.c, &inc_y);
return;
}
}
}
#endif
Expand Down

0 comments on commit cb48505

Please sign in to comment.