Skip to content

Commit

Permalink
FFT OpenBC Solver: more optimization (#4232)
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang authored Nov 14, 2024
1 parent 294b6fe commit 0165b67
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 27 deletions.
2 changes: 2 additions & 0 deletions Src/FFT/AMReX_FFT_OpenBCSolver.H
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
}
}
}

m_r2c.prepare_openbc();
}

template <typename T>
Expand Down
97 changes: 70 additions & 27 deletions Src/FFT/AMReX_FFT_R2C.H
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ public:
template <typename F>
void post_forward_doit (F const& post_forward);

void prepare_openbc ();

private:

static std::pair<Plan<T>,Plan<T>> make_c2c_plans (cMF& inout);
Expand All @@ -176,6 +178,8 @@ private:
Plan<T> m_fft_bwd_y{};
Plan<T> m_fft_fwd_z{};
Plan<T> m_fft_bwd_z{};
Plan<T> m_fft_fwd_x_half{};
Plan<T> m_fft_bwd_x_half{};

// Comm meta-data. In the forward phase, we start with (x,y,z),
// transpose to (y,x,z) and then (z,x,y). In the backward phase, we
Expand Down Expand Up @@ -394,6 +398,60 @@ R2C<T,D,S>::~R2C<T,D,S> ()
m_fft_fwd_x.destroy();
m_fft_fwd_y.destroy();
m_fft_fwd_z.destroy();
if (m_fft_bwd_x_half.plan != m_fft_fwd_x_half.plan) {
m_fft_bwd_x_half.destroy();
}
m_fft_fwd_x_half.destroy();
}

template <typename T, Direction D, DomainStrategy S>
void R2C<T,D,S>::prepare_openbc ()
{
#if (AMREX_SPACEDIM == 3)
if (m_slab_decomp) {
auto* fab = detail::get_fab(m_rx);
if (fab) {
Box bottom_half = m_real_domain;
bottom_half.growHi(2,-m_real_domain.length(2)/2);
Box box = fab->box() & bottom_half;
if (box.ok()) {
auto* pr = fab->dataPtr();
auto* pc = (typename Plan<T>::VendorComplex *)
detail::get_fab(m_cx)->dataPtr();
#ifdef AMREX_USE_SYCL
m_fft_fwd_x_half.template init_r2c<Direction::forward>
(box, pr, pc, m_slab_decomp);
m_fft_bwd_x_half = m_fft_fwd_x_half;
#else
if constexpr (D == Direction::both || D == Direction::forward) {
m_fft_fwd_x_half.template init_r2c<Direction::forward>
(box, pr, pc, m_slab_decomp);
}
if constexpr (D == Direction::both || D == Direction::backward) {
m_fft_bwd_x_half.template init_r2c<Direction::backward>
(box, pr, pc, m_slab_decomp);
}
#endif
}
}
} // else todo

if (m_cmd_x2z && ! m_cmd_x2z_half) {
Box bottom_half = m_spectral_domain_z;
// Note that z-direction's index is 0 because we z is the
// unit-stride direction here.
bottom_half.growHi(0,-m_spectral_domain_z.length(0)/2);
m_cmd_x2z_half = std::make_unique<MultiBlockCommMetaData>
(m_cz, bottom_half, m_cx, IntVect(0), m_dtos_x2z);
}

if (m_cmd_z2x && ! m_cmd_z2x_half) {
Box bottom_half = m_spectral_domain_x;
bottom_half.growHi(2,-m_spectral_domain_x.length(2)/2);
m_cmd_z2x_half = std::make_unique<MultiBlockCommMetaData>
(m_cx, bottom_half, m_cz, IntVect(0), m_dtos_z2x);
}
#endif
}

template <typename T, Direction D, DomainStrategy S>
Expand All @@ -406,7 +464,8 @@ void R2C<T,D,S>::forward (MF const& inmf)
if (&m_rx != &inmf) {
m_rx.ParallelCopy(inmf, 0, 0, 1);
}
m_fft_fwd_x.template compute_r2c<Direction::forward>();
auto& fft_x = m_openbc_half ? m_fft_fwd_x_half : m_fft_fwd_x;
fft_x.template compute_r2c<Direction::forward>();

if ( m_cmd_x2y) {
ParallelCopy(m_cy, m_cx, *m_cmd_x2y, 0, 0, 1, m_dtos_x2y);
Expand All @@ -419,19 +478,16 @@ void R2C<T,D,S>::forward (MF const& inmf)
#if (AMREX_SPACEDIM == 3)
else if ( m_cmd_x2z) {
if (m_openbc_half) {
Box upper_half = m_spectral_domain_z;
// Note that z-direction's index is 0 because we z is the unit-stride direction here.
upper_half.growLo (0,-m_spectral_domain_z.length(0)/2);
if (! m_cmd_x2z_half) {
Box bottom_half = m_spectral_domain_z;
bottom_half.growHi(0,-m_spectral_domain_z.length(0)/2);
m_cmd_x2z_half = std::make_unique<MultiBlockCommMetaData>
(m_cz, bottom_half, m_cx, IntVect(0), m_dtos_x2z);
}
NonLocalBC::ApplyDtosAndProjectionOnReciever packing
{NonLocalBC::PackComponents{}, m_dtos_x2z};
auto handler = ParallelCopy_nowait(m_cz, m_cx, *m_cmd_x2z_half, packing);

Box upper_half = m_spectral_domain_z;
// Note that z-direction's index is 0 because we z is the
// unit-stride direction here.
upper_half.growLo (0,-m_spectral_domain_z.length(0)/2);
m_cz.setVal(0, upper_half, 0, 1);

ParallelCopy_finish(m_cz, std::move(handler), *m_cmd_x2z_half, packing);
} else {
ParallelCopy(m_cz, m_cx, *m_cmd_x2z, 0, 0, 1, m_dtos_x2z);
Expand Down Expand Up @@ -459,22 +515,8 @@ void R2C<T,D,S>::backward_doit (MF& outmf, IntVect const& ngout)
}
#if (AMREX_SPACEDIM == 3)
else if ( m_cmd_z2x) {
if (m_openbc_half) {
Box upper_half = m_spectral_domain_x;
upper_half.growLo (2,-m_spectral_domain_x.length(2)/2);
if (! m_cmd_z2x_half) {
Box bottom_half = m_spectral_domain_x;
bottom_half.growHi(2,-m_spectral_domain_x.length(2)/2);
m_cmd_z2x_half = std::make_unique<MultiBlockCommMetaData>
(m_cx, bottom_half, m_cz, IntVect(0), m_dtos_z2x);
}
NonLocalBC::ApplyDtosAndProjectionOnReciever packing
{NonLocalBC::PackComponents{}, m_dtos_z2x};
auto handler = ParallelCopy_nowait(m_cx, m_cz, *m_cmd_z2x_half, packing);
ParallelCopy_finish(m_cx, std::move(handler), *m_cmd_z2x_half, packing);
} else {
ParallelCopy(m_cx, m_cz, *m_cmd_z2x, 0, 0, 1, m_dtos_z2x);
}
auto const& cmd = m_openbc_half ? m_cmd_z2x_half : m_cmd_z2x;
ParallelCopy(m_cx, m_cz, *cmd, 0, 0, 1, m_dtos_z2x);
}
#endif

Expand All @@ -483,7 +525,8 @@ void R2C<T,D,S>::backward_doit (MF& outmf, IntVect const& ngout)
ParallelCopy(m_cx, m_cy, *m_cmd_y2x, 0, 0, 1, m_dtos_y2x);
}

m_fft_bwd_x.template compute_r2c<Direction::backward>();
auto& fft_x = m_openbc_half ? m_fft_bwd_x_half : m_fft_bwd_x;
fft_x.template compute_r2c<Direction::backward>();
outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), ngout);
}

Expand Down

0 comments on commit 0165b67

Please sign in to comment.