Skip to content

Commit

Permalink
Remove empty non matching mesh interpolation data exception (#2748)
Browse files Browse the repository at this point in the history
* remove check

* update nonmatching mesh interpolation tests

* linting

* Readability improvements

* minor cosmetic updates

* Doc fix

---------

Co-authored-by: Garth N. Wells <[email protected]>
  • Loading branch information
nate-sime and garth-wells authored Aug 24, 2023
1 parent 18e0b13 commit 200e94c
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 161 deletions.
179 changes: 84 additions & 95 deletions cpp/dolfinx/fem/interpolate.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ std::vector<T> interpolation_coords(const fem::FiniteElement<T>& element,

// Evaluate coordinate element basis at reference points
namespace stdex = std::experimental;
using cmdspan4_t = stdex::mdspan<const T, stdex::dextents<std::size_t, 4>>;
std::array<std::size_t, 4> phi_shape = cmap.tabulate_shape(0, Xshape[0]);
std::vector<T> phi_b(
std::reduce(phi_shape.begin(), phi_shape.end(), 1, std::multiplies{}));
cmdspan4_t phi_full(phi_b.data(), phi_shape);
stdex::mdspan<const T, stdex::dextents<std::size_t, 4>> phi_full(phi_b.data(),
phi_shape);
cmap.tabulate(0, X, Xshape, phi_b);
auto phi = stdex::submdspan(phi_full, 0, stdex::full_extent,
stdex::full_extent, 0);
Expand Down Expand Up @@ -101,7 +101,7 @@ std::vector<T> interpolation_coords(const fem::FiniteElement<T>& element,

/// @brief Interpolate an expression f(x) in a finite element space.
///
/// @param[out] u The function to interpolate into
/// @param[out] u The Function object to interpolate into
/// @param[in] f Evaluation of the function `f(x)` at the physical
/// points `x` given by fem::interpolation_coords. The element used in
/// fem::interpolation_coords should be the same element as associated
Expand All @@ -120,34 +120,39 @@ void interpolate(Function<T, U>& u, std::span<const T> f,

namespace impl
{
namespace stdex = std::experimental;

/// @brief Convenience typdef
template <typename T, std::size_t D>
using mdspan_t = stdex::mdspan<T, stdex::dextents<std::size_t, D>>;

/// @brief Scatter data into non-contiguous memory.
///
/// Scatter blocked data `send_values` to its corresponding src_rank and
/// Scatter blocked data `send_values` to its corresponding `src_rank` and
/// insert the data into `recv_values`. The insert location in
/// `recv_values` is determined by `dest_ranks`. If the j-th dest rank
/// is -1, then `recv_values[j*block_size:(j+1)*block_size]) = 0.
/// is -1, then `recv_values[j*block_size:(j+1)*block_size]) = 0`.
///
/// @param[in] comm The mpi communicator
/// @param[in] src_ranks The rank owning the values of each row in
/// send_values
/// @param[in] dest_ranks List of ranks receiving data. Size of array is
/// how many values we are receiving (not unrolled for blcok_size).
/// how many values we are receiving (not unrolled for block_size).
/// @param[in] send_values The values to send back to owner. Shape
/// (src_ranks.size(), block_size). Storage is row-major.
/// @param[in] s_shape Shape of send_values
/// @param[in,out] recv_values Array to fill with values Shape
/// (src_ranks.size(), block_size).
/// @param[in,out] recv_values Array to fill with values. Shape
/// (dest_ranks.size(), block_size). Storage is row-major.
/// @pre It is required that src_ranks are sorted.
/// @note dest_ranks can contain repeated entries
/// @note dest_ranks might contain -1 (no process owns the point)
template <dolfinx::scalar T>
void scatter_values(MPI_Comm comm, std::span<const std::int32_t> src_ranks,
std::span<const std::int32_t> dest_ranks,
std::span<const T> send_values,
std::array<std::size_t, 2> s_shape,
std::span<T> recv_values)
void scatter_values(
MPI_Comm comm, std::span<const std::int32_t> src_ranks,
std::span<const std::int32_t> dest_ranks,
stdex::mdspan<const T, stdex::dextents<std::size_t, 2>> send_values,
std::span<T> recv_values)
{
const std::size_t block_size = s_shape[1];
const std::size_t block_size = send_values.extent(1);
assert(src_ranks.size() * block_size == send_values.size());
assert(recv_values.size() == dest_ranks.size() * block_size);

Expand Down Expand Up @@ -236,11 +241,10 @@ void scatter_values(MPI_Comm comm, std::span<const std::int32_t> src_ranks,
std::partial_sum(send_sizes.begin(), send_sizes.end(),
std::next(send_offsets.begin(), 1));

std::stringstream cc;
// Send values to dest ranks
std::vector<T> values(recv_offsets.back());
values.reserve(1);
MPI_Neighbor_alltoallv(send_values.data(), send_sizes.data(),
MPI_Neighbor_alltoallv(send_values.data_handle(), send_sizes.data(),
send_offsets.data(), dolfinx::MPI::mpi_type<T>(),
values.data(), recv_sizes.data(), recv_offsets.data(),
dolfinx::MPI::mpi_type<T>(), reverse_comm);
Expand All @@ -257,20 +261,27 @@ void scatter_values(MPI_Comm comm, std::span<const std::int32_t> src_ranks,
};

/// @brief Apply interpolation operator Pi to data to evaluate the dof
/// coefficients
/// coefficients.
/// @param[in] Pi The interpolation matrix (shape = (num dofs,
/// num_points * value_size))
/// num_points * value_size)).
/// @param[in] data Function evaluations, by point, e.g. (f0(x0),
/// f1(x0), f0(x1), f1(x1), ...)
/// @param[out] coeffs The degrees of freedom to compute
/// @param[in] bs The block size
/// f1(x0), f0(x1), f1(x1), ...).
/// @param[out] coeffs The degrees of freedom to compute.
/// @param[in] bs The block size.
template <typename U, typename V, dolfinx::scalar T>
void interpolation_apply(const U& Pi, const V& data, std::span<T> coeffs,
int bs)
requires requires {
requires std::convertible_to<
U,
std::experimental::mdspan<const typename std::decay_t<U>::value_type,
std::experimental::dextents<std::size_t, 2>>>;
requires std::convertible_to<
V,
std::experimental::mdspan<const typename std::decay_t<V>::value_type,
std::experimental::dextents<std::size_t, 2>>>;
}
void interpolation_apply(U&& Pi, V&& data, std::span<T> coeffs, int bs)
{
using X = typename dolfinx::scalar_value_type_t<T>;
static_assert(U::rank() == 2, "Must be rank 2");
static_assert(V::rank() == 2, "Must be rank 2");

// Compute coefficients = Pi * x (matrix-vector multiply)
if (bs == 1)
Expand Down Expand Up @@ -392,10 +403,11 @@ void interpolate_same_map(Function<T, U>& u1, const Function<T, U>& u0,
}
}

/// Interpolate from one finite element Function to another on the same
/// mesh. The function is for cases where the finite element basis
/// functions for the two elements are mapped differently, e.g. one may
/// be Piola mapped and the other with a standard isoparametric map.
/// Interpolate from one finite element Function to another on the same mesh.
/// This interpolation function is for cases where the finite element basis
/// functions for the two elements are mapped differently, e.g. one may be
/// subject to a Piola mapping and the other to a standard isoparametric
/// mapping.
/// @param[out] u1 The function to interpolate to
/// @param[in] u0 The function to interpolate from
/// @param[in] cells The cells to interpolate on
Expand Down Expand Up @@ -461,60 +473,54 @@ void interpolate_nonmatching_maps(Function<T, U>& u1, const Function<T, U>& u0,
const std::size_t num_dofs_g = cmap.dim();
std::span<const U> x_g = mesh->geometry().x();

namespace stdex = std::experimental;
using cmdspan2_t = stdex::mdspan<const U, stdex::dextents<std::size_t, 2>>;
using cmdspan4_t = stdex::mdspan<const U, stdex::dextents<std::size_t, 4>>;
using mdspan2_t = stdex::mdspan<U, stdex::dextents<std::size_t, 2>>;
using mdspan3_t = stdex::mdspan<U, stdex::dextents<std::size_t, 3>>;
using mdspan3T_t = stdex::mdspan<T, stdex::dextents<std::size_t, 3>>;

// Evaluate coordinate map basis at reference interpolation points
const std::array<std::size_t, 4> phi_shape
= cmap.tabulate_shape(1, Xshape[0]);
std::vector<U> phi_b(
std::reduce(phi_shape.begin(), phi_shape.end(), 1, std::multiplies{}));
cmdspan4_t phi(phi_b.data(), phi_shape);
mdspan_t<const U, 4> phi(phi_b.data(), phi_shape);
cmap.tabulate(1, X, Xshape, phi_b);

// Evaluate v basis functions at reference interpolation points
const auto [_basis_derivatives_reference0, b0shape]
= element0->tabulate(X, Xshape, 0);
cmdspan4_t basis_derivatives_reference0(_basis_derivatives_reference0.data(),
b0shape);
mdspan_t<const U, 4> basis_derivatives_reference0(
_basis_derivatives_reference0.data(), b0shape);

// Create working arrays
std::vector<T> local1(element1->space_dimension());
std::vector<T> coeffs0(element0->space_dimension());

std::vector<U> basis0_b(Xshape[0] * dim0 * value_size0);
mdspan3_t basis0(basis0_b.data(), Xshape[0], dim0, value_size0);
impl::mdspan_t<U, 3> basis0(basis0_b.data(), Xshape[0], dim0, value_size0);

std::vector<U> basis_reference0_b(Xshape[0] * dim0 * value_size_ref0);
mdspan3_t basis_reference0(basis_reference0_b.data(), Xshape[0], dim0,
value_size_ref0);
impl::mdspan_t<U, 3> basis_reference0(basis_reference0_b.data(), Xshape[0],
dim0, value_size_ref0);

std::vector<T> values0_b(Xshape[0] * 1 * element1->value_size());
mdspan3T_t values0(values0_b.data(), Xshape[0], 1, element1->value_size());
impl::mdspan_t<T, 3> values0(values0_b.data(), Xshape[0], 1,
element1->value_size());

std::vector<T> mapped_values_b(Xshape[0] * 1 * element1->value_size());
mdspan3T_t mapped_values0(mapped_values_b.data(), Xshape[0], 1,
element1->value_size());
impl::mdspan_t<T, 3> mapped_values0(mapped_values_b.data(), Xshape[0], 1,
element1->value_size());

std::vector<U> coord_dofs_b(num_dofs_g * gdim);
mdspan2_t coord_dofs(coord_dofs_b.data(), num_dofs_g, gdim);
impl::mdspan_t<U, 2> coord_dofs(coord_dofs_b.data(), num_dofs_g, gdim);

std::vector<U> J_b(Xshape[0] * gdim * tdim);
mdspan3_t J(J_b.data(), Xshape[0], gdim, tdim);
impl::mdspan_t<U, 3> J(J_b.data(), Xshape[0], gdim, tdim);
std::vector<U> K_b(Xshape[0] * tdim * gdim);
mdspan3_t K(K_b.data(), Xshape[0], tdim, gdim);
impl::mdspan_t<U, 3> K(K_b.data(), Xshape[0], tdim, gdim);
std::vector<U> detJ(Xshape[0]);
std::vector<U> det_scratch(2 * gdim * tdim);

// Get interpolation operator
const auto [_Pi_1, pi_shape] = element1->interpolation_operator();
cmdspan2_t Pi_1(_Pi_1.data(), pi_shape);
impl::mdspan_t<const U, 2> Pi_1(_Pi_1.data(), pi_shape);

using u_t = stdex::mdspan<U, stdex::dextents<std::size_t, 2>>;
using u_t = impl::mdspan_t<U, 2>;
using U_t = stdex::mdspan<const U, stdex::dextents<std::size_t, 2>>;
using J_t = stdex::mdspan<const U, stdex::dextents<std::size_t, 2>>;
using K_t = stdex::mdspan<const U, stdex::dextents<std::size_t, 2>>;
Expand Down Expand Up @@ -638,67 +644,50 @@ void interpolate_nonmatching_meshes(
std::vector<U>, std::vector<std::int32_t>>&
nmm_interpolation_data)
{
int result;
namespace stdex = std::experimental;

auto mesh = u.function_space()->mesh();
auto mesh_v = v.function_space()->mesh();
MPI_Comm_compare(mesh->comm(), mesh_v->comm(), &result);
assert(mesh);
MPI_Comm comm = mesh->comm();

if (result == MPI_UNEQUAL)
throw std::runtime_error("Interpolation on different meshes is only "
"supported with the same communicator.");
{
auto mesh_v = v.function_space()->mesh();
assert(mesh_v);
int result;
MPI_Comm_compare(comm, mesh_v->comm(), &result);
if (result == MPI_UNEQUAL)
{
throw std::runtime_error("Interpolation on different meshes is only "
"supported with the same communicator.");
}
}

MPI_Comm comm = mesh->comm();
const int tdim = mesh->topology()->dim();
auto cell_map = mesh->topology()->index_map(tdim);
assert(mesh->topology());
auto cell_map = mesh->topology()->index_map(mesh->topology()->dim());
assert(cell_map);

auto element_u = u.function_space()->element();
assert(element_u);
const std::size_t value_size = element_u->value_size();

if (std::get<0>(nmm_interpolation_data).empty())
{
throw std::runtime_error(
"In order to interpolate on nonmatching meshes, the user needs to "
"provide the necessary interpolation data. This can be computed "
"with fem::create_nonmatching_meshes_interpolation_data.");
}

const std::tuple_element_t<
0, std::tuple<std::vector<std::int32_t>, std::vector<std::int32_t>,
std::vector<U>, std::vector<std::int32_t>>>& dest_ranks
= std::get<0>(nmm_interpolation_data);
const std::tuple_element_t<
1, std::tuple<std::vector<std::int32_t>, std::vector<std::int32_t>,
std::vector<U>, std::vector<std::int32_t>>>& src_ranks
= std::get<1>(nmm_interpolation_data);
const std::tuple_element_t<
2, std::tuple<std::vector<std::int32_t>, std::vector<std::int32_t>,
std::vector<U>, std::vector<std::int32_t>>>& received_points
= std::get<2>(nmm_interpolation_data);
const std::tuple_element_t<
3, std::tuple<std::vector<std::int32_t>, std::vector<std::int32_t>,
std::vector<U>, std::vector<std::int32_t>>>&
evaluation_cells
= std::get<3>(nmm_interpolation_data);
const auto& [dest_ranks, src_ranks, recv_points, evaluation_cells]
= nmm_interpolation_data;

// Evaluate the interpolating function where possible
std::vector<T> send_values(received_points.size() / 3 * value_size);
v.eval(received_points, {received_points.size() / 3, (std::size_t)3},
evaluation_cells, send_values,
{received_points.size() / 3, (std::size_t)value_size});
std::vector<T> send_values(recv_points.size() / 3 * value_size);
v.eval(recv_points, {recv_points.size() / 3, (std::size_t)3},
evaluation_cells, send_values, {recv_points.size() / 3, value_size});

// Send values back to owning process
std::array<std::size_t, 2> v_shape = {src_ranks.size(), value_size};
std::vector<T> values_b(dest_ranks.size() * value_size);
impl::scatter_values(comm, src_ranks, dest_ranks,
std::span<const T>(send_values), v_shape,
std::span<T>(values_b));
stdex::mdspan<const T, stdex::dextents<std::size_t, 2>> _send_values(
send_values.data(), src_ranks.size(), value_size);
impl::scatter_values(comm, src_ranks, dest_ranks, _send_values,
std::span(values_b));

// Transpose received data
namespace stdex = std::experimental;
stdex::mdspan<const T, stdex::dextents<std::size_t, 2>> values(
values_b.data(), dest_ranks.size(), value_size);

std::vector<T> valuesT_b(value_size * dest_ranks.size());
stdex::mdspan<T, stdex::dextents<std::size_t, 2>> valuesT(
valuesT_b.data(), value_size, dest_ranks.size());
Expand Down
58 changes: 2 additions & 56 deletions python/test/unit/fem/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,10 @@
import ufl
from basix.ufl import element, mixed_element
from dolfinx.fem import (Function, FunctionSpace, TensorFunctionSpace,
VectorFunctionSpace, assemble_scalar,
create_nonmatching_meshes_interpolation_data, form)
VectorFunctionSpace)
from dolfinx.geometry import (bb_tree, compute_colliding_cells,
compute_collisions_points)
from dolfinx.mesh import (CellType, create_mesh, create_unit_cube,
create_unit_square, locate_entities_boundary,
meshtags)
from dolfinx.mesh import create_mesh, create_unit_cube
from mpi4py import MPI

from dolfinx import default_real_type, la
Expand Down Expand Up @@ -174,57 +171,6 @@ def f(x):
assert round(w.x.norm(la.Norm.l1) - 6 * num_vertices, 7) == 0


@pytest.mark.parametrize("xtype", [np.float64])
@pytest.mark.parametrize("cell_type0", [CellType.hexahedron, CellType.tetrahedron])
@pytest.mark.parametrize("cell_type1", [CellType.triangle, CellType.quadrilateral])
def test_nonmatching_interpolation(xtype, cell_type0, cell_type1):
mesh0 = create_unit_cube(MPI.COMM_WORLD, 5, 6, 7, cell_type=cell_type0, dtype=xtype)
mesh1 = create_unit_square(MPI.COMM_WORLD, 25, 24, cell_type=cell_type1, dtype=xtype)

def f(x):
return (7 * x[1], 3 * x[0], x[2] + 0.4)

el0 = element("Lagrange", mesh0.basix_cell(), 1, shape=(3, ))
V0 = FunctionSpace(mesh0, el0)
el1 = element("Lagrange", mesh1.basix_cell(), 1, shape=(3, ))
V1 = FunctionSpace(mesh1, el1)

# Interpolate on 3D mesh
u0 = Function(V0, dtype=xtype)
u0.interpolate(f)
u0.x.scatter_forward()

# Interpolate 3D->2D
u1 = Function(V1, dtype=xtype)
u1.interpolate(u0, nmm_interpolation_data=create_nonmatching_meshes_interpolation_data(
u1.function_space.mesh._cpp_object,
u1.function_space.element,
u0.function_space.mesh._cpp_object))
u1.x.scatter_forward()

# Exact interpolation on 2D mesh
u1_ex = Function(V1, dtype=xtype)
u1_ex.interpolate(f)
u1_ex.x.scatter_forward()

assert np.allclose(u1_ex.x.array, u1.x.array, rtol=1.0e-4, atol=1.0e-6)

# Interpolate 2D->3D
u0_2 = Function(V0, dtype=xtype)
u0_2.interpolate(u1, nmm_interpolation_data=create_nonmatching_meshes_interpolation_data(
u0_2.function_space.mesh._cpp_object,
u0_2.function_space.element,
u1.function_space.mesh._cpp_object))

# Check that function values over facets of 3D mesh of the twice interpolated property is preserved
def locate_bottom_facets(x):
return np.isclose(x[2], 0)
facets = locate_entities_boundary(mesh0, mesh0.topology.dim - 1, locate_bottom_facets)
facet_tag = meshtags(mesh0, mesh0.topology.dim - 1, facets, np.full(len(facets), 1, dtype=np.int32))
residual = ufl.inner(u0 - u0_2, u0 - u0_2) * ufl.ds(domain=mesh0, subdomain_data=facet_tag, subdomain_id=1)
assert np.isclose(assemble_scalar(form(residual, dtype=xtype)), 0)


@pytest.mark.parametrize("types", [
# (np.float32, "float"), # Fails on Redhat CI, needs further investigation
(np.float64, "double")
Expand Down
Loading

0 comments on commit 200e94c

Please sign in to comment.