Skip to content

Commit

Permalink
ocl: Fixed OPENCL_LIBSMM_VALIDATE_SMM
Browse files Browse the repository at this point in the history
* Read-back parameter stack from device (rather than using host_param_stack).
* Cleanup.
  • Loading branch information
hfp committed Oct 11, 2023
1 parent 3da1815 commit 81c6213
Showing 1 changed file with 20 additions and 15 deletions.
35 changes: 20 additions & 15 deletions src/acc/opencl/smm/opencl_libsmm.c
Original file line number Diff line number Diff line change
Expand Up @@ -919,9 +919,9 @@ int libsmm_acc_transpose(const int* dev_trs_stack, int offset, int stack_size, v
imat = (char*)LIBXSMM_UP2((uintptr_t)stack + sizeof(int) * offset_stack_size, LIBXSMM_ALIGNMENT);
omat = (char*)LIBXSMM_UP2((uintptr_t)imat + data_size, LIBXSMM_ALIGNMENT);
gold = (char*)LIBXSMM_UP2((uintptr_t)omat + data_size, LIBXSMM_ALIGNMENT);
ACC_OPENCL_CHECK(
c_dbcsr_acc_memcpy_d2h(dev_trs_stack, stack, sizeof(int) * offset_stack_size, stream), "transfer debug stack", result);
ACC_OPENCL_CHECK(c_dbcsr_acc_memcpy_d2h(dev_data, imat, data_size, stream), "transfer debug input", result);
ACC_OPENCL_CHECK(c_dbcsr_acc_memcpy_d2h(dev_trs_stack, stack, sizeof(int) * offset_stack_size, stream),
"transfer validation stack", result);
ACC_OPENCL_CHECK(c_dbcsr_acc_memcpy_d2h(dev_data, imat, data_size, stream), "transfer validation input", result);
}
else result = EXIT_FAILURE;
}
Expand Down Expand Up @@ -984,7 +984,7 @@ int libsmm_acc_transpose(const int* dev_trs_stack, int offset, int stack_size, v
LIBXSMM_ATOMIC_RELEASE(lock, LIBXSMM_ATOMIC_RELAXED);
}
# if defined(OPENCL_LIBSMM_VALIDATE_TRANS)
ACC_OPENCL_CHECK(c_dbcsr_acc_memcpy_d2h(dev_data, omat, data_size, stream), "transfer debug test", result);
ACC_OPENCL_CHECK(c_dbcsr_acc_memcpy_d2h(dev_data, omat, data_size, stream), "transfer validation test", result);
# endif
# if defined(OPENCL_LIBSMM_VALIDATE_TRANS)
ACC_OPENCL_CHECK(c_dbcsr_acc_stream_sync(stream), "sync stream", result);
Expand Down Expand Up @@ -1653,33 +1653,38 @@ int libsmm_acc_process(const int* host_param_stack, const int* dev_param_stack,
size_t work_size;
# if defined(OPENCL_LIBSMM_VALIDATE_SMM)
/* validate result (implies readback from device and performance penalty) */
int* pinp = NULL;
char *ainp = NULL, *binp = NULL, *test = NULL, *gold = NULL, *btrn = NULL;
const libxsmm_datatype precision =
(dbcsr_type_real_8 == datatype ? LIBXSMM_DATATYPE_F64
: (dbcsr_type_real_4 == datatype ? LIBXSMM_DATATYPE_F32 : LIBXSMM_DATATYPE_UNSUPPORTED));
const int typesize = OPENCL_LIBSMM_TYPESIZE(datatype);
libxsmm_xmmfunction kernel_cpu = {NULL};
size_t asize, bsize, csize;
size_t psize, asize, bsize, csize;
void* scratch = NULL;
if (CL_SUCCESS == clGetMemObjectInfo(*ACC_OPENCL_MEM(dev_a_data), CL_MEM_SIZE, sizeof(size_t), &asize, NULL) &&
if (CL_SUCCESS == clGetMemObjectInfo(*ACC_OPENCL_MEM(dev_param_stack), CL_MEM_SIZE, sizeof(size_t), &psize, NULL) &&
CL_SUCCESS == clGetMemObjectInfo(*ACC_OPENCL_MEM(dev_a_data), CL_MEM_SIZE, sizeof(size_t), &asize, NULL) &&
CL_SUCCESS == clGetMemObjectInfo(*ACC_OPENCL_MEM(dev_b_data), CL_MEM_SIZE, sizeof(size_t), &bsize, NULL) &&
CL_SUCCESS == clGetMemObjectInfo(*ACC_OPENCL_MEM(dev_c_data), CL_MEM_SIZE, sizeof(size_t), &csize, NULL))
{
libxsmm_descriptor_blob blob;
libxsmm_gemm_descriptor* const desc = OPENCL_LIBSMM_DESCINIT(
&blob, precision, m_max, n_max, k_max, m_max, k_max, m_max, LIBXSMM_GEMM_FLAG_NONE, LIBXSMM_PREFETCH_NONE);
const size_t scratch_size = asize + bsize + csize + csize + k_max * n_max * typesize +
4 * (LIBXSMM_ALIGNMENT - 1) /*alignments*/;
scratch = libxsmm_aligned_scratch(scratch_size, LIBXSMM_ALIGNMENT);
const size_t scratch_size = psize + asize + bsize + csize + csize + k_max * n_max * typesize +
5 * (LIBXSMM_ALIGNMENT - 1) /*alignments*/;
scratch = libxsmm_aligned_malloc(scratch_size, LIBXSMM_ALIGNMENT);
if (NULL != desc && NULL != scratch) {
ainp = (char*)scratch;
pinp = (int*)scratch;
ainp = (char*)LIBXSMM_UP2((uintptr_t)pinp + psize, LIBXSMM_ALIGNMENT);
binp = (char*)LIBXSMM_UP2((uintptr_t)ainp + asize, LIBXSMM_ALIGNMENT);
test = (char*)LIBXSMM_UP2((uintptr_t)binp + bsize, LIBXSMM_ALIGNMENT);
gold = (char*)LIBXSMM_UP2((uintptr_t)test + csize, LIBXSMM_ALIGNMENT);
btrn = (char*)LIBXSMM_UP2((uintptr_t)gold + csize, LIBXSMM_ALIGNMENT);
ACC_OPENCL_CHECK(c_dbcsr_acc_memcpy_d2h(dev_a_data, ainp, asize, stream), "transfer debug a-data", result);
ACC_OPENCL_CHECK(c_dbcsr_acc_memcpy_d2h(dev_b_data, binp, bsize, stream), "transfer debug b-data", result);
ACC_OPENCL_CHECK(c_dbcsr_acc_memcpy_d2h(dev_c_data, gold, csize, stream), "transfer debug c-data", result);
ACC_OPENCL_CHECK(
c_dbcsr_acc_memcpy_d2h(dev_param_stack, pinp, psize, stream), "transfer validation param-data", result);
ACC_OPENCL_CHECK(c_dbcsr_acc_memcpy_d2h(dev_a_data, ainp, asize, stream), "transfer validation a-data", result);
ACC_OPENCL_CHECK(c_dbcsr_acc_memcpy_d2h(dev_b_data, binp, bsize, stream), "transfer validation b-data", result);
ACC_OPENCL_CHECK(c_dbcsr_acc_memcpy_d2h(dev_c_data, gold, csize, stream), "transfer validation c-data", result);
kernel_cpu = libxsmm_xmmdispatch(desc);
assert(NULL != kernel_cpu.xmm);
}
Expand Down Expand Up @@ -1754,12 +1759,12 @@ int libsmm_acc_process(const int* host_param_stack, const int* dev_param_stack,
}
# endif
# if defined(OPENCL_LIBSMM_VALIDATE_SMM)
ACC_OPENCL_CHECK(c_dbcsr_acc_memcpy_d2h(dev_c_data, test, csize, stream), "transfer debug test", result);
ACC_OPENCL_CHECK(c_dbcsr_acc_memcpy_d2h(dev_c_data, test, csize, stream), "transfer validation test", result);
ACC_OPENCL_CHECK(c_dbcsr_acc_stream_sync(stream), "sync stream", result);
if (EXIT_SUCCESS == result) {
const char* const env_tol = getenv("OPENCL_LIBSMM_SMM_TOLERANCE");
const double tolerance = ((NULL == env_tol || '\0' == *env_tol) ? 1E-3 : atof(env_tol));
const int* const params = host_param_stack + (4 <= nparams ? (nparams - 4) : 0);
const int* const params = pinp + (4 <= nparams ? (nparams - 4) : 0);
size_t i;
LIBXSMM_STDIO_ACQUIRE();
if (0 != c_dbcsr_acc_opencl_config.verbosity) {
Expand Down

0 comments on commit 81c6213

Please sign in to comment.