Skip to content

Commit

Permalink
Merge pull request #1 from mgates3/gemqrt
Browse files Browse the repository at this point in the history
gemqrt
  • Loading branch information
TeachRaccooon authored Aug 25, 2023
2 parents a692dba + 3ece4b5 commit dc2b5e3
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
4 changes: 4 additions & 0 deletions test/lapacke_wrappers.hh
Original file line number Diff line number Diff line change
Expand Up @@ -1171,6 +1171,8 @@ inline lapack_int LAPACKE_gemqrt(
float* T, lapack_int ldt,
float* C, lapack_int ldc )
{
if (trans == 'C')
trans = 'T';
return LAPACKE_sgemqrt(
LAPACK_COL_MAJOR, side, trans, m, n, k, nb,
V, ldv,
Expand All @@ -1184,6 +1186,8 @@ inline lapack_int LAPACKE_gemqrt(
double* T, lapack_int ldt,
double* C, lapack_int ldc )
{
if (trans == 'C')
trans = 'T';
return LAPACKE_dgemqrt(
LAPACK_COL_MAJOR, side, trans, m, n, k, nb,
V, ldv,
Expand Down
3 changes: 2 additions & 1 deletion test/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,8 @@ def filter_csv( values, csv ):
[ 'orhr_col', gen + dtype_real + align + n + tall ],
[ 'unhr_col', gen + dtype + align + n + tall ],

[ 'gemqrt', gen + dtype + align + n + nb + side + trans ],
[ 'gemqrt', gen + dtype_real + align + n + nb + side + trans ], # real does trans = N, T, C
[ 'gemqrt', gen + dtype_complex + align + n + nb + side + trans_nc ], # complex does trans = N, C, not T

# Triangle-pentagon
[ 'tpqrt', gen + dtype + align + mn + l + nb ],
Expand Down
20 changes: 13 additions & 7 deletions test/test_gemqrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ template< typename scalar_t >
void test_gemqrt_work( Params& params, bool run )
{
using real_t = blas::real_type< scalar_t >;
typedef long long lld;
using blas::min;
using blas::max;

// get & mark input values
lapack::Side side = params.side();
Expand All @@ -30,6 +31,12 @@ void test_gemqrt_work( Params& params, bool run )
int64_t nb = params.nb();
int64_t align = params.align();

// geqrt requires min( m, n ) >= nb; use nb = 0.5 min( m, n ).
if (nb > min( m, n )) {
nb = max( 1, min( m, n ) / 2 );
params.nb() = nb;
}

// mark non-standard output values
params.ref_time();
params.ref_gflops();
Expand All @@ -42,7 +49,8 @@ void test_gemqrt_work( Params& params, bool run )
int64_t ldv;
if (side == lapack::Side::Right) {
ldv = roundup( blas::max( 1, m ), align );
} else {
}
else {
ldv = roundup( blas::max( 1, n ), align );
}
int64_t ldt = roundup( nb, align );
Expand All @@ -65,15 +73,14 @@ void test_gemqrt_work( Params& params, bool run )

// Calling this to set up the matrices
lapack::geqrt( m, n, nb, &V[0], ldv, &T[0], ldt );
lapack::gemqrt( side, trans, m, n, k, nb, &V[0], ldv, &T[0], ldt, &C_tst[0], ldc );
/*

//---------- run test
testsweeper::flush_cache( params.cache() );
double time = testsweeper::get_wtime();
int64_t info_tst = lapack::gemqrt( side, trans, m, n, k, nb, &V[0], ldv, &T[0], ldt, &C_tst[0], ldc );
time = testsweeper::get_wtime() - time;
if (info_tst != 0) {
fprintf( stderr, "lapack::gemqrt returned error %lld\n", (lld) info_tst );
fprintf( stderr, "lapack::gemqrt returned error %lld\n", llong( info_tst ) );
}

params.time() = time;
Expand All @@ -87,7 +94,7 @@ void test_gemqrt_work( Params& params, bool run )
int64_t info_ref = LAPACKE_gemqrt( side2char(side), op2char(trans), m, n, k, nb, &V[0], ldv, &T[0], ldt, &C_ref[0], ldc );
time = testsweeper::get_wtime() - time;
if (info_ref != 0) {
fprintf( stderr, "LAPACKE_gemqrt returned error %lld\n", (lld) info_ref );
fprintf( stderr, "LAPACKE_gemqrt returned error %lld\n", llong( info_ref ) );
}

params.ref_time() = time;
Expand All @@ -102,7 +109,6 @@ void test_gemqrt_work( Params& params, bool run )
params.error() = error;
params.okay() = (error == 0); // expect lapackpp == lapacke
}
*/
}

#endif // LAPACK >= 3.4.0
Expand Down

0 comments on commit dc2b5e3

Please sign in to comment.