Skip to content

Commit

Permalink
Wrap create_nonmatching_meshes_interpolation_data in Python (#3039)
Browse files Browse the repository at this point in the history
* wrap create_nonmatching_meshes_interpolation_data

* missing imports

* _

* fix

* typo

* fix typing

* Geometry typing

* ]

* return

* padding=

* tidy

* typing and docstring

* Use const spans in C++ for interpolation data.
Add name-tuple documenting returntype of create_nonmatching_meshes_interpolation_data.
Use views into nb-array to pass to C++.
Return ndarrays from C++

* Fix documentation of interpolate + Python interface for normal interpolation

* Wrap empty constructor as namedtuple

* Fix capitalization and docstring

* Fix import order

* Sort import

---------

Co-authored-by: Jørgen Schartum Dokken <[email protected]>
Co-authored-by: Jorgen S. Dokken <[email protected]>
  • Loading branch information
3 people authored Feb 20, 2024
1 parent e3fb2a1 commit 641e839
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 38 deletions.
14 changes: 8 additions & 6 deletions cpp/dolfinx/fem/Function.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,10 @@ class Function
void interpolate(
const Function<value_type, geometry_type>& v,
std::span<const std::int32_t> cells,
const std::tuple<std::vector<std::int32_t>, std::vector<std::int32_t>,
std::vector<geometry_type>, std::vector<std::int32_t>>&
nmm_interpolation_data
const std::tuple<std::span<const std::int32_t>,
std::span<const std::int32_t>,
std::span<const geometry_type>,
std::span<const std::int32_t>>& nmm_interpolation_data
= {})
{
fem::interpolate(*this, v, cells, nmm_interpolation_data);
Expand All @@ -174,9 +175,10 @@ class Function
/// generate_nonmatching_meshes_interpolation_data (optional).
void interpolate(
const Function<value_type, geometry_type>& v,
const std::tuple<std::vector<std::int32_t>, std::vector<std::int32_t>,
std::vector<geometry_type>, std::vector<std::int32_t>>&
nmm_interpolation_data
const std::tuple<std::span<const std::int32_t>,
std::span<const std::int32_t>,
std::span<const geometry_type>,
std::span<const std::int32_t>>& nmm_interpolation_data
= {})
{
assert(_function_space);
Expand Down
12 changes: 6 additions & 6 deletions cpp/dolfinx/fem/interpolate.h
Original file line number Diff line number Diff line change
Expand Up @@ -671,9 +671,9 @@ template <dolfinx::scalar T, std::floating_point U>
void interpolate_nonmatching_meshes(
Function<T, U>& u, const Function<T, U>& v,
std::span<const std::int32_t> cells,
const std::tuple<std::vector<std::int32_t>, std::vector<std::int32_t>,
std::vector<U>, std::vector<std::int32_t>>&
nmm_interpolation_data)
const std::tuple<std::span<const std::int32_t>,
std::span<const std::int32_t>, std::span<const U>,
std::span<const std::int32_t>>& nmm_interpolation_data)
{
auto mesh = u.function_space()->mesh();
assert(mesh);
Expand Down Expand Up @@ -1127,9 +1127,9 @@ template <dolfinx::scalar T, std::floating_point U>
void interpolate(
Function<T, U>& u, const Function<T, U>& v,
std::span<const std::int32_t> cells,
const std::tuple<std::vector<std::int32_t>, std::vector<std::int32_t>,
std::vector<U>, std::vector<std::int32_t>>&
nmm_interpolation_data
const std::tuple<std::span<const std::int32_t>,
std::span<const std::int32_t>, std::span<const U>,
std::span<const std::int32_t>>& nmm_interpolation_data
= {})
{
assert(u.function_space());
Expand Down
49 changes: 46 additions & 3 deletions python/dolfinx/fem/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,21 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Tools for assembling and manipulating finite element forms."""

import typing

import numpy as np
import numpy.typing as npt

from dolfinx.cpp.fem import FiniteElement_float32 as _FiniteElement_float32
from dolfinx.cpp.fem import FiniteElement_float64 as _FiniteElement_float64
from dolfinx.cpp.fem import IntegralType, transpose_dofmap
from dolfinx.cpp.fem import (
IntegralType,
create_nonmatching_meshes_interpolation_data,
transpose_dofmap,
create_nonmatching_meshes_interpolation_data as _create_nonmatching_meshes_interpolation_data,
)
from dolfinx.cpp.fem import create_sparsity_pattern as _create_sparsity_pattern
from dolfinx.cpp.fem import discrete_gradient as _discrete_gradient
from dolfinx.cpp.mesh import Geometry_float32 as _Geometry_float32
from dolfinx.cpp.mesh import Geometry_float64 as _Geometry_float64
from dolfinx.fem.assemble import (
apply_lifting,
assemble_matrix,
Expand All @@ -37,9 +45,11 @@
Expression,
Function,
FunctionSpace,
PointOwnershipData,
functionspace,
)
from dolfinx.la import MatrixCSR as _MatrixCSR
from dolfinx.mesh import Mesh as _Mesh


def create_sparsity_pattern(a: Form):
Expand All @@ -59,6 +69,39 @@ def create_sparsity_pattern(a: Form):
return _create_sparsity_pattern(a._cpp_object)


def create_nonmatching_meshes_interpolation_data(
mesh_to: typing.Union[_Mesh, _Geometry_float64, _Geometry_float32],
element: typing.Union[_FiniteElement_float32, _FiniteElement_float64],
mesh_from: _Mesh,
cells: typing.Optional[npt.NDArray[np.int32]] = None,
padding: float = 1e-14,
) -> PointOwnershipData:
"""Generate data needed to interpolate discrete functions across different meshes.
Args:
mesh_to: Mesh or geometry of the mesh of the function space to interpolate into
element: Element of the function space to interpolate into
mesh_from: Mesh that the function to interpolate from is defined on
cells: Indices of the cells in the destination mesh on which to interpolate.
padding: Absolute padding of bounding boxes of all entities on mesh_to
Returns:
Data needed to interpolation functions defined on function spaces on the meshes.
"""
if cells is None:
return PointOwnershipData(
*_create_nonmatching_meshes_interpolation_data(
mesh_to._cpp_object, element, mesh_from._cpp_object, padding
)
)
else:
return PointOwnershipData(
*_create_nonmatching_meshes_interpolation_data(
mesh_to, element, mesh_from._cpp_object, cells, padding
)
)


def discrete_gradient(space0: FunctionSpace, space1: FunctionSpace) -> _MatrixCSR:
"""Assemble a discrete gradient operator.
Expand Down
27 changes: 26 additions & 1 deletion python/dolfinx/fem/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@
from dolfinx.mesh import Mesh


class PointOwnershipData(typing.NamedTuple):
"""Convenience class for storing data related to the ownership of points.
Attributes:
src_owner: Ranks owning each point sent into ownership determination for current process
dest_owners: Ranks that sent `dest_points` to current process
dest_points: Points owned by current rank
dest_cells: Cell indices (local to process) where each entry of `dest_points` is located
"""

src_owner: npt.NDArray[np.int32]
dest_owners: npt.NDArray[np.int32]
dest_points: npt.NDArray[np.floating]
dest_cells: npt.NDArray[np.int32]


class Constant(ufl.Constant):
_cpp_object: typing.Union[
_cpp.fem.Constant_complex64,
Expand Down Expand Up @@ -388,15 +404,24 @@ def interpolate(
self,
u: typing.Union[typing.Callable, Expression, Function],
cells: typing.Optional[np.ndarray] = None,
nmm_interpolation_data=((), (), (), ()),
nmm_interpolation_data: typing.Optional[PointOwnershipData] = None,
) -> None:
"""Interpolate an expression
Args:
u: The function, Expression or Function to interpolate.
cells: The cells to interpolate over. If `None` then all
cells are interpolated over.
nmm_interpolation_data: Data needed to interpolate functions defined on other meshes
"""
if nmm_interpolation_data is None:
x_dtype = self.function_space.mesh.geometry.x.dtype
nmm_interpolation_data = PointOwnershipData(
src_owner=np.empty(0, dtype=np.int32),
dest_owners=np.empty(0, dtype=np.int32),
dest_points=np.empty(0, dtype=x_dtype),
dest_cells=np.empty(0, dtype=np.int32),
)

@singledispatch
def _interpolate(u, cells: typing.Optional[np.ndarray] = None):
Expand Down
48 changes: 38 additions & 10 deletions python/dolfinx/wrappers/fem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,30 @@ void declare_objects(nb::module_& m, const std::string& type)
[](dolfinx::fem::Function<T, U>& self,
dolfinx::fem::Function<T, U>& u,
nb::ndarray<const std::int32_t, nb::ndim<1>, nb::c_contig> cells,
const std::tuple<std::vector<std::int32_t>,
std::vector<std::int32_t>, std::vector<U>,
std::vector<std::int32_t>>& interpolation_data)
const std::tuple<
nb::ndarray<const std::int32_t, nb::ndim<1>, nb::c_contig>,
nb::ndarray<const std::int32_t, nb::ndim<1>, nb::c_contig>,
nb::ndarray<const U, nb::ndim<1>, nb::c_contig>,
nb::ndarray<const std::int32_t, nb::ndim<1>, nb::c_contig>>&
interpolation_data)
{
std::tuple<std::span<const std::int32_t>,
std::span<const std::int32_t>, std::span<const U>,
std::span<const std::int32_t>>
_interpolation_data(
std::span<const std::int32_t>(
std::get<0>(interpolation_data).data(),
std::get<0>(interpolation_data).size()),
std::span<const std::int32_t>(
std::get<1>(interpolation_data).data(),
std::get<1>(interpolation_data).size()),
std::span<const U>(std::get<2>(interpolation_data).data(),
std::get<2>(interpolation_data).size()),
std::span<const std::int32_t>(
std::get<3>(interpolation_data).data(),
std::get<3>(interpolation_data).size()));
self.interpolate(u, std::span(cells.data(), cells.size()),
interpolation_data);
_interpolation_data);
},
nb::arg("u"), nb::arg("cells"), nb::arg("nmm_interpolation_data"),
"Interpolate a finite element function")
Expand Down Expand Up @@ -925,9 +943,14 @@ void declare_real_functions(nb::module_& m)
= cell_map->size_local() + cell_map->num_ghosts();
std::vector<std::int32_t> cells(num_cells, 0);
std::iota(cells.begin(), cells.end(), 0);
return dolfinx::fem::create_nonmatching_meshes_interpolation_data(
mesh0.geometry(), element0, mesh1,
std::span(cells.data(), cells.size()), padding);
auto [src_owner, dest_owner, dest_points, dest_cells]
= dolfinx::fem::create_nonmatching_meshes_interpolation_data(
mesh0.geometry(), element0, mesh1,
std::span(cells.data(), cells.size()), padding);
return std::tuple(dolfinx_wrappers::as_nbarray(std::move(src_owner)),
dolfinx_wrappers::as_nbarray(std::move(dest_owner)),
dolfinx_wrappers::as_nbarray(std::move(dest_points)),
dolfinx_wrappers::as_nbarray(std::move(dest_cells)));
},
nb::arg("mesh0"), nb::arg("element0"), nb::arg("mesh1"),
nb::arg("padding"));
Expand All @@ -939,9 +962,14 @@ void declare_real_functions(nb::module_& m)
nb::ndarray<const std::int32_t, nb::ndim<1>, nb::c_contig> cells,
T padding)
{
return dolfinx::fem::create_nonmatching_meshes_interpolation_data(
geometry0, element0, mesh1, std::span(cells.data(), cells.size()),
padding);
auto [src_owner, dest_owner, dest_points, dest_cells]
= dolfinx::fem::create_nonmatching_meshes_interpolation_data(
geometry0, element0, mesh1,
std::span(cells.data(), cells.size()), padding);
return std::tuple(dolfinx_wrappers::as_nbarray(std::move(src_owner)),
dolfinx_wrappers::as_nbarray(std::move(dest_owner)),
dolfinx_wrappers::as_nbarray(std::move(dest_points)),
dolfinx_wrappers::as_nbarray(std::move(dest_cells)));
},
nb::arg("geometry0"), nb::arg("element0"), nb::arg("mesh1"),
nb::arg("cells"), nb ::arg("padding"));
Expand Down
18 changes: 6 additions & 12 deletions python/test/unit/fem/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,9 +887,9 @@ def f(x):
u1.interpolate(
u0,
nmm_interpolation_data=create_nonmatching_meshes_interpolation_data(
u1.function_space.mesh._cpp_object,
u1.function_space.mesh,
u1.function_space.element,
u0.function_space.mesh._cpp_object,
u0.function_space.mesh,
padding=padding,
),
)
Expand All @@ -907,9 +907,9 @@ def f(x):
u0_2.interpolate(
u1,
nmm_interpolation_data=create_nonmatching_meshes_interpolation_data(
u0_2.function_space.mesh._cpp_object,
u0_2.function_space.mesh,
u0_2.function_space.element,
u1.function_space.mesh._cpp_object,
u1.function_space.mesh,
padding=padding,
),
)
Expand Down Expand Up @@ -964,10 +964,7 @@ def f_test1(x):
u1.x.scatter_forward()
padding = 1e-14
u1_2_u2_nmm_data = create_nonmatching_meshes_interpolation_data(
u2.function_space.mesh._cpp_object,
u2.function_space.element,
u1.function_space.mesh._cpp_object,
padding=padding,
u2.function_space.mesh, u2.function_space.element, u1.function_space.mesh, padding=padding
)

u2.interpolate(u1, nmm_interpolation_data=u1_2_u2_nmm_data)
Expand All @@ -992,10 +989,7 @@ def f_test2(x):
u2.x.scatter_forward()
padding = 1e-14
u2_2_u1_nmm_data = create_nonmatching_meshes_interpolation_data(
u1.function_space.mesh._cpp_object,
u1.function_space.element,
u2.function_space.mesh._cpp_object,
padding,
u1.function_space.mesh, u1.function_space.element, u2.function_space.mesh, padding=padding
)

u1.interpolate(u2, nmm_interpolation_data=u2_2_u1_nmm_data)
Expand Down

0 comments on commit 641e839

Please sign in to comment.