Skip to content

Commit

Permalink
Merge pull request #5052 from ye-luo/expand-gemv-ger-tests
Browse files Browse the repository at this point in the history
Expand gemv ger tests
  • Loading branch information
prckent authored Jun 21, 2024
2 parents 2bfdcec + 890db02 commit 935919f
Show file tree
Hide file tree
Showing 9 changed files with 644 additions and 302 deletions.
134 changes: 134 additions & 0 deletions src/Platforms/CUDA/AccelBLAS_CUDA.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,80 @@ inline void gemm(BLASHandle<PlatformKind::CUDA>& handle,
"cublasZgemm failed!");
}

inline void gemv(BLASHandle<PlatformKind::CUDA>& handle,
const char trans,
const int m,
const int n,
const float& alpha,
const float* const A,
const int lda,
const float* const x,
const int incx,
const float& beta,
float* const y,
const int incy)
{
cublasErrorCheck(cublasSgemv(handle.h_cublas, cuBLAS::convertOperation(trans), m, n, &alpha, A, lda, x, incx, &beta,
y, incy),
"cublasSgemv failed!");
}

inline void gemv(BLASHandle<PlatformKind::CUDA>& handle,
const char trans,
const int m,
const int n,
const double& alpha,
const double* const A,
const int lda,
const double* const x,
const int incx,
const double& beta,
double* const y,
const int incy)
{
cublasErrorCheck(cublasDgemv(handle.h_cublas, cuBLAS::convertOperation(trans), m, n, &alpha, A, lda, x, incx, &beta,
y, incy),
"cublasDgemv failed!");
}

inline void gemv(BLASHandle<PlatformKind::CUDA>& handle,
const char trans,
const int m,
const int n,
const std::complex<float>& alpha,
const std::complex<float>* A,
const int lda,
const std::complex<float>* x,
const int incx,
const std::complex<float>& beta,
std::complex<float>* y,
const int incy)
{
cublasErrorCheck(cublasCgemv(handle.h_cublas, cuBLAS::convertOperation(trans), m, n, castNativeType(&alpha),
castNativeType(A), lda, castNativeType(x), incx, castNativeType(&beta),
castNativeType(y), incy),
"cublasCgemv failed!");
}

inline void gemv(BLASHandle<PlatformKind::CUDA>& handle,
const char trans,
const int m,
const int n,
const std::complex<double>& alpha,
const std::complex<double>* A,
const int lda,
const std::complex<double>* x,
const int incx,
const std::complex<double>& beta,
std::complex<double>* y,
const int incy)
{
cublasErrorCheck(cublasZgemv(handle.h_cublas, cuBLAS::convertOperation(trans), m, n, castNativeType(&alpha),
castNativeType(A), lda, castNativeType(x), incx, castNativeType(&beta),
castNativeType(y), incy),
"cublasZgemv failed!");
}

template<typename T>
inline void gemv_batched(BLASHandle<PlatformKind::CUDA>& handle,
const char trans,
Expand All @@ -148,6 +222,66 @@ inline void gemv_batched(BLASHandle<PlatformKind::CUDA>& handle,
"cuBLAS_MFs::gemv_batched failed!");
}

inline void ger(BLASHandle<PlatformKind::CUDA>& handle,
const int m,
const int n,
const float& alpha,
const float* const x,
const int incx,
const float* const y,
const int incy,
float* const A,
const int lda)
{
cublasErrorCheck(cublasSger(handle.h_cublas, m, n, &alpha, x, incx, y, incy, A, lda), "cublasSger failed!");
}

inline void ger(BLASHandle<PlatformKind::CUDA>& handle,
const int m,
const int n,
const double& alpha,
const double* const x,
const int incx,
const double* const y,
const int incy,
double* const A,
const int lda)
{
cublasErrorCheck(cublasDger(handle.h_cublas, m, n, &alpha, x, incx, y, incy, A, lda), "cublasDger failed!");
}

inline void ger(BLASHandle<PlatformKind::CUDA>& handle,
const int m,
const int n,
const std::complex<float>& alpha,
const std::complex<float>* x,
const int incx,
const std::complex<float>* y,
const int incy,
std::complex<float>* A,
const int lda)
{
cublasErrorCheck(cublasCgeru(handle.h_cublas, m, n, castNativeType(&alpha), castNativeType(x), incx,
castNativeType(y), incy, castNativeType(A), lda),
"cublasCger failed!");
}

inline void ger(BLASHandle<PlatformKind::CUDA>& handle,
const int m,
const int n,
const std::complex<double>& alpha,
const std::complex<double>* x,
const int incx,
const std::complex<double>* y,
const int incy,
std::complex<double>* A,
const int lda)
{
cublasErrorCheck(cublasZgeru(handle.h_cublas, m, n, castNativeType(&alpha), castNativeType(x), incx,
castNativeType(y), incy, castNativeType(A), lda),
"cublasZger failed!");
}

template<typename T>
inline void ger_batched(BLASHandle<PlatformKind::CUDA>& handle,
const int m,
Expand Down
34 changes: 34 additions & 0 deletions src/Platforms/OMPTarget/AccelBLAS_OMPTarget.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,24 @@ inline void gemm_batched(BLASHandle<PlatformKind::OMPTARGET>& handle,
}


template<typename T>
inline void gemv(BLASHandle<PlatformKind::OMPTARGET>& handle,
const char trans,
const int m,
const int n,
const T& alpha,
const T* const A,
const int lda,
const T* const x,
const int incx,
const T& beta,
T* const y,
const int incy)
{
if (ompBLAS::gemv(handle.h_ompblas, trans, m, n, alpha, A, lda, x, incx, beta, y, incy) != 0)
throw std::runtime_error("ompBLAS::gemv_batched failed!");
}

template<typename T>
inline void gemv_batched(BLASHandle<PlatformKind::OMPTARGET>& handle,
const char trans,
Expand All @@ -92,6 +110,22 @@ inline void gemv_batched(BLASHandle<PlatformKind::OMPTARGET>& handle,
throw std::runtime_error("ompBLAS::gemv_batched failed!");
}

template<typename T>
inline void ger(BLASHandle<PlatformKind::OMPTARGET>& handle,
const int m,
const int n,
const T& alpha,
const T* const x,
const int incx,
const T* const y,
const int incy,
T* const A,
const int lda)
{
if (ompBLAS::ger(handle.h_ompblas, m, n, alpha, x, incx, y, incy, A, lda) != 0)
throw std::runtime_error("ompBLAS::ger_batched failed!");
}

template<typename T>
inline void ger_batched(BLASHandle<PlatformKind::OMPTARGET>& handle,
const int m,
Expand Down
8 changes: 8 additions & 0 deletions src/Platforms/ROCm/cuda2hip.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,29 @@
#define cublasGetStream hipblasGetStream
#define cublasOperation_t hipblasOperation_t
#define cublasCgeam hipblasCgeam
#define cublasCgemv hipblasCgemv
#define cublasCgeru hipblasCgeru
#define cublasCgemm hipblasCgemm
#define cublasCgemmBatched hipblasCgemmBatched
#define cublasCgetrfBatched hipblasCgetrfBatched_
#define cublasCgetriBatched hipblasCgetriBatched_
#define cublasDgeam hipblasDgeam
#define cublasDgemv hipblasDgemv
#define cublasDger hipblasDger
#define cublasDgemm hipblasDgemm
#define cublasDgemmBatched hipblasDgemmBatched
#define cublasDgetrfBatched hipblasDgetrfBatched_
#define cublasDgetriBatched hipblasDgetriBatched_
#define cublasSgeam hipblasSgeam
#define cublasSgemv hipblasSgemv
#define cublasSger hipblasSger
#define cublasSgemm hipblasSgemm
#define cublasSgemmBatched hipblasSgemmBatched
#define cublasSgetrfBatched hipblasSgetrfBatched_
#define cublasSgetriBatched hipblasSgetriBatched_
#define cublasZgeam hipblasZgeam
#define cublasZgemv hipblasZgemv
#define cublasZgeru hipblasZgeru
#define cublasZgemm hipblasZgemm
#define cublasZgemmBatched hipblasZgemmBatched
#define cublasZgetrfBatched hipblasZgetrfBatched_
Expand Down
56 changes: 52 additions & 4 deletions src/Platforms/SYCL/AccelBLAS_SYCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,31 @@ inline void gemm(BLASHandle<PlatformKind::SYCL>& handle,
}
}

template<typename T>
inline void gemv(BLASHandle<PlatformKind::SYCL>& handle,
const char trans,
const int m,
const int n,
const T& alpha,
const T* const A,
const int lda,
const T* const x,
const int incx,
const T& beta,
T* const y,
const int incy)
{
try
{
oneapi::mkl::blas::gemv(handle.queue_, syclBLAS::convertTransEnum(trans), m, n, alpha, A, lda, x, incx, beta, y,
incy);
}
catch (oneapi::mkl::exception& e)
{
throw std::runtime_error(std::string("AccelBLAS::gemv exception: ") + e.what());
}
}

template<typename T>
inline void gemv_batched(BLASHandle<PlatformKind::SYCL>& handle,
const char trans,
Expand All @@ -81,6 +106,28 @@ inline void gemv_batched(BLASHandle<PlatformKind::SYCL>& handle,
}
}

template<typename T>
inline void ger(BLASHandle<PlatformKind::SYCL>& handle,
const int m,
const int n,
const T& alpha,
const T* const x,
const int incx,
const T* const y,
const int incy,
T* const A,
const int lda)
{
try
{
oneapi::mkl::blas::ger(handle.queue_, m, n, alpha, x, incx, y, incy, A, lda);
}
catch (oneapi::mkl::exception& e)
{
throw std::runtime_error(std::string("AccelBLAS::ger exception: ") + e.what());
}
}

template<typename T>
inline void ger_batched(BLASHandle<PlatformKind::SYCL>& handle,
const int m,
Expand Down Expand Up @@ -116,7 +163,8 @@ inline void copy_batched(BLASHandle<PlatformKind::SYCL>& handle,
try
{
syclBLAS::syclBLAS_int bc = batch_count;
oneapi::mkl::blas::copy_batch(handle.queue_, &n, const_cast<const T**>(in), &incx, const_cast<T**>(out), &incy, 1, &bc);
oneapi::mkl::blas::copy_batch(handle.queue_, &n, const_cast<const T**>(in), &incx, const_cast<T**>(out), &incy, 1,
&bc);
}
catch (oneapi::mkl::exception& e)
{
Expand Down Expand Up @@ -154,9 +202,9 @@ inline void gemm_batched(BLASHandle<PlatformKind::SYCL>& handle,
oneapi::mkl::blas::gemm_batch(handle.queue_, sycl::span{&trans_a, 1}, sycl::span{&trans_b, 1}, sycl::span{&m, 1},
sycl::span{&n, 1}, sycl::span{&k, 1}, alpha_span,
sycl::span{const_cast<const T**>(A), batch_count}, sycl::span{&lda, 1},
sycl::span{const_cast<const T**>(B), batch_count}, sycl::span{&ldb, 1},
beta_span, sycl::span{const_cast<T**>(C), batch_count},
sycl::span{&ldc, 1}, 1, sycl::span{const_cast<size_t*>(&batch_count), 1});
sycl::span{const_cast<const T**>(B), batch_count}, sycl::span{&ldb, 1}, beta_span,
sycl::span{const_cast<T**>(C), batch_count}, sycl::span{&ldc, 1}, 1,
sycl::span{const_cast<size_t*>(&batch_count), 1});
sycl::free(alpha_span.data(), handle.queue_);
sycl::free(beta_span.data(), handle.queue_);
#else
Expand Down
2 changes: 1 addition & 1 deletion src/Platforms/tests/OMPTarget/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
set(UTEST_EXE test_omptarget)
set(UTEST_NAME deterministic-unit_${UTEST_EXE})

add_executable(${UTEST_EXE} test_vector.cpp test_math.cpp test_deep_copy.cpp test_class_member.cpp test_runtime_mem.cpp)
add_executable(${UTEST_EXE} test_math.cpp test_deep_copy.cpp test_class_member.cpp test_runtime_mem.cpp)
target_link_libraries(${UTEST_EXE} platform_runtime catch_main)

add_unit_test(${UTEST_NAME} 1 1 $<TARGET_FILE:${UTEST_EXE}>)
Expand Down
24 changes: 12 additions & 12 deletions src/Platforms/tests/OMPTarget/test_ompBLAS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ void test_gemm(const int M, const int N, const int K, const char transa, const c
// transa/transb == 'T'/'T': C[N,M] = A[M,K] * B[K,N]; C = B^t * A^t

// alpha 0.5, beta 0
ompBLAS::gemm(handle, transa, transb, M, N, K, alpha_half, A.device_data(), a1, B.device_data(), b1, beta, C.device_data(),
M);
ompBLAS::gemm(handle, transa, transb, M, N, K, alpha_half, A.device_data(), a1, B.device_data(), b1, beta,
C.device_data(), M);
// alpha 0.5, beta 1
ompBLAS::gemm(handle, transa, transb, M, N, K, alpha_half, A.device_data(), a1, B.device_data(), b1, beta1, C.device_data(),
M);
ompBLAS::gemm(handle, transa, transb, M, N, K, alpha_half, A.device_data(), a1, B.device_data(), b1, beta1,
C.device_data(), M);
C.updateFrom();

BLAS::gemm(transa, transb, M, N, K, alpha, A.data(), a1, B.data(), b1, beta, D.data(), M);
Expand Down Expand Up @@ -118,11 +118,11 @@ void test_gemm(const int M, const int N, const int K, const char transa, const c
Carr.updateTo();

// alpha 0.5, beta 0
ompBLAS::gemm_batched(handle, transa, transb, M, N, K, alpha_half, Aarr.device_data(), a1, Barr.device_data(), b1, beta,
Carr.device_data(), M, 2);
ompBLAS::gemm_batched(handle, transa, transb, M, N, K, alpha_half, Aarr.device_data(), a1, Barr.device_data(), b1,
beta, Carr.device_data(), M, 2);
// alpha 0.5, beta 1
ompBLAS::gemm_batched(handle, transa, transb, M, N, K, alpha_half, Aarr.device_data(), a1, Barr.device_data(), b1, beta1,
Carr.device_data(), M, 2);
ompBLAS::gemm_batched(handle, transa, transb, M, N, K, alpha_half, Aarr.device_data(), a1, Barr.device_data(), b1,
beta1, Carr.device_data(), M, 2);
C.updateFrom();
C2.updateFrom();

Expand All @@ -143,7 +143,7 @@ void test_gemm(const int M, const int N, const int K, const char transa, const c
}
}

TEST_CASE("OmpBLAS gemm", "[OMP]")
TEST_CASE("ompBLAS gemm", "[OMP]")
{
const int M = 37;
const int N = 71;
Expand Down Expand Up @@ -341,7 +341,7 @@ void test_gemv_batched(const int M_b, const int N_b, const char trans, const int
}
}

TEST_CASE("OmpBLAS gemv", "[OMP]")
TEST_CASE("ompBLAS gemv", "[OMP]")
{
const int M = 137;
const int N = 79;
Expand All @@ -365,7 +365,7 @@ TEST_CASE("OmpBLAS gemv", "[OMP]")
#endif
}

TEST_CASE("OmpBLAS gemv notrans", "[OMP]")
TEST_CASE("ompBLAS gemv notrans", "[OMP]")
{
const int M = 137;
const int N = 79;
Expand Down Expand Up @@ -528,7 +528,7 @@ void test_ger_batched(const int M, const int N, const int batch_count)
}
}

TEST_CASE("OmpBLAS ger", "[OMP]")
TEST_CASE("ompBLAS ger", "[OMP]")
{
const int M = 137;
const int N = 79;
Expand Down
Loading

0 comments on commit 935919f

Please sign in to comment.