Skip to content

Commit

Permalink
FFT::PoissonHybrid: Add interfaces for user provided dz (#4229)
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang authored Nov 14, 2024
1 parent fc42a39 commit 294b6fe
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 127 deletions.
65 changes: 51 additions & 14 deletions Src/FFT/AMReX_FFT_Poisson.H
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ template <typename MF = MultiFab>
class PoissonHybrid
{
public:
using T = typename MF::value_type;

template <typename FA=MF, std::enable_if_t<IsFabArray_v<FA>,int> = 0>
explicit PoissonHybrid (Geometry const& geom)
Expand All @@ -104,6 +105,11 @@ public:
}

void solve (MF& soln, MF const& rhs);
void solve (MF& soln, MF const& rhs, Vector<T> const& dz);
void solve (MF& soln, MF const& rhs, Gpu::DeviceVector<T> const& dz);

template <typename DZ>
void solve_doit (MF& soln, MF const& rhs, DZ const& dz); // has to be public for cuda

private:
Geometry m_geom;
Expand Down Expand Up @@ -223,16 +229,50 @@ void PoissonOpenBC<MF>::solve (MF& soln, MF const& rhs)

#endif /* AMREX_SPACEDIM == 3 */

namespace fft_poisson_detail {
template <typename T>
struct DZ {
[[nodiscard]] constexpr T operator[] (int) const { return m_delz; }
T m_delz;
};
}

template <typename MF>
void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs)
{
auto delz = T(m_geom.CellSize(AMREX_SPACEDIM-1));
solve_doit(soln, rhs, fft_poisson_detail::DZ<T>{delz});
}

template <typename MF>
void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs, Gpu::DeviceVector<T> const& dz)
{
auto const* pdz = dz.dataPtr();
solve_doit(soln, rhs, pdz);
}

template <typename MF>
void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs, Vector<T> const& dz)
{
#ifdef AMREX_USE_GPU
Gpu::DeviceVector<T> d_dz(dz.size());
Gpu::htod_memcpy_async(d_dz.data(), dz.data(), dz.size()*sizeof(T));
auto const* pdz = d_dz.data();
#else
auto const* pdz = dz.data();
#endif
solve_doit(soln, rhs, pdz);
}

template <typename MF>
template <typename DZ>
void PoissonHybrid<MF>::solve_doit (MF& soln, MF const& rhs, DZ const& dz)
{
BL_PROFILE("FFT::PoissonHybrid::solve");

#if (AMREX_SPACEDIM < 3)
amrex::ignore_unused(soln, rhs);
amrex::ignore_unused(soln, rhs, dz);
#else
using T = typename MF::value_type;

auto facx = T(2)*Math::pi<T>()/T(m_geom.ProbLength(0));
auto facy = T(2)*Math::pi<T>()/T(m_geom.ProbLength(1));
auto dx = T(m_geom.CellSize(0));
Expand All @@ -242,9 +282,6 @@ void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs)
auto ny = m_geom.Domain().length(1);
auto nz = m_geom.Domain().length(2);

Gpu::DeviceVector<T> delzv(nz, T(m_geom.CellSize(2)));
auto const* delz = delzv.data();

Box cdomain = m_geom.Domain();
cdomain.setBig(0,cdomain.length(0)/2);
auto cba = amrex::decompose(cdomain, ParallelContext::NProcsSub(),
Expand Down Expand Up @@ -283,18 +320,18 @@ void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs)
for(int k=0; k < nz; k++) {
if(k==0) {
ald(i,j,k) = 0.;
cud(i,j,k) = 2.0 /(delz[k]*(delz[k]+delz[k+1]));
cud(i,j,k) = 2.0 /(dz[k]*(dz[k]+dz[k+1]));
bd(i,j,k) = k2 -ald(i,j,k)-cud(i,j,k);
} else if (k == nz-1) {
ald(i,j,k) = 2.0 /(delz[k]*(delz[k]+delz[k-1]));
ald(i,j,k) = 2.0 /(dz[k]*(dz[k]+dz[k-1]));
cud(i,j,k) = 0.;
bd(i,j,k) = k2 -ald(i,j,k)-cud(i,j,k);
if (i == 0 && j == 0) {
bd(i,j,k) *= 2.0;
}
} else {
ald(i,j,k) = 2.0 /(delz[k]*(delz[k]+delz[k-1]));
cud(i,j,k) = 2.0 /(delz[k]*(delz[k]+delz[k+1]));
ald(i,j,k) = 2.0 /(dz[k]*(dz[k]+dz[k-1]));
cud(i,j,k) = 2.0 /(dz[k]*(dz[k]+dz[k+1]));
bd(i,j,k) = k2 -ald(i,j,k)-cud(i,j,k);
}
}
Expand Down Expand Up @@ -339,18 +376,18 @@ void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs)
for(int k=0; k < nz; k++) {
if(k==0) {
ald[k] = 0.;
cud[k] = 2.0 /(delz[k]*(delz[k]+delz[k+1]));
cud[k] = 2.0 /(dz[k]*(dz[k]+dz[k+1]));
bd[k] = k2 -ald[k]-cud[k];
} else if (k == nz-1) {
ald[k] = 2.0 /(delz[k]*(delz[k]+delz[k-1]));
ald[k] = 2.0 /(dz[k]*(dz[k]+dz[k-1]));
cud[k] = 0.;
bd[k] = k2 -ald[k]-cud[k];
if (i == 0 && j == 0) {
bd[k] *= 2.0;
}
} else {
ald[k] = 2.0 /(delz[k]*(delz[k]+delz[k-1]));
cud[k] = 2.0 /(delz[k]*(delz[k]+delz[k+1]));
ald[k] = 2.0 /(dz[k]*(dz[k]+dz[k-1]));
cud[k] = 2.0 /(dz[k]*(dz[k]+dz[k+1]));
bd[k] = k2 -ald[k]-cud[k];
}
}
Expand Down
Loading

0 comments on commit 294b6fe

Please sign in to comment.