From be5919c269c438d12484f3fbd60d3c2146d65a51 Mon Sep 17 00:00:00 2001 From: yasahi-hpc <57478230+yasahi-hpc@users.noreply.github.com> Date: Tue, 10 Sep 2024 13:51:20 +0200 Subject: [PATCH] Apply check functions to fft functions (#130) * Improve assertions in fft functions * format * use is_complex_v * fix: typo * using string_view and remove maybe_unused from assertion helper --------- Co-authored-by: Yuuichi Asahi --- common/src/KokkosFFT_asserts.hpp | 54 ++++++++++++++++++++++++++ common/src/KokkosFFT_utils.hpp | 37 +----------------- fft/src/KokkosFFT_Cuda_plans.hpp | 24 +++++++----- fft/src/KokkosFFT_Cuda_transform.hpp | 21 ++++------ fft/src/KokkosFFT_HIP_plans.hpp | 31 +++++++-------- fft/src/KokkosFFT_HIP_transform.hpp | 21 ++++------ fft/src/KokkosFFT_Plans.hpp | 58 ++++++++++++++-------------- fft/src/KokkosFFT_ROCM_plans.hpp | 35 +++++++++-------- fft/src/KokkosFFT_ROCM_transform.hpp | 27 ++++++------- fft/src/KokkosFFT_Transform.hpp | 40 +++++++++++++------ fft/unit_test/Test_Transform.cpp | 4 +- 11 files changed, 189 insertions(+), 163 deletions(-) create mode 100644 common/src/KokkosFFT_asserts.hpp diff --git a/common/src/KokkosFFT_asserts.hpp b/common/src/KokkosFFT_asserts.hpp new file mode 100644 index 00000000..3328e4eb --- /dev/null +++ b/common/src/KokkosFFT_asserts.hpp @@ -0,0 +1,54 @@ +// SPDX-FileCopyrightText: (C) The Kokkos-FFT development team, see COPYRIGHT.md file +// +// SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception + +#ifndef KOKKOSFFT_ASSERTS_HPP +#define KOKKOSFFT_ASSERTS_HPP + +#include +#include +#include + +#if defined(__cpp_lib_source_location) && __cpp_lib_source_location >= 201907L +#include +#define KOKKOSFFT_EXPECTS(expression, msg) \ + KokkosFFT::Impl::check_precondition( \ + (expression), msg, std::source_location::current().file_name(), \ + std::source_location::current().line(), \ + std::source_location::current().function_name(), \ + std::source_location::current().column()) +#else +#include +#define KOKKOSFFT_EXPECTS(expression, msg) \ + KokkosFFT::Impl::check_precondition((expression), msg, __FILE__, __LINE__, \ + __FUNCTION__) +#endif + +namespace KokkosFFT { +namespace Impl { + +inline void check_precondition(const bool expression, + const std::string_view& msg, + const char* file_name, int line, + const char* function_name, + const int column = -1) { + // Quick return if possible + if (expression) return; + + std::stringstream ss("file: "); + if (column == -1) { + // For C++ 17 + ss << file_name << '(' << line << ") `" << function_name << "`: " << msg + << '\n'; + } else { + // For C++ 20 and later + ss << file_name << '(' << line << ':' << column << ") `" << function_name + << "`: " << msg << '\n'; + } + throw std::runtime_error(ss.str()); +} + +} // namespace Impl +} // namespace KokkosFFT + +#endif diff --git a/common/src/KokkosFFT_utils.hpp b/common/src/KokkosFFT_utils.hpp index a0fae657..5e56a1ae 100644 --- a/common/src/KokkosFFT_utils.hpp +++ b/common/src/KokkosFFT_utils.hpp @@ -10,48 +10,13 @@ #include #include #include +#include "KokkosFFT_asserts.hpp" #include "KokkosFFT_traits.hpp" #include "KokkosFFT_common_types.hpp" -#if defined(__cpp_lib_source_location) && __cpp_lib_source_location >= 201907L -#include -#define KOKKOSFFT_EXPECTS(expression, msg) \ - KokkosFFT::Impl::check_precondition( \ - (expression), msg, std::source_location::current().file_name(), \ - std::source_location::current().line(), \ - std::source_location::current().function_name(), \ - std::source_location::current().column()) -#else -#include -#define KOKKOSFFT_EXPECTS(expression, msg) \ - KokkosFFT::Impl::check_precondition((expression), msg, __FILE__, __LINE__, \ - __FUNCTION__) -#endif - namespace KokkosFFT { namespace Impl { -inline void check_precondition(const bool expression, - [[maybe_unused]] const std::string& msg, - [[maybe_unused]] const char* file_name, int line, - [[maybe_unused]] const char* function_name, - [[maybe_unused]] const int column = -1) { - // Quick return if possible - if (expression) return; - - std::stringstream ss("file: "); - if (column == -1) { - // For C++ 17 - ss << file_name << '(' << line << ") `" << function_name << "`: " << msg - << '\n'; - } else { - // For C++ 20 and later - ss << file_name << '(' << line << ':' << column << ") `" << function_name - << "`: " << msg << '\n'; - } - throw std::runtime_error(ss.str()); -} - template auto convert_negative_axis(ViewType, int _axis = -1) { static_assert(Kokkos::is_view_v, diff --git a/fft/src/KokkosFFT_Cuda_plans.hpp b/fft/src/KokkosFFT_Cuda_plans.hpp index 935be540..5d9531b1 100644 --- a/fft/src/KokkosFFT_Cuda_plans.hpp +++ b/fft/src/KokkosFFT_Cuda_plans.hpp @@ -8,6 +8,7 @@ #include #include "KokkosFFT_Cuda_types.hpp" #include "KokkosFFT_layouts.hpp" +#include "KokkosFFT_asserts.hpp" namespace KokkosFFT { namespace Impl { @@ -30,7 +31,7 @@ auto create_plan(const ExecutionSpace& exec_space, plan = std::make_unique(); cufftResult cufft_rt = cufftCreate(&(*plan)); - if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftCreate failed"); + KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftCreate failed"); cudaStream_t stream = exec_space.cuda_stream(); cufftSetStream((*plan), stream); @@ -44,7 +45,8 @@ auto create_plan(const ExecutionSpace& exec_space, std::multiplies<>()); cufft_rt = cufftPlan1d(&(*plan), nx, type, howmany); - if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftPlan1d failed"); + KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftPlan1d failed"); + return fft_size; } @@ -67,7 +69,7 @@ auto create_plan(const ExecutionSpace& exec_space, plan = std::make_unique(); cufftResult cufft_rt = cufftCreate(&(*plan)); - if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftCreate failed"); + KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftCreate failed"); cudaStream_t stream = exec_space.cuda_stream(); cufftSetStream((*plan), stream); @@ -81,7 +83,8 @@ auto create_plan(const ExecutionSpace& exec_space, std::multiplies<>()); cufft_rt = cufftPlan2d(&(*plan), nx, ny, type); - if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftPlan2d failed"); + KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftPlan2d failed"); + return fft_size; } @@ -104,7 +107,7 @@ auto create_plan(const ExecutionSpace& exec_space, plan = std::make_unique(); cufftResult cufft_rt = cufftCreate(&(*plan)); - if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftCreate failed"); + KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftCreate failed"); cudaStream_t stream = exec_space.cuda_stream(); cufftSetStream((*plan), stream); @@ -120,7 +123,8 @@ auto create_plan(const ExecutionSpace& exec_space, std::multiplies<>()); cufft_rt = cufftPlan3d(&(*plan), nx, ny, nz, type); - if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftPlan3d failed"); + KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftPlan3d failed"); + return fft_size; } @@ -163,7 +167,7 @@ auto create_plan(const ExecutionSpace& exec_space, plan = std::make_unique(); cufftResult cufft_rt = cufftCreate(&(*plan)); - if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftCreate failed"); + KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftCreate failed"); cudaStream_t stream = exec_space.cuda_stream(); cufftSetStream((*plan), stream); @@ -171,8 +175,8 @@ auto create_plan(const ExecutionSpace& exec_space, cufft_rt = cufftPlanMany(&(*plan), rank, fft_extents.data(), in_extents.data(), istride, idist, out_extents.data(), ostride, odist, type, howmany); - if (cufft_rt != CUFFT_SUCCESS) - throw std::runtime_error("cufftPlanMany failed"); + + KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftPlanMany failed"); return fft_size; } @@ -186,4 +190,4 @@ void destroy_plan_and_info(std::unique_ptr& plan, InfoType&) { } // namespace Impl } // namespace KokkosFFT -#endif \ No newline at end of file +#endif diff --git a/fft/src/KokkosFFT_Cuda_transform.hpp b/fft/src/KokkosFFT_Cuda_transform.hpp index 8878039c..18fd818f 100644 --- a/fft/src/KokkosFFT_Cuda_transform.hpp +++ b/fft/src/KokkosFFT_Cuda_transform.hpp @@ -6,6 +6,7 @@ #define KOKKOSFFT_CUDA_TRANSFORM_HPP #include +#include "KokkosFFT_asserts.hpp" namespace KokkosFFT { namespace Impl { @@ -13,50 +14,44 @@ template inline void exec_plan(cufftHandle& plan, cufftReal* idata, cufftComplex* odata, int /*direction*/, Args...) { cufftResult cufft_rt = cufftExecR2C(plan, idata, odata); - if (cufft_rt != CUFFT_SUCCESS) - throw std::runtime_error("cufftExecR2C failed"); + KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecR2C failed"); } template inline void exec_plan(cufftHandle& plan, cufftDoubleReal* idata, cufftDoubleComplex* odata, int /*direction*/, Args...) { cufftResult cufft_rt = cufftExecD2Z(plan, idata, odata); - if (cufft_rt != CUFFT_SUCCESS) - throw std::runtime_error("cufftExecD2Z failed"); + KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecD2Z failed"); } template inline void exec_plan(cufftHandle& plan, cufftComplex* idata, cufftReal* odata, int /*direction*/, Args...) { cufftResult cufft_rt = cufftExecC2R(plan, idata, odata); - if (cufft_rt != CUFFT_SUCCESS) - throw std::runtime_error("cufftExecC2R failed"); + KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecC2R failed"); } template inline void exec_plan(cufftHandle& plan, cufftDoubleComplex* idata, cufftDoubleReal* odata, int /*direction*/, Args...) { cufftResult cufft_rt = cufftExecZ2D(plan, idata, odata); - if (cufft_rt != CUFFT_SUCCESS) - throw std::runtime_error("cufftExecZ2D failed"); + KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecZ2D failed"); } template inline void exec_plan(cufftHandle& plan, cufftComplex* idata, cufftComplex* odata, int direction, Args...) { cufftResult cufft_rt = cufftExecC2C(plan, idata, odata, direction); - if (cufft_rt != CUFFT_SUCCESS) - throw std::runtime_error("cufftExecC2C failed"); + KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecC2C failed"); } template inline void exec_plan(cufftHandle& plan, cufftDoubleComplex* idata, cufftDoubleComplex* odata, int direction, Args...) { cufftResult cufft_rt = cufftExecZ2Z(plan, idata, odata, direction); - if (cufft_rt != CUFFT_SUCCESS) - throw std::runtime_error("cufftExecZ2Z failed"); + KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecZ2Z failed"); } } // namespace Impl } // namespace KokkosFFT -#endif \ No newline at end of file +#endif diff --git a/fft/src/KokkosFFT_HIP_plans.hpp b/fft/src/KokkosFFT_HIP_plans.hpp index bc2b3386..1c9405c4 100644 --- a/fft/src/KokkosFFT_HIP_plans.hpp +++ b/fft/src/KokkosFFT_HIP_plans.hpp @@ -8,6 +8,7 @@ #include #include "KokkosFFT_HIP_types.hpp" #include "KokkosFFT_layouts.hpp" +#include "KokkosFFT_asserts.hpp" namespace KokkosFFT { namespace Impl { @@ -30,8 +31,7 @@ auto create_plan(const ExecutionSpace& exec_space, plan = std::make_unique(); hipfftResult hipfft_rt = hipfftCreate(&(*plan)); - if (hipfft_rt != HIPFFT_SUCCESS) - throw std::runtime_error("hipfftCreate failed"); + KOKKOSFFT_EXPECTS(hipfft_rt == HIPFFT_SUCCESS, "hipfftCreate failed"); hipStream_t stream = exec_space.hip_stream(); hipfftSetStream((*plan), stream); @@ -45,8 +45,8 @@ auto create_plan(const ExecutionSpace& exec_space, std::multiplies<>()); hipfft_rt = hipfftPlan1d(&(*plan), nx, type, howmany); - if (hipfft_rt != HIPFFT_SUCCESS) - throw std::runtime_error("hipfftPlan1d failed"); + KOKKOSFFT_EXPECTS(hipfft_rt == HIPFFT_SUCCESS, "hipfftPlan1d failed"); + return fft_size; } @@ -69,8 +69,7 @@ auto create_plan(const ExecutionSpace& exec_space, plan = std::make_unique(); hipfftResult hipfft_rt = hipfftCreate(&(*plan)); - if (hipfft_rt != HIPFFT_SUCCESS) - throw std::runtime_error("hipfftCreate failed"); + KOKKOSFFT_EXPECTS(hipfft_rt == HIPFFT_SUCCESS, "hipfftCreate failed"); hipStream_t stream = exec_space.hip_stream(); hipfftSetStream((*plan), stream); @@ -84,8 +83,8 @@ auto create_plan(const ExecutionSpace& exec_space, std::multiplies<>()); hipfft_rt = hipfftPlan2d(&(*plan), nx, ny, type); - if (hipfft_rt != HIPFFT_SUCCESS) - throw std::runtime_error("hipfftPlan2d failed"); + KOKKOSFFT_EXPECTS(hipfft_rt == HIPFFT_SUCCESS, "hipfftPlan2d failed"); + return fft_size; } @@ -108,8 +107,7 @@ auto create_plan(const ExecutionSpace& exec_space, plan = std::make_unique(); hipfftResult hipfft_rt = hipfftCreate(&(*plan)); - if (hipfft_rt != HIPFFT_SUCCESS) - throw std::runtime_error("hipfftCreate failed"); + KOKKOSFFT_EXPECTS(hipfft_rt == HIPFFT_SUCCESS, "hipfftCreate failed"); hipStream_t stream = exec_space.hip_stream(); hipfftSetStream((*plan), stream); @@ -125,8 +123,8 @@ auto create_plan(const ExecutionSpace& exec_space, std::multiplies<>()); hipfft_rt = hipfftPlan3d(&(*plan), nx, ny, nz, type); - if (hipfft_rt != HIPFFT_SUCCESS) - throw std::runtime_error("hipfftPlan3d failed"); + KOKKOSFFT_EXPECTS(hipfft_rt == HIPFFT_SUCCESS, "hipfftPlan3d failed"); + return fft_size; } @@ -169,8 +167,7 @@ auto create_plan(const ExecutionSpace& exec_space, plan = std::make_unique(); hipfftResult hipfft_rt = hipfftCreate(&(*plan)); - if (hipfft_rt != HIPFFT_SUCCESS) - throw std::runtime_error("hipfftCreate failed"); + KOKKOSFFT_EXPECTS(hipfft_rt == HIPFFT_SUCCESS, "hipfftCreate failed"); hipStream_t stream = exec_space.hip_stream(); hipfftSetStream((*plan), stream); @@ -179,8 +176,8 @@ auto create_plan(const ExecutionSpace& exec_space, in_extents.data(), istride, idist, out_extents.data(), ostride, odist, type, howmany); - if (hipfft_rt != HIPFFT_SUCCESS) - throw std::runtime_error("hipfftPlan failed"); + KOKKOSFFT_EXPECTS(hipfft_rt == HIPFFT_SUCCESS, "hipfftPlanMany failed"); + return fft_size; } @@ -193,4 +190,4 @@ void destroy_plan_and_info(std::unique_ptr& plan, InfoType&) { } // namespace Impl } // namespace KokkosFFT -#endif \ No newline at end of file +#endif diff --git a/fft/src/KokkosFFT_HIP_transform.hpp b/fft/src/KokkosFFT_HIP_transform.hpp index 8102df49..aa5dffac 100644 --- a/fft/src/KokkosFFT_HIP_transform.hpp +++ b/fft/src/KokkosFFT_HIP_transform.hpp @@ -6,6 +6,7 @@ #define KOKKOSFFT_HIP_TRANSFORM_HPP #include +#include "KokkosFFT_asserts.hpp" namespace KokkosFFT { namespace Impl { @@ -13,50 +14,44 @@ template inline void exec_plan(hipfftHandle& plan, hipfftReal* idata, hipfftComplex* odata, int /*direction*/, Args...) { hipfftResult hipfft_rt = hipfftExecR2C(plan, idata, odata); - if (hipfft_rt != HIPFFT_SUCCESS) - throw std::runtime_error("hipfftExecR2C failed"); + KOKKOSFFT_EXPECTS(hipfft_rt == HIPFFT_SUCCESS, "hipfftExecR2C failed"); } template inline void exec_plan(hipfftHandle& plan, hipfftDoubleReal* idata, hipfftDoubleComplex* odata, int /*direction*/, Args...) { hipfftResult hipfft_rt = hipfftExecD2Z(plan, idata, odata); - if (hipfft_rt != HIPFFT_SUCCESS) - throw std::runtime_error("hipfftExecD2Z failed"); + KOKKOSFFT_EXPECTS(hipfft_rt == HIPFFT_SUCCESS, "hipfftExecD2Z failed"); } template inline void exec_plan(hipfftHandle& plan, hipfftComplex* idata, hipfftReal* odata, int /*direction*/, Args...) { hipfftResult hipfft_rt = hipfftExecC2R(plan, idata, odata); - if (hipfft_rt != HIPFFT_SUCCESS) - throw std::runtime_error("hipfftExecC2R failed"); + KOKKOSFFT_EXPECTS(hipfft_rt == HIPFFT_SUCCESS, "hipfftExecC2R failed"); } template inline void exec_plan(hipfftHandle& plan, hipfftDoubleComplex* idata, hipfftDoubleReal* odata, int /*direction*/, Args...) { hipfftResult hipfft_rt = hipfftExecZ2D(plan, idata, odata); - if (hipfft_rt != HIPFFT_SUCCESS) - throw std::runtime_error("hipfftExecZ2D failed"); + KOKKOSFFT_EXPECTS(hipfft_rt == HIPFFT_SUCCESS, "hipfftExecZ2D failed"); } template inline void exec_plan(hipfftHandle& plan, hipfftComplex* idata, hipfftComplex* odata, int direction, Args...) { hipfftResult hipfft_rt = hipfftExecC2C(plan, idata, odata, direction); - if (hipfft_rt != HIPFFT_SUCCESS) - throw std::runtime_error("hipfftExecC2C failed"); + KOKKOSFFT_EXPECTS(hipfft_rt == HIPFFT_SUCCESS, "hipfftExecC2C failed"); } template inline void exec_plan(hipfftHandle& plan, hipfftDoubleComplex* idata, hipfftDoubleComplex* odata, int direction, Args...) { hipfftResult hipfft_rt = hipfftExecZ2Z(plan, idata, odata, direction); - if (hipfft_rt != HIPFFT_SUCCESS) - throw std::runtime_error("hipfftExecZ2Z failed"); + KOKKOSFFT_EXPECTS(hipfft_rt == HIPFFT_SUCCESS, "hipfftExecZ2Z failed"); } } // namespace Impl } // namespace KokkosFFT -#endif \ No newline at end of file +#endif diff --git a/fft/src/KokkosFFT_Plans.hpp b/fft/src/KokkosFFT_Plans.hpp index caeeb2ca..17487a63 100644 --- a/fft/src/KokkosFFT_Plans.hpp +++ b/fft/src/KokkosFFT_Plans.hpp @@ -172,18 +172,19 @@ class Plan { static_assert(InViewType::rank() >= 1, "Plan::Plan: View rank must be larger than or equal to 1"); - if (KokkosFFT::Impl::is_real_v && - m_direction != KokkosFFT::Direction::forward) { - throw std::runtime_error( - "Plan::Plan: real to complex transform is constrcuted with backward " - "direction."); + KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, m_axes), + "axes are invalid for in/out views"); + + if constexpr (KokkosFFT::Impl::is_real_v) { + KOKKOSFFT_EXPECTS( + m_direction == KokkosFFT::Direction::forward, + "real to complex transform is constructed with backward direction."); } - if (KokkosFFT::Impl::is_real_v && - m_direction != KokkosFFT::Direction::backward) { - throw std::runtime_error( - "Plan::Plan: complex to real transform is constrcuted with forward " - "direction."); + if constexpr (KokkosFFT::Impl::is_real_v) { + KOKKOSFFT_EXPECTS( + m_direction == KokkosFFT::Direction::backward, + "complex to real transform is constructed with forward direction."); } shape_type<1> s = {0}; @@ -233,18 +234,18 @@ class Plan { "Plan::Plan: View rank must be larger than or equal to the " "Rank of FFT axes"); - if (KokkosFFT::Impl::is_real_v && - m_direction != KokkosFFT::Direction::forward) { - throw std::runtime_error( - "Plan::Plan: real to complex transform is constrcuted with backward " - "direction."); + KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, m_axes), + "axes are invalid for in/out views"); + if constexpr (KokkosFFT::Impl::is_real_v) { + KOKKOSFFT_EXPECTS( + m_direction == KokkosFFT::Direction::forward, + "real to complex transform is constructed with backward direction."); } - if (KokkosFFT::Impl::is_real_v && - m_direction != KokkosFFT::Direction::backward) { - throw std::runtime_error( - "Plan::Plan: complex to real transform is constrcuted with forward " - "direction."); + if constexpr (KokkosFFT::Impl::is_real_v) { + KOKKOSFFT_EXPECTS( + m_direction == KokkosFFT::Direction::backward, + "complex to real transform is constructed with forward direction."); } m_in_extents = KokkosFFT::Impl::extract_extents(in); @@ -288,17 +289,14 @@ class Plan { auto in_extents = KokkosFFT::Impl::extract_extents(in); auto out_extents = KokkosFFT::Impl::extract_extents(out); - if (in_extents != m_in_extents) { - throw std::runtime_error( - "Plan::good: extents of input View for plan and execution are " - "not identical."); - } - if (out_extents != m_out_extents) { - throw std::runtime_error( - "Plan::good: extents of output View for plan and execution are " - "not identical."); - } + KOKKOSFFT_EXPECTS( + in_extents == m_in_extents, + "extents of input View for plan and execution are not identical."); + + KOKKOSFFT_EXPECTS( + out_extents == m_out_extents, + "extents of output View for plan and execution are not identical."); } /// \brief Return the execution space diff --git a/fft/src/KokkosFFT_ROCM_plans.hpp b/fft/src/KokkosFFT_ROCM_plans.hpp index 2d44ba59..d051c0b9 100644 --- a/fft/src/KokkosFFT_ROCM_plans.hpp +++ b/fft/src/KokkosFFT_ROCM_plans.hpp @@ -9,6 +9,7 @@ #include #include "KokkosFFT_ROCM_types.hpp" #include "KokkosFFT_layouts.hpp" +#include "KokkosFFT_asserts.hpp" namespace KokkosFFT { namespace Impl { @@ -124,8 +125,8 @@ auto create_plan(const ExecutionSpace& exec_space, // Create the description rocfft_plan_description description; rocfft_status status = rocfft_plan_description_create(&description); - if (status != rocfft_status_success) - std::runtime_error("rocfft_plan_description_create failed"); + KOKKOSFFT_EXPECTS(status == rocfft_status_success, + "rocfft_plan_description_create failed"); auto [in_array_type, out_array_type, fft_direction] = get_in_out_array_type(type, direction); @@ -143,8 +144,8 @@ auto create_plan(const ExecutionSpace& exec_space, out_strides.size(), // output stride length out_strides.data(), // output stride data odist); // output batch distance - if (status != rocfft_status_success) - std::runtime_error("rocfft_plan_description_set_data_layout failed"); + KOKKOSFFT_EXPECTS(status == rocfft_status_success, + "rocfft_plan_description_set_data_layout failed"); // Out-of-place transform const rocfft_result_placement place = rocfft_placement_notinplace; @@ -157,38 +158,38 @@ auto create_plan(const ExecutionSpace& exec_space, howmany, // Number of transforms description // Description ); - if (status != rocfft_status_success) - std::runtime_error("rocfft_plan_create failed"); + KOKKOSFFT_EXPECTS(status == rocfft_status_success, + "rocfft_plan_create failed"); // Prepare workbuffer and set execution information status = rocfft_execution_info_create(&execution_info); - if (status != rocfft_status_success) - std::runtime_error("rocfft_execution_info_create failed"); + KOKKOSFFT_EXPECTS(status == rocfft_status_success, + "rocfft_execution_info_create failed"); // set stream // NOTE: The stream must be of type hipStream_t. // It is an error to pass the address of a hipStream_t object. hipStream_t stream = exec_space.hip_stream(); status = rocfft_execution_info_set_stream(execution_info, stream); - if (status != rocfft_status_success) - throw std::runtime_error("rocfft_execution_info_set_stream failed."); + KOKKOSFFT_EXPECTS(status == rocfft_status_success, + "rocfft_execution_info_set_stream failed"); std::size_t workbuffersize = 0; status = rocfft_plan_get_work_buffer_size(*plan, &workbuffersize); - if (status != rocfft_status_success) - std::runtime_error("rocfft_plan_get_work_buffer_size failed"); + KOKKOSFFT_EXPECTS(status == rocfft_status_success, + "rocfft_plan_get_work_buffer_size failed"); if (workbuffersize > 0) { buffer = BufferViewType("work_buffer", workbuffersize); status = rocfft_execution_info_set_work_buffer( execution_info, (void*)buffer.data(), workbuffersize); - if (status != rocfft_status_success) - std::runtime_error("rocfft_execution_info_set_work_buffer failed"); + KOKKOSFFT_EXPECTS(status == rocfft_status_success, + "rocfft_execution_info_set_work_buffer failed"); } status = rocfft_plan_description_destroy(description); - if (status != rocfft_status_success) - std::runtime_error("rocfft_plan_description_destroy failed"); + KOKKOSFFT_EXPECTS(status == rocfft_status_success, + "rocfft_plan_description_destroy failed"); return fft_size; } @@ -204,4 +205,4 @@ void destroy_plan_and_info(std::unique_ptr& plan, } // namespace Impl } // namespace KokkosFFT -#endif \ No newline at end of file +#endif diff --git a/fft/src/KokkosFFT_ROCM_transform.hpp b/fft/src/KokkosFFT_ROCM_transform.hpp index 5c177d66..c1735038 100644 --- a/fft/src/KokkosFFT_ROCM_transform.hpp +++ b/fft/src/KokkosFFT_ROCM_transform.hpp @@ -7,6 +7,7 @@ #include #include +#include "KokkosFFT_asserts.hpp" namespace KokkosFFT { namespace Impl { @@ -15,8 +16,8 @@ inline void exec_plan(rocfft_plan& plan, float* idata, const rocfft_execution_info& execution_info) { rocfft_status status = rocfft_execute(plan, (void**)&idata, (void**)&odata, execution_info); - if (status != rocfft_status_success) - throw std::runtime_error("rocfft_execute for R2C failed"); + KOKKOSFFT_EXPECTS(status == rocfft_status_success, + "rocfft_execute for R2C failed"); } inline void exec_plan(rocfft_plan& plan, double* idata, @@ -24,8 +25,8 @@ inline void exec_plan(rocfft_plan& plan, double* idata, const rocfft_execution_info& execution_info) { rocfft_status status = rocfft_execute(plan, (void**)&idata, (void**)&odata, execution_info); - if (status != rocfft_status_success) - throw std::runtime_error("rocfft_execute for D2Z failed"); + KOKKOSFFT_EXPECTS(status == rocfft_status_success, + "rocfft_execute for D2Z failed"); } inline void exec_plan(rocfft_plan& plan, std::complex* idata, @@ -33,8 +34,8 @@ inline void exec_plan(rocfft_plan& plan, std::complex* idata, const rocfft_execution_info& execution_info) { rocfft_status status = rocfft_execute(plan, (void**)&idata, (void**)&odata, execution_info); - if (status != rocfft_status_success) - throw std::runtime_error("rocfft_execute for C2R failed"); + KOKKOSFFT_EXPECTS(status == rocfft_status_success, + "rocfft_execute for C2R failed"); } inline void exec_plan(rocfft_plan& plan, std::complex* idata, @@ -42,8 +43,8 @@ inline void exec_plan(rocfft_plan& plan, std::complex* idata, const rocfft_execution_info& execution_info) { rocfft_status status = rocfft_execute(plan, (void**)&idata, (void**)&odata, execution_info); - if (status != rocfft_status_success) - throw std::runtime_error("rocfft_execute for Z2D failed"); + KOKKOSFFT_EXPECTS(status == rocfft_status_success, + "rocfft_execute for Z2D failed"); } inline void exec_plan(rocfft_plan& plan, std::complex* idata, @@ -51,8 +52,8 @@ inline void exec_plan(rocfft_plan& plan, std::complex* idata, const rocfft_execution_info& execution_info) { rocfft_status status = rocfft_execute(plan, (void**)&idata, (void**)&odata, execution_info); - if (status != rocfft_status_success) - throw std::runtime_error("rocfft_execute for C2C failed"); + KOKKOSFFT_EXPECTS(status == rocfft_status_success, + "rocfft_execute for C2C failed"); } inline void exec_plan(rocfft_plan& plan, std::complex* idata, @@ -60,11 +61,11 @@ inline void exec_plan(rocfft_plan& plan, std::complex* idata, const rocfft_execution_info& execution_info) { rocfft_status status = rocfft_execute(plan, (void**)&idata, (void**)&odata, execution_info); - if (status != rocfft_status_success) - throw std::runtime_error("rocfft_execute for Z2Z failed"); + KOKKOSFFT_EXPECTS(status == rocfft_status_success, + "rocfft_execute for Z2Z failed"); } } // namespace Impl } // namespace KokkosFFT -#endif \ No newline at end of file +#endif diff --git a/fft/src/KokkosFFT_Transform.hpp b/fft/src/KokkosFFT_Transform.hpp index 23dcd003..8a83677a 100644 --- a/fft/src/KokkosFFT_Transform.hpp +++ b/fft/src/KokkosFFT_Transform.hpp @@ -141,7 +141,8 @@ void fft(const ExecutionSpace& exec_space, const InViewType& in, "and OutViewType."); static_assert(InViewType::rank() >= 1, "fft: View rank must be larger than or equal to 1"); - + KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, axis_type<1>({axis})), + "axes are invalid for in/out views"); KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::forward, axis, n); KokkosFFT::Impl::fft_exec_impl(plan, in, out, norm); @@ -169,7 +170,8 @@ void ifft(const ExecutionSpace& exec_space, const InViewType& in, "and OutViewType."); static_assert(InViewType::rank() >= 1, "ifft: View rank must be larger than or equal to 1"); - + KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, axis_type<1>({axis})), + "axes are invalid for in/out views"); KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::backward, axis, n); KokkosFFT::Impl::fft_exec_impl(plan, in, out, norm); @@ -205,7 +207,8 @@ void rfft(const ExecutionSpace& exec_space, const InViewType& in, "rfft: InViewType must be real"); static_assert(KokkosFFT::Impl::is_complex_v, "rfft: OutViewType must be complex"); - + KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, axis_type<1>({axis})), + "axes are invalid for in/out views"); fft(exec_space, in, out, norm, axis, n); } @@ -240,6 +243,8 @@ void irfft(const ExecutionSpace& exec_space, const InViewType& in, "irfft: InViewType must be complex"); static_assert(KokkosFFT::Impl::is_real_v, "irfft: OutViewType must be real"); + KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, axis_type<1>({axis})), + "axes are invalid for in/out views"); ifft(exec_space, in, out, norm, axis, n); } @@ -275,6 +280,8 @@ void hfft(const ExecutionSpace& exec_space, const InViewType& in, "hfft: InViewType must be complex"); static_assert(KokkosFFT::Impl::is_real_v, "hfft: OutViewType must be real"); + KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, axis_type<1>({axis})), + "axes are invalid for in/out views"); auto new_norm = KokkosFFT::Impl::swap_direction(norm); // using ComplexViewType = typename // KokkosFFT::Impl::complex_view_type::type; @@ -314,7 +321,8 @@ void ihfft(const ExecutionSpace& exec_space, const InViewType& in, "ihfft: InViewType must be real"); static_assert(KokkosFFT::Impl::is_complex_v, "ihfft: OutViewType must be complex"); - + KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, axis_type<1>({axis})), + "axes are invalid for in/out views"); auto new_norm = KokkosFFT::Impl::swap_direction(norm); OutViewType out_conj; rfft(exec_space, in, out, new_norm, axis, n); @@ -346,7 +354,8 @@ void fft2(const ExecutionSpace& exec_space, const InViewType& in, "and OutViewType."); static_assert(InViewType::rank() >= 2, "fft2: View rank must be larger than or equal to 2"); - + KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, axes), + "axes are invalid for in/out views"); KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::forward, axes, s); KokkosFFT::Impl::fft_exec_impl(plan, in, out, norm); @@ -375,7 +384,8 @@ void ifft2(const ExecutionSpace& exec_space, const InViewType& in, "and OutViewType."); static_assert(InViewType::rank() >= 2, "ifft2: View rank must be larger than or equal to 2"); - + KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, axes), + "axes are invalid for in/out views"); KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::backward, axes, s); KokkosFFT::Impl::fft_exec_impl(plan, in, out, norm); @@ -412,7 +422,8 @@ void rfft2(const ExecutionSpace& exec_space, const InViewType& in, "rfft2: InViewType must be real"); static_assert(KokkosFFT::Impl::is_complex_v, "rfft2: OutViewType must be complex"); - + KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, axes), + "axes are invalid for in/out views"); fft2(exec_space, in, out, norm, axes, s); } @@ -447,7 +458,8 @@ void irfft2(const ExecutionSpace& exec_space, const InViewType& in, "irfft2: InViewType must be complex"); static_assert(KokkosFFT::Impl::is_real_v, "irfft2: OutViewType must be real"); - + KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, axes), + "axes are invalid for in/out views"); ifft2(exec_space, in, out, norm, axes, s); } @@ -479,7 +491,8 @@ void fftn(const ExecutionSpace& exec_space, const InViewType& in, static_assert( InViewType::rank() >= DIM, "fftn: View rank must be larger than or equal to the Rank of FFT axes"); - + KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, axes), + "axes are invalid for in/out views"); KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::forward, axes, s); KokkosFFT::Impl::fft_exec_impl(plan, in, out, norm); @@ -513,7 +526,8 @@ void ifftn(const ExecutionSpace& exec_space, const InViewType& in, static_assert( InViewType::rank() >= DIM, "ifftn: View rank must be larger than or equal to the Rank of FFT axes"); - + KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, axes), + "axes are invalid for in/out views"); KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::backward, axes, s); KokkosFFT::Impl::fft_exec_impl(plan, in, out, norm); @@ -555,7 +569,8 @@ void rfftn(const ExecutionSpace& exec_space, const InViewType& in, "rfftn: InViewType must be real"); static_assert(KokkosFFT::Impl::is_complex_v, "rfftn: OutViewType must be complex"); - + KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, axes), + "axes are invalid for in/out views"); fftn(exec_space, in, out, axes, norm, s); } @@ -595,7 +610,8 @@ void irfftn(const ExecutionSpace& exec_space, const InViewType& in, "irfftn: InViewType must be complex"); static_assert(KokkosFFT::Impl::is_real_v, "irfftn: OutViewType must be real"); - + KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, axes), + "axes are invalid for in/out views"); ifftn(exec_space, in, out, axes, norm, s); } diff --git a/fft/unit_test/Test_Transform.cpp b/fft/unit_test/Test_Transform.cpp index 4bcb5f49..ceb5cf63 100644 --- a/fft/unit_test/Test_Transform.cpp +++ b/fft/unit_test/Test_Transform.cpp @@ -24,7 +24,7 @@ void fft1(ViewType& in, ViewType& out) { using value_type = typename ViewType::non_const_value_type; using real_value_type = KokkosFFT::Impl::base_floating_point_type; - static_assert(KokkosFFT::Impl::is_complex::value, + static_assert(KokkosFFT::Impl::is_complex_v, "fft1: ViewType must be complex"); const value_type I(0.0, 1.0); @@ -64,7 +64,7 @@ void ifft1(ViewType& in, ViewType& out) { using value_type = typename ViewType::non_const_value_type; using real_value_type = KokkosFFT::Impl::base_floating_point_type; - static_assert(KokkosFFT::Impl::is_complex::value, + static_assert(KokkosFFT::Impl::is_complex_v, "ifft1: ViewType must be complex"); const value_type I(0.0, 1.0);