Skip to content

Commit

Permalink
Merge branch 'master' into pbrubeck/value_shape
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Nov 13, 2024
2 parents f91615c + a59b15f commit 4aa725f
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 97 deletions.
157 changes: 81 additions & 76 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from firedrake.utils import ScalarType, assert_empty, tuplify
from pyop2 import op2
from pyop2.exceptions import MapValueError, SparsityFormatError
from pyop2.types.mat import _GlobalMatPayload, _DatMatPayload
from pyop2.utils import cached_property


Expand Down Expand Up @@ -965,22 +966,24 @@ def assemble(self, tensor=None):
Result of assembly: `float` for 0-forms, `firedrake.cofunction.Cofunction` or `firedrake.function.Function` for 1-forms, and `matrix.MatrixBase` for 2-forms.
"""
self._check_tensor(tensor)
if tensor is None:
tensor = self.allocate()
needs_zeroing = False
else:
needs_zeroing = self._needs_zeroing
if annotate_tape():
raise NotImplementedError(
"Taping with explicit FormAssembler objects is not supported yet. "
"Use assemble instead."
)
if needs_zeroing:
type(self)._as_pyop2_type(tensor).zero()

if tensor is None:
tensor = self.allocate()
else:
self._check_tensor(tensor)
if self._needs_zeroing:
self._as_pyop2_type(tensor).zero()

self.execute_parloops(tensor)

for bc in self._bcs:
self._apply_bc(tensor, bc)

return self.result(tensor)

@abc.abstractmethod
Expand All @@ -992,9 +995,9 @@ def _check_tensor(self, tensor):
"""Check input tensor."""

@staticmethod
def _as_pyop2_type(tensor):
"""Return tensor as pyop2 type."""
raise NotImplementedError
@abc.abstractmethod
def _as_pyop2_type(tensor, indices=None):
"""Cast a Firedrake tensor into a PyOP2 data structure, optionally indexing it."""

def execute_parloops(self, tensor):
for parloop in self.parloops(tensor):
Expand All @@ -1003,29 +1006,27 @@ def execute_parloops(self, tensor):
def parloops(self, tensor):
if hasattr(self, "_parloops"):
for (lknl, _), parloop in zip(self.local_kernels, self._parloops):
data = _FormHandler.index_tensor(tensor, self._form, lknl.indices, self.diagonal)
data = self._as_pyop2_type(tensor, lknl.indices)
parloop.arguments[0].data = data

else:
# Make parloops for one concrete output tensor and cache them.
# TODO: Make parloops only with some symbolic information of the output tensor.
self._parloops = tuple(parloop_builder.build(tensor) for parloop_builder in self.parloop_builders)
return self._parloops

@cached_property
def parloop_builders(self):
out = []
for local_kernel, subdomain_id in self.local_kernels:
out.append(
ParloopBuilder(
parloops_ = []
for local_kernel, subdomain_id in self.local_kernels:
parloop_builder = ParloopBuilder(
self._form,
self._bcs,
local_kernel,
subdomain_id,
self.all_integer_subdomain_ids,
diagonal=self.diagonal,
)
)
return tuple(out)
pyop2_tensor = self._as_pyop2_type(tensor, local_kernel.indices)
parloop = parloop_builder.build(pyop2_tensor)
parloops_.append(parloop)
self._parloops = tuple(parloops_)

return self._parloops

@cached_property
def local_kernels(self):
Expand Down Expand Up @@ -1120,10 +1121,11 @@ def _apply_bc(self, tensor, bc):
pass

def _check_tensor(self, tensor):
assert tensor is None
pass

@staticmethod
def _as_pyop2_type(tensor):
def _as_pyop2_type(tensor, indices=None):
assert not indices
return tensor

def result(self, tensor):
Expand Down Expand Up @@ -1198,15 +1200,16 @@ def _apply_dirichlet_bc(self, tensor, bc):
bc.zero(tensor)

def _check_tensor(self, tensor):
rank = len(self._form.arguments())
if rank == 1:
test, = self._form.arguments()
if tensor is not None and test.function_space() != tensor.function_space():
raise ValueError("Form's argument does not match provided result tensor")
if tensor.function_space() != self._form.arguments()[0].function_space():
raise ValueError("Form's argument does not match provided result tensor")

@staticmethod
def _as_pyop2_type(tensor):
return tensor.dat
def _as_pyop2_type(tensor, indices=None):
if indices is not None and any(index is not None for index in indices):
i, = indices
return tensor.dat[i]
else:
return tensor.dat

def execute_parloops(self, tensor):
# We are repeatedly incrementing into the same Dat so intermediate halo exchanges
Expand Down Expand Up @@ -1454,12 +1457,26 @@ def _apply_bcs_mat_real_block(op2tensor, i, j, component, node_set):
dat.zero(subset=node_set)

def _check_tensor(self, tensor):
if tensor is not None and tensor.a.arguments() != self._form.arguments():
if tensor.a.arguments() != self._form.arguments():
raise ValueError("Form's arguments do not match provided result tensor")

@staticmethod
def _as_pyop2_type(tensor):
return tensor.M
def _as_pyop2_type(tensor, indices=None):
if indices is not None and any(index is not None for index in indices):
i, j = indices
mat = tensor.M[i, j]
else:
mat = tensor.M

if mat.handle.getType() == "python":
mat_context = mat.handle.getPythonContext()
if isinstance(mat_context, _GlobalMatPayload):
mat = mat_context.global_
else:
assert isinstance(mat_context, _DatMatPayload)
mat = mat_context.dat

return mat

def result(self, tensor):
tensor.M.assemble()
Expand All @@ -1471,7 +1488,7 @@ class MatrixFreeAssembler(FormAssembler):
Parameters
----------
form : ufl.Form or slate.TensorBasehe
form : ufl.Form or slate.TensorBase
2-form.
Notes
Expand All @@ -1498,14 +1515,15 @@ def allocate(self):
appctx=self._appctx or {})

def assemble(self, tensor=None):
self._check_tensor(tensor)
if tensor is None:
tensor = self.allocate()
else:
self._check_tensor(tensor)
tensor.assemble()
return tensor

def _check_tensor(self, tensor):
if tensor is not None and tensor.a.arguments() != self._form.arguments():
if tensor.a.arguments() != self._form.arguments():
raise ValueError("Form's arguments do not match provided result tensor")


Expand Down Expand Up @@ -1820,12 +1838,12 @@ def __init__(self, form, bcs, local_knl, subdomain_id,
self._active_coefficients = _FormHandler.iter_active_coefficients(form, local_knl.kinfo)
self._constants = _FormHandler.iter_constants(form, local_knl.kinfo)

def build(self, tensor):
def build(self, tensor: op2.Global | op2.Dat | op2.Mat) -> op2.Parloop:
"""Construct the parloop.
Parameters
----------
tensor : op2.Global or firedrake.cofunction.Cofunction or matrix.MatrixBase
tensor :
The output tensor.
"""
Expand Down Expand Up @@ -1909,17 +1927,28 @@ def collect_lgmaps(self):
:param local_knl: A :class:`tsfc_interface.SplitKernel`.
:param bcs: Iterable of boundary conditions.
"""

if len(self._form.arguments()) == 2 and not self._diagonal:
if not self._bcs:
return None
lgmaps = []
for i, j in self.get_indicess():

if any(i is not None for i in self._local_knl.indices):
i, j = self._local_knl.indices
row_bcs, col_bcs = self._filter_bcs(i, j)
rlgmap, clgmap = self._tensor.M[i, j].local_to_global_maps
# the tensor is already indexed
rlgmap, clgmap = self._tensor.local_to_global_maps
rlgmap = self.test_function_space[i].local_to_global_map(row_bcs, rlgmap)
clgmap = self.trial_function_space[j].local_to_global_map(col_bcs, clgmap)
lgmaps.append((rlgmap, clgmap))
return tuple(lgmaps)
return ((rlgmap, clgmap),)
else:
lgmaps = []
for i, j in self.get_indicess():
row_bcs, col_bcs = self._filter_bcs(i, j)
rlgmap, clgmap = self._tensor[i, j].local_to_global_maps
rlgmap = self.test_function_space[i].local_to_global_map(row_bcs, rlgmap)
clgmap = self.trial_function_space[j].local_to_global_map(col_bcs, clgmap)
lgmaps.append((rlgmap, clgmap))
return tuple(lgmaps)
else:
return None

Expand All @@ -1939,10 +1968,6 @@ def _integral_type(self):
def _indexed_function_spaces(self):
return _FormHandler.index_function_spaces(self._form, self._indices)

@property
def _indexed_tensor(self):
return _FormHandler.index_tensor(self._tensor, self._form, self._indices, self._diagonal)

@cached_property
def _mesh(self):
return tuple(self._form.ufl_domains())[self._kinfo.domain_number]
Expand Down Expand Up @@ -1990,28 +2015,27 @@ def _as_parloop_arg(tsfc_arg, self):
@_as_parloop_arg.register(kernel_args.OutputKernelArg)
def _as_parloop_arg_output(_, self):
rank = len(self._form.arguments())
tensor = self._indexed_tensor
Vs = self._indexed_function_spaces

if rank == 0:
return op2.GlobalParloopArg(tensor)
return op2.GlobalParloopArg(self._tensor)
elif rank == 1 or rank == 2 and self._diagonal:
V, = Vs
if V.ufl_element().family() == "Real":
return op2.GlobalParloopArg(tensor)
return op2.GlobalParloopArg(self._tensor)
else:
return op2.DatParloopArg(tensor, self._get_map(V))
return op2.DatParloopArg(self._tensor, self._get_map(V))
elif rank == 2:
rmap, cmap = [self._get_map(V) for V in Vs]

if all(V.ufl_element().family() == "Real" for V in Vs):
assert rmap is None and cmap is None
return op2.GlobalParloopArg(tensor.handle.getPythonContext().global_)
return op2.GlobalParloopArg(self._tensor)
elif any(V.ufl_element().family() == "Real" for V in Vs):
m = rmap or cmap
return op2.DatParloopArg(tensor.handle.getPythonContext().dat, m)
return op2.DatParloopArg(self._tensor, m)
else:
return op2.MatParloopArg(tensor, (rmap, cmap), lgmaps=self.collect_lgmaps())
return op2.MatParloopArg(self._tensor, (rmap, cmap), lgmaps=self.collect_lgmaps())
else:
raise AssertionError

Expand Down Expand Up @@ -2122,22 +2146,3 @@ def index_function_spaces(form, indices):
return tuple(a.ufl_function_space()[i] for i, a in zip(indices, form.arguments()))
else:
raise AssertionError

@staticmethod
def index_tensor(tensor, form, indices, diagonal):
"""Return the PyOP2 data structure tied to ``tensor``, indexed
if necessary.
"""
rank = len(form.arguments())
is_indexed = any(i is not None for i in indices)

if rank == 0:
return tensor
elif rank == 1 or rank == 2 and diagonal:
i, = indices
return tensor.dat[i] if is_indexed else tensor.dat
elif rank == 2:
i, j = indices
return tensor.M[i, j] if is_indexed else tensor.M
else:
raise AssertionError
6 changes: 2 additions & 4 deletions firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,8 +843,7 @@ def local_to_global_map(self, bcs, lgmap=None):
return PETSc.LGMap().create(indices, bsize=bsize, comm=lgmap.comm)

def collapse(self):
from firedrake import FunctionSpace
return FunctionSpace(self.mesh(), self.ufl_element())
return type(self)(self.mesh(), self.ufl_element())


class RestrictedFunctionSpace(FunctionSpace):
Expand Down Expand Up @@ -1161,8 +1160,7 @@ def _ises(self):
return self.dof_dset.field_ises

def collapse(self):
from firedrake import MixedFunctionSpace
return MixedFunctionSpace([V_ for V_ in self])
return type(self)([V_ for V_ in self], self.mesh())


class ProxyFunctionSpace(FunctionSpace):
Expand Down
33 changes: 32 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Global test configuration."""

import pytest
from firedrake.petsc import get_external_packages
from firedrake.petsc import PETSc, get_external_packages


def pytest_configure(config):
Expand Down Expand Up @@ -122,3 +122,34 @@ def fin():
assert len(tape.get_blocks()) == 0

request.addfinalizer(fin)


class _petsc_raises:
"""Context manager for catching PETSc-raised exceptions.
The usual `pytest.raises` exception handler is not suitable for errors
raised inside a callback to PETSc because the error is wrapped inside a
`PETSc.Error` object and so this context manager unpacks this to access
the actual internal error.
Parameters
----------
exc_type :
The exception type that is expected to be raised inside a PETSc callback.
"""
def __init__(self, exc_type):
self.exc_type = exc_type

def __enter__(self):
pass

def __exit__(self, exc_type, exc_val, traceback):
if exc_type is PETSc.Error and isinstance(exc_val.__cause__, self.exc_type):
return True


@pytest.fixture
def petsc_raises():
# This function is needed because pytest does not support classes as fixtures.
return _petsc_raises
Loading

0 comments on commit 4aa725f

Please sign in to comment.