Skip to content

Commit

Permalink
Data independent form compilation for Python interface (#3263)
Browse files Browse the repository at this point in the history
* Create mesh independent form compilator for Python interface

* Ruff formatting

* Mypy

* Parameterize test over all dtypes

* Ruff format

* Fix compile options handling

* Fix coeff and constant name mapping

* Ruff

* Remove extra definition of basix element and check that mesh c_el is consistent

* Add support for linear and bilinear forms

* Fix typo

* Apply suggestions from code review

Co-authored-by: Chris Richardson <[email protected]>

* Apply suggestions from code review

Co-authored-by: Chris Richardson <[email protected]>

* Apply suggestions from code review

* Apply suggestions from code review

Co-authored-by: Igor Baratta <[email protected]>

* Apply suggestions from code review

* Add documentation and default scalar type

---------

Co-authored-by: Chris Richardson <[email protected]>
Co-authored-by: Igor Baratta <[email protected]>
  • Loading branch information
3 people authored Jun 25, 2024
1 parent f72fcb9 commit 17c71a3
Show file tree
Hide file tree
Showing 4 changed files with 297 additions and 28 deletions.
11 changes: 10 additions & 1 deletion python/dolfinx/fem/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,14 @@
)
from dolfinx.fem.dofmap import DofMap
from dolfinx.fem.element import CoordinateElement, coordinate_element
from dolfinx.fem.forms import Form, extract_function_spaces, form, form_cpp_class
from dolfinx.fem.forms import (
Form,
compile_form,
create_form,
extract_function_spaces,
form,
form_cpp_class,
)
from dolfinx.fem.function import (
Constant,
ElementMetaData,
Expand Down Expand Up @@ -148,4 +155,6 @@ def interpolation_matrix(space0: FunctionSpace, space1: FunctionSpace) -> _Matri
"CoordinateElement",
"coordinate_element",
"form_cpp_class",
"create_form",
"compile_form",
]
222 changes: 195 additions & 27 deletions python/dolfinx/fem/forms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2017-2023 Chris N. Richardson, Garth N. Wells and Michal Habera
# Copyright (C) 2017-2024 Chris N. Richardson, Garth N. Wells, Michal Habera and Jørgen S. Dokken
#
# This file is part of DOLFINx (https://www.fenicsproject.org)
#
Expand All @@ -8,20 +8,24 @@

import collections
import typing
from dataclasses import dataclass
from itertools import chain

import numpy as np
import numpy.typing as npt

import ffcx
import ufl
from dolfinx import cpp as _cpp
from dolfinx import default_scalar_type, jit
from dolfinx.fem import IntegralType
from dolfinx.fem.function import FunctionSpace

if typing.TYPE_CHECKING:
from mpi4py import MPI

from dolfinx.fem import function
from dolfinx.mesh import Mesh
from dolfinx.mesh import Mesh, MeshTags


class Form:
Expand Down Expand Up @@ -96,6 +100,54 @@ def integral_types(self):
return self._cpp_object.integral_types


def get_integration_domains(
integral_type: IntegralType,
subdomain: typing.Optional[typing.Union[MeshTags, list[tuple[int, np.ndarray]]]],
subdomain_ids: list[int],
) -> list[tuple[int, np.ndarray]]:
"""Get integration domains from subdomain data.
The subdomain data is a meshtags object consisting of markers, or a None object.
If it is a None object we do not pack any integration entities.
Integration domains are defined as a list of tuples, where each input `subdomain_ids`
is mapped to an array of integration entities, where an integration entity for a cell
integral is the list of cells.
For an exterior facet integral each integration entity is a
tuple (cell_index, local_facet_index).
For an interior facet integral each integration entity is a
uple (cell_index0, local_facet_index0,
cell_index1, local_facet_index1). Where the first cell-facet pair is
the '+' restriction, the second the '-' restriction.
Args:
integral_type: The type of integral to pack integration entitites for
subdomain: A meshtag with markers or manually specified integration domains.
subdomain_ids: List of ids to integrate over
"""
if subdomain is None:
return []
else:
domains = []
try:
if integral_type in (IntegralType.exterior_facet, IntegralType.interior_facet):
tdim = subdomain.topology.dim # type: ignore
subdomain._cpp_object.topology.create_connectivity(tdim - 1, tdim) # type: ignore
subdomain._cpp_object.topology.create_connectivity(tdim, tdim - 1) # type: ignore
# Compute integration domains only for each subdomain id in the integrals
# If a process has no integral entities, insert an empty array
for id in subdomain_ids:
integration_entities = _cpp.fem.compute_integration_domains(
integral_type,
subdomain._cpp_object.topology, # type: ignore
subdomain.find(id), # type: ignore
subdomain.dim, # type: ignore
)
domains.append((id, integration_entities))
return [(s[0], np.array(s[1])) for s in domains]
except AttributeError:
return [(s[0], np.array(s[1])) for s in subdomain] # type: ignore


def form_cpp_class(
dtype: npt.DTypeLike,
) -> typing.Union[
Expand Down Expand Up @@ -222,31 +274,6 @@ def _form(form):
flattened_ids.sort()
subdomain_ids[itg_type] = flattened_ids

def get_integration_domains(integral_type, subdomain, subdomain_ids):
"""Get integration domains from subdomain data"""
if subdomain is None:
return []
else:
domains = []
try:
if integral_type in (IntegralType.exterior_facet, IntegralType.interior_facet):
tdim = subdomain.topology.dim
subdomain._cpp_object.topology.create_connectivity(tdim - 1, tdim)
subdomain._cpp_object.topology.create_connectivity(tdim, tdim - 1)
# Compute integration domains only for each subdomain id in the integrals
# If a process has no integral entities, insert an empty array
for id in subdomain_ids:
integration_entities = _cpp.fem.compute_integration_domains(
integral_type,
subdomain._cpp_object.topology,
subdomain.find(id),
subdomain.dim,
)
domains.append((id, integration_entities))
return [(s[0], np.array(s[1])) for s in domains]
except AttributeError:
return [(s[0], np.array(s[1])) for s in subdomain]

# Subdomain markers (possibly empty list for some integral types)
subdomains = {
_ufl_to_dolfinx_domain[key]: get_integration_domains(
Expand Down Expand Up @@ -333,3 +360,144 @@ def unique_spaces(V):
return list(unique_spaces(V.transpose()))

raise RuntimeError("Unsupported array of forms")


@dataclass
class CompiledForm:
"""
Compiled UFL form without associated DOLFINx data
"""

ufl_form: ufl.Form # The original ufl form
ufcx_form: typing.Any # The compiled form
module: typing.Any # The module
code: str # The source code
dtype: npt.DTypeLike # data type used for the `ufcx_form`


def compile_form(
comm: MPI.Intracomm,
form: ufl.Form,
form_compiler_options: typing.Optional[dict] = {"scalar_type": default_scalar_type},
jit_options: typing.Optional[dict] = None,
) -> CompiledForm:
"""Compile UFL form without associated DOLFINx data
Args:
comm: The MPI communicator used when compiling the form
form: The UFL form to compile
form_compiler_options: See :func:`ffcx_jit <dolfinx.jit.ffcx_jit>`
jit_options: See :func:`ffcx_jit <dolfinx.jit.ffcx_jit>`.
"""
p_ffcx = ffcx.get_options(form_compiler_options)
p_jit = jit.get_options(jit_options)
ufcx_form, module, code = jit.ffcx_jit(comm, form, p_ffcx, p_jit)
return CompiledForm(form, ufcx_form, module, code, p_ffcx["scalar_type"])


def form_cpp_creator(
dtype: npt.DTypeLike,
) -> typing.Union[
_cpp.fem.Form_float32,
_cpp.fem.Form_float64,
_cpp.fem.Form_complex64,
_cpp.fem.Form_complex128,
]:
"""Return the wrapped C++ constructor for creating a variational form of a specific scalar type.
Args:
dtype: Scalar type of the required form class.
Returns:
Wrapped C++ form class of the requested type.
Note:
This function is for advanced usage, typically when writing
custom kernels using Numba or C.
"""
if np.issubdtype(dtype, np.float32):
return _cpp.fem.create_form_float32
elif np.issubdtype(dtype, np.float64):
return _cpp.fem.create_form_float64
elif np.issubdtype(dtype, np.complex64):
return _cpp.fem.create_form_complex64
elif np.issubdtype(dtype, np.complex128):
return _cpp.fem.create_form_complex128
else:
raise NotImplementedError(f"Type {dtype} not supported.")


def create_form(
form: CompiledForm,
function_spaces: list[function.FunctionSpace],
mesh: Mesh,
coefficient_map: dict[ufl.Function, function.Function],
constant_map: dict[ufl.Constant, function.Constant],
) -> Form:
"""
Create a Form object from a data-independent compiled form
Args:
form: Compiled ufl form
function_spaces: List of function spaces associated with the form.
Should match the number of arguments in the form.
mesh: Mesh to associate form with
coefficient_map: Map from UFL coefficient to function with data
constant_map: Map from UFL constant to constant with data
"""
sd = form.ufl_form.subdomain_data()
(domain,) = list(sd.keys()) # Assuming single domain

# Make map from integral_type to subdomain id
subdomain_ids: dict[IntegralType, list[list[int]]] = {
type: [] for type in sd.get(domain).keys()
}
flattened_subdomain_ids: dict[IntegralType, list[int]] = {
type: [] for type in sd.get(domain).keys()
}
for integral in form.ufl_form.integrals():
if integral.subdomain_data() is not None:
# Subdomain ids can be strings, its or tuples with strings and ints
if integral.subdomain_id() != "everywhere":
try:
ids = [sid for sid in integral.subdomain_id() if sid != "everywhere"]
except TypeError:
# If not tuple, but single integer id
ids = [integral.subdomain_id()]
else:
ids = []
subdomain_ids[integral.integral_type()].append(ids) # type: ignore

# Chain and sort subdomain ids
for itg_type, marker_ids in subdomain_ids.items():
flattened_ids: list[int] = list(chain.from_iterable(marker_ids))
flattened_ids.sort()
flattened_subdomain_ids[itg_type] = flattened_ids

# Subdomain markers (possibly empty list for some integral types)
subdomains = {
_ufl_to_dolfinx_domain[key]: get_integration_domains(
_ufl_to_dolfinx_domain[key], subdomain_data[0], flattened_subdomain_ids[key]
)
for (key, subdomain_data) in sd.get(domain).items()
}

# Extract name of ufl objects and map them to their corresponding C++ object
ufl_coefficients = ufl.algorithms.extract_coefficients(form.ufl_form)
coefficients = {
f"w{ufl_coefficients.index(u)}": uh._cpp_object for (u, uh) in coefficient_map.items()
}
ufl_constants = ufl.algorithms.analysis.extract_constants(form.ufl_form)
constants = {f"c{ufl_constants.index(u)}": uh._cpp_object for (u, uh) in constant_map.items()}

ftype = form_cpp_creator(form.dtype)
f = ftype(
form.module.ffi.cast("uintptr_t", form.module.ffi.addressof(form.ufcx_form)),
[fs._cpp_object for fs in function_spaces],
coefficients,
constants,
subdomains,
mesh._cpp_object,
)
return Form(f, form.ufcx_form, form.code)
36 changes: 36 additions & 0 deletions python/dolfinx/wrappers/fem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,42 @@ void declare_form(nb::module_& m, std::string type)
nb::arg("form"), nb::arg("spaces"), nb::arg("coefficients"),
nb::arg("constants"), nb::arg("subdomains"), nb::arg("mesh"),
"Create Form from a pointer to ufcx_form.");
m.def(
pymethod_create_form.c_str(),
[](std::uintptr_t form,
const std::vector<
std::shared_ptr<const dolfinx::fem::FunctionSpace<U>>>& spaces,
const std::map<std::string,
std::shared_ptr<const dolfinx::fem::Function<T, U>>>&
coefficients,
const std::map<std::string,
std::shared_ptr<const dolfinx::fem::Constant<T>>>&
constants,
const std::map<dolfinx::fem::IntegralType,
std::vector<std::pair<std::int32_t,
std::span<const std::int32_t>>>>&
subdomains,
std::shared_ptr<const dolfinx::mesh::Mesh<U>> mesh = nullptr)
{
std::map<
dolfinx::fem::IntegralType,
std::vector<std::pair<std::int32_t, std::span<const std::int32_t>>>>
sd;
for (auto& [itg, data] : subdomains)
{
std::vector<std::pair<std::int32_t, std::span<const std::int32_t>>> x;
for (auto& [id, idx] : data)
x.emplace_back(id, std::span(idx.data(), idx.size()));
sd.insert({itg, std::move(x)});
}

ufcx_form* p = reinterpret_cast<ufcx_form*>(form);
return dolfinx::fem::create_form<T, U>(*p, spaces, coefficients,
constants, sd, mesh);
},
nb::arg("form"), nb::arg("spaces"), nb::arg("coefficients"),
nb::arg("constants"), nb::arg("subdomains"), nb::arg("mesh"),
"Create Form from a pointer to ufcx_form.");

m.def("create_sparsity_pattern",
&dolfinx::fem ::create_sparsity_pattern<T, U>, nb::arg("a"),
Expand Down
56 changes: 56 additions & 0 deletions python/test/unit/fem/test_assemble_mesh_independent_form.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from mpi4py import MPI

import numpy as np
import pytest

import basix.ufl
import dolfinx
import ufl


@pytest.mark.parametrize(
"dtype",
[
np.float32,
np.float64,
pytest.param(np.complex64, marks=pytest.mark.xfail_win32_complex),
pytest.param(np.complex128, marks=pytest.mark.xfail_win32_complex),
],
)
def test_compiled_form(dtype):
"""
Compile a form without an associated mesh and assemble a form over a sequence of meshes
"""
real_type = dtype(0).real.dtype
c_el = basix.ufl.element("Lagrange", "triangle", 1, shape=(2,), dtype=real_type)
domain = ufl.Mesh(c_el)
el = basix.ufl.element("Lagrange", "triangle", 2, dtype=real_type)
V = ufl.FunctionSpace(domain, el)
u = ufl.Coefficient(V)
w = ufl.Coefficient(V)
c = ufl.Constant(domain)
e = ufl.Constant(domain)
J = c * e * u * w * ufl.dx(domain=domain)

# Compile form using dolfinx.jit.ffcx_jit
compiled_form = dolfinx.fem.compile_form(
MPI.COMM_WORLD, J, form_compiler_options={"scalar_type": dtype}
)

def create_and_integrate(N, compiled_form):
mesh = dolfinx.mesh.create_unit_square(MPI.COMM_WORLD, N, N, dtype=real_type)
assert mesh.ufl_domain().ufl_coordinate_element() == c_el
Vh = dolfinx.fem.functionspace(mesh, u.ufl_element())
uh = dolfinx.fem.Function(Vh, dtype=dtype)
uh.interpolate(lambda x: x[0])
wh = dolfinx.fem.Function(Vh, dtype=dtype)
wh.interpolate(lambda x: x[1])
eh = dolfinx.fem.Constant(mesh, dtype(3.0))
ch = dolfinx.fem.Constant(mesh, dtype(2.0))
form = dolfinx.fem.create_form(compiled_form, [], mesh, {u: uh, w: wh}, {c: ch, e: eh})
assert np.isclose(mesh.comm.allreduce(dolfinx.fem.assemble_scalar(form), op=MPI.SUM), 1.5)

# Create various meshes, that all uses this compiled form with a map from ufl
# to dolfinx functions and constants
for i in range(1, 4):
create_and_integrate(i, compiled_form)

0 comments on commit 17c71a3

Please sign in to comment.