diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index d885a01b96..a6d25b50bd 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -226,3 +226,6 @@ In chronological order: * Dirreke * [2024-01-16] Add basic support for the CSKY architecture + +* Christopher Daley + * [2024-01-24] Optimize GEMV forwarding on ARM64 systems diff --git a/interface/gemm.c b/interface/gemm.c index 5742d36c4b..576e94593c 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -39,6 +39,7 @@ #include #include +#include #include "common.h" #ifdef FUNCTION_PROFILE #include "functable.h" @@ -499,6 +500,15 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS #endif #if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && (!defined(BFLOAT16) || defined(GEMM_GEMV_FORWARD_BF16)) +#if defined(ARCH_ARM64) + // The gemv kernels in arm64/{gemv_n.S,gemv_n_sve.c,gemv_t.S,gemv_t_sve.c} + // perform poorly in certain circumstances. We use the following boolean + // variable along with the gemv argument values to avoid these inefficient + // gemv cases, see github issue#4951. + bool have_tuned_gemv = false; +#else + bool have_tuned_gemv = true; +#endif // Check if we can convert GEMM -> GEMV if (args.k != 0) { if (args.n == 1) { @@ -518,8 +528,11 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS if (transb & 1) { inc_x = args.ldb; } - GEMV(&NT, &m, &n, args.alpha, args.a, &lda, args.b, &inc_x, args.beta, args.c, &inc_y); - return; + bool is_efficient_gemv = have_tuned_gemv || ((NT == 'N') || (NT == 'T' && inc_x == 1)); + if (is_efficient_gemv) { + GEMV(&NT, &m, &n, args.alpha, args.a, &lda, args.b, &inc_x, args.beta, args.c, &inc_y); + return; + } } if (args.m == 1) { blasint inc_x = args.lda; @@ -538,8 +551,11 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS m = args.n; n = args.k; } - GEMV(&NT, &m, &n, args.alpha, args.b, &ldb, args.a, &inc_x, args.beta, args.c, &inc_y); - return; + bool is_efficient_gemv = have_tuned_gemv || ((NT == 'N' && inc_y == 1) || (NT == 'T' && inc_x == 1)); + if (is_efficient_gemv) { + GEMV(&NT, &m, &n, args.alpha, args.b, &ldb, args.a, &inc_x, args.beta, args.c, &inc_y); + return; + } } } #endif