Skip to content

Commit

Permalink
FFT: Add new domain decomposition strategy
Browse files Browse the repository at this point in the history
Instead of pencil, it has the option of doing slab decomposition. This
allows the x and y directions to be processed together without MPI
communication.
  • Loading branch information
WeiqunZhang committed Nov 8, 2024
1 parent 8e7bb00 commit 5278f52
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 71 deletions.
67 changes: 48 additions & 19 deletions Src/FFT/AMReX_FFT_Helper.H
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ namespace amrex::FFT

enum struct Direction { forward, backward, both, none };

enum struct DomainStrategy { slab, pencil };

AMREX_ENUM( Boundary, periodic, even, odd );

enum struct Kind { none, r2c_f, r2c_b, c2c_f, c2c_b, r2r_ee_f, r2r_ee_b,
Expand Down Expand Up @@ -172,18 +174,33 @@ struct Plan
}

template <Direction D>
void init_r2c (Box const& box, T* pr, VendorComplex* pc)
void init_r2c (Box const& box, T* pr, VendorComplex* pc, bool is_2d_transform = false)
{
static_assert(D == Direction::forward || D == Direction::backward);

int rank = is_2d_transform ? 2 : 1;

kind = (D == Direction::forward) ? Kind::r2c_f : Kind::r2c_b;
defined = true;
pf = (void*)pr;
pb = (void*)pc;

n = box.length(0);
int nc = (n/2) + 1;
howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
int len[2] = {};
if (rank == 1) {
len[0] = box.length(0);
} else {
len[0] = box.length(1);
len[1] = box.length(0);
}
int nr = (rank == 1) ? len[0] : len[0]*len[1];
n = nr;
int nc = (rank == 1) ? (len[0]/2+1) : (len[1]/2+1)*len[0];
#if (AMREX_SPACEDIM == 1)
howmany = 1;
#else
howmany = (rank == 1) ? AMREX_D_TERM(1, *box.length(1), *box.length(2))
: AMREX_D_TERM(1, *1 , *box.length(2));
#endif

amrex::ignore_unused(nc);

Expand All @@ -193,43 +210,51 @@ struct Plan
if constexpr (D == Direction::forward) {
cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
AMREX_CUFFT_SAFE_CALL
(cufftMakePlanMany(plan, 1, &n, nullptr, 1, n, nullptr, 1, nc, fwd_type, howmany, &work_size));
(cufftMakePlanMany(plan, rank, len, nullptr, 1, nr, nullptr, 1, nc, fwd_type, howmany, &work_size));
AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
} else {
cufftType bwd_type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
AMREX_CUFFT_SAFE_CALL
(cufftMakePlanMany(plan, 1, &n, nullptr, 1, nc, nullptr, 1, n, bwd_type, howmany, &work_size));
(cufftMakePlanMany(plan, rank, len, nullptr, 1, nc, nullptr, 1, nr, bwd_type, howmany, &work_size));
AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream()));
}
#elif defined(AMREX_USE_HIP)

auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
const std::size_t length = n;
const std::size_t length[2] = {std::size_t(len[0]), std::size_t(len[1])};
if constexpr (D == Direction::forward) {
AMREX_ROCFFT_SAFE_CALL
(rocfft_plan_create(&plan, rocfft_placement_notinplace,
rocfft_transform_type_real_forward, prec, 1,
&length, howmany, nullptr));
rocfft_transform_type_real_forward, prec, rank,
length, howmany, nullptr));
} else {
AMREX_ROCFFT_SAFE_CALL
(rocfft_plan_create(&plan, rocfft_placement_notinplace,
rocfft_transform_type_real_inverse, prec, 1,
&length, howmany, nullptr));
rocfft_transform_type_real_inverse, prec, rank,
length, howmany, nullptr));
}

#elif defined(AMREX_USE_SYCL)

auto* pp = new mkl_desc_r(n);
mkl_desc_r* pp;
if (rank == 1) {
pp = new mkl_desc_r(len[0]);
} else {
pp = new mkl_desc_r({std::int64_t(len[0]), std::int64_t(len[1])});
}
#ifndef AMREX_USE_MKL_DFTI_2024
pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::NOT_INPLACE);
#else
pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
#endif
pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nr);
pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
std::vector<std::int64_t> strides = {0,1};
std::vector<std::int64_t> strides;
strides.push_back(0);
if (rank == 2) { strides.push_back(len[1]); }
strides.push_back(1);
#ifndef AMREX_USE_MKL_DFTI_2024
pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
Expand All @@ -247,21 +272,21 @@ struct Plan
if constexpr (std::is_same_v<float,T>) {
if constexpr (D == Direction::forward) {
plan = fftwf_plan_many_dft_r2c
(1, &n, howmany, pr, nullptr, 1, n, pc, nullptr, 1, nc,
(rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc,
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
} else {
plan = fftwf_plan_many_dft_c2r
(1, &n, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, n,
(rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr,
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
}
} else {
if constexpr (D == Direction::forward) {
plan = fftw_plan_many_dft_r2c
(1, &n, howmany, pr, nullptr, 1, n, pc, nullptr, 1, nc,
(rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc,
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
} else {
plan = fftw_plan_many_dft_c2r
(1, &n, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, n,
(rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr,
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
}
}
Expand Down Expand Up @@ -1087,13 +1112,17 @@ namespace detail
template <typename FA1, typename FA2>
std::unique_ptr<char,DataDeleter> make_mfs_share (FA1& fa1, FA2& fa2)
{
bool not_same_fa = true;
if constexpr (std::is_same_v<FA1,FA2>) {
not_same_fa = (&fa1 != &fa2);
}
using FAB1 = typename FA1::FABType::value_type;
using FAB2 = typename FA2::FABType::value_type;
using T1 = typename FAB1::value_type;
using T2 = typename FAB2::value_type;
auto myproc = ParallelContext::MyProcSub();
bool alloc_1 = (myproc < fa1.size());
bool alloc_2 = (myproc < fa2.size());
bool alloc_2 = (myproc < fa2.size()) && not_same_fa;
void* p = nullptr;
if (alloc_1 && alloc_2) {
Box const& box1 = fa1.fabbox(myproc);
Expand Down
41 changes: 31 additions & 10 deletions Src/FFT/AMReX_FFT_Poisson.H
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,42 @@ public:
template <typename FA=MF, std::enable_if_t<IsFabArray_v<FA>,int> = 0>
Poisson (Geometry const& geom,
Array<std::pair<Boundary,Boundary>,AMREX_SPACEDIM> const& bc)
: m_geom(geom), m_bc(bc), m_r2x(geom.Domain(),bc)
{}
: m_geom(geom), m_bc(bc)
{
bool all_periodic = true;
for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) {
all_periodic = all_periodic
&& (bc[idim].first == Boundary::periodic)
&& (bc[idim].second == Boundary::periodic);
}
if (all_periodic) {
m_r2c = std::make_unique<R2C<typename MF::value_type>>(m_geom.Domain());
} else {
m_r2x = std::make_unique<R2X<typename MF::value_type>> (m_geom.Domain(), m_bc);
}
}

template <typename FA=MF, std::enable_if_t<IsFabArray_v<FA>,int> = 0>
explicit Poisson (Geometry const& geom)
: m_geom(geom),
m_bc{AMREX_D_DECL(std::make_pair(Boundary::periodic,Boundary::periodic),
std::make_pair(Boundary::periodic,Boundary::periodic),
std::make_pair(Boundary::periodic,Boundary::periodic))},
m_r2x(geom.Domain(),m_bc)
std::make_pair(Boundary::periodic,Boundary::periodic))}
{
AMREX_ALWAYS_ASSERT(m_geom.isAllPeriodic());
if (m_geom.isAllPeriodic()) {
m_r2c = std::make_unique<R2C<typename MF::value_type>>(m_geom.Domain());
} else {
amrex::Abort("FFT::Poisson: wrong BC");
}
}

void solve (MF& soln, MF const& rhs);

private:
Geometry m_geom;
Array<std::pair<Boundary,Boundary>,AMREX_SPACEDIM> m_bc;
R2X<typename MF::value_type> m_r2x;
std::unique_ptr<R2X<typename MF::value_type>> m_r2x;
std::unique_ptr<R2C<typename MF::value_type>> m_r2c;
};

#if (AMREX_SPACEDIM == 3)
Expand Down Expand Up @@ -114,7 +130,7 @@ void Poisson<MF>::solve (MF& soln, MF const& rhs)
{AMREX_D_DECL(T(2)/T(m_geom.CellSize(0)*m_geom.CellSize(0)),
T(2)/T(m_geom.CellSize(1)*m_geom.CellSize(1)),
T(2)/T(m_geom.CellSize(2)*m_geom.CellSize(2)))};
auto scale = m_r2x.scalingFactor();
auto scale = (m_r2x) ? m_r2x->scalingFactor() : T(1)/T(m_geom.Domain().numPts());

GpuArray<T,AMREX_SPACEDIM> offset{AMREX_D_DECL(T(0),T(0),T(0))};
// Not sure about odd-even and even-odd yet
Expand All @@ -133,8 +149,7 @@ void Poisson<MF>::solve (MF& soln, MF const& rhs)
}
}

m_r2x.forwardThenBackward(rhs, soln,
[=] AMREX_GPU_DEVICE (int i, int j, int k, auto& spectral_data)
auto f = [=] AMREX_GPU_DEVICE (int i, int j, int k, auto& spectral_data)
{
amrex::ignore_unused(j,k);
AMREX_D_TERM(T a = fac[0]*(i+offset[0]);,
Expand All @@ -147,7 +162,13 @@ void Poisson<MF>::solve (MF& soln, MF const& rhs)
spectral_data /= k2;
}
spectral_data *= scale;
});
};

if (m_r2x) {
m_r2x->forwardThenBackward(rhs, soln, f);
} else {
m_r2c->forwardThenBackward(rhs, soln, f);
}
}

#if (AMREX_SPACEDIM == 3)
Expand Down
Loading

0 comments on commit 5278f52

Please sign in to comment.