From c1b18f2c8d5b72f3bc1e54b924adf4d45ddefaa8 Mon Sep 17 00:00:00 2001 From: ksagiyam Date: Fri, 13 Sep 2024 02:59:46 +0100 Subject: [PATCH] rebase fix --- firedrake/assemble.py | 54 +++++++++++++++++++++++++++++--- firedrake/mg/kernels.py | 4 +-- firedrake/pointeval_utils.py | 2 +- firedrake/pointquery_utils.py | 2 +- firedrake/slate/slac/compiler.py | 5 ++- 5 files changed, 57 insertions(+), 10 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index a3afef9fd4..e652e2128a 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -1582,6 +1582,8 @@ def __init__(self, form, local_knl, subdomain_id, all_integer_subdomain_ids, dia self._constants = _FormHandler.iter_constants(form, local_knl.kinfo) self._active_exterior_facets = _FormHandler.iter_active_exterior_facets(form, local_knl.kinfo) self._active_interior_facets = _FormHandler.iter_active_interior_facets(form, local_knl.kinfo) + self._active_exterior_facet_orientations = _FormHandler.iter_active_exterior_facet_orientations(form, local_knl.kinfo) + self._active_interior_facet_orientations = _FormHandler.iter_active_interior_facet_orientations(form, local_knl.kinfo) self._map_arg_cache = {} # Cache for holding :class:`op2.MapKernelArg` instances. @@ -1602,6 +1604,8 @@ def build(self): assert_empty(self._constants) assert_empty(self._active_exterior_facets) assert_empty(self._active_interior_facets) + assert_empty(self._active_exterior_facet_orientations) + assert_empty(self._active_interior_facet_orientations) iteration_regions = {"exterior_facet_top": op2.ON_TOP, "exterior_facet_bottom": op2.ON_BOTTOM, @@ -1799,12 +1803,24 @@ def _as_global_kernel_arg_interior_facet(_, self): @_as_global_kernel_arg.register(kernel_args.ExteriorFacetOrientationKernelArg) def _as_global_kernel_arg_exterior_facet_orientation(_, self): - return op2.DatKernelArg((1,)) + mesh = next(self._active_exterior_facet_orientations) + if mesh is self._mesh: + return op2.DatKernelArg((1,)) + else: + m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids) + assert integral_type == "exterior_facet" + return op2.DatKernelArg((1,), m._global_kernel_arg) @_as_global_kernel_arg.register(kernel_args.InteriorFacetOrientationKernelArg) def _as_global_kernel_arg_interior_facet_orientation(_, self): - return op2.DatKernelArg((2,)) + mesh = next(self._active_interior_facet_orientations) + if mesh is self._mesh: + return op2.DatKernelArg((2,)) + else: + m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids) + assert integral_type == "interior_facet" + return op2.DatKernelArg((2,), m._global_kernel_arg) @_as_global_kernel_arg.register(CellFacetKernelArg) @@ -1856,6 +1872,8 @@ def __init__(self, form, bcs, local_knl, subdomain_id, self._constants = _FormHandler.iter_constants(form, local_knl.kinfo) self._active_exterior_facets = _FormHandler.iter_active_exterior_facets(form, local_knl.kinfo) self._active_interior_facets = _FormHandler.iter_active_interior_facets(form, local_knl.kinfo) + self._active_exterior_facet_orientations = _FormHandler.iter_active_exterior_facet_orientations(form, local_knl.kinfo) + self._active_interior_facet_orientations = _FormHandler.iter_active_interior_facet_orientations(form, local_knl.kinfo) def build(self, tensor): """Construct the parloop. @@ -2115,12 +2133,24 @@ def _as_parloop_arg_interior_facet(_, self): @_as_parloop_arg.register(kernel_args.ExteriorFacetOrientationKernelArg) def _as_parloop_arg_exterior_facet_orientation(_, self): - return op2.DatParloopArg(self._mesh.exterior_facets.local_facet_orientation_dat) + mesh = next(self._active_exterior_facet_orientations) + if mesh is self._mesh: + m = None + else: + m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids) + assert integral_type == "exterior_facet" + return op2.DatParloopArg(mesh.exterior_facets.local_facet_orientation_dat, m) @_as_parloop_arg.register(kernel_args.InteriorFacetOrientationKernelArg) def _as_parloop_arg_interior_facet_orientation(_, self): - return op2.DatParloopArg(self._mesh.interior_facets.local_facet_orientation_dat) + mesh = next(self._active_interior_facet_orientations) + if mesh is self._mesh: + m = None + else: + m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids) + assert integral_type == "interior_facet" + return op2.DatParloopArg(mesh.interior_facets.local_facet_orientation_dat, m) @_as_parloop_arg.register(CellFacetKernelArg) @@ -2198,6 +2228,22 @@ def iter_active_interior_facets(form, kinfo): mesh = all_meshes[i] yield mesh + @staticmethod + def iter_active_exterior_facet_orientations(form, kinfo): + """Yield the form exterior facet orientations referenced in ``kinfo``.""" + all_meshes = extract_domains(form) + for i in kinfo.active_domain_numbers.exterior_facet_orientations: + mesh = all_meshes[i] + yield mesh + + @staticmethod + def iter_active_interior_facet_orientations(form, kinfo): + """Yield the form interior facet orientations referenced in ``kinfo``.""" + all_meshes = extract_domains(form) + for i in kinfo.active_domain_numbers.interior_facet_orientations: + mesh = all_meshes[i] + yield mesh + @staticmethod def index_function_spaces(form, indices): """Return the function spaces of the form's arguments, indexed diff --git a/firedrake/mg/kernels.py b/firedrake/mg/kernels.py index a4ddd28f7e..443012a99a 100644 --- a/firedrake/mg/kernels.py +++ b/firedrake/mg/kernels.py @@ -144,7 +144,7 @@ def compile_element(expression, dual_space=None, parameters=None, config = dict(interface=builder, ufl_cell=cell, - integral_type="cell", + domain_integral_type_map={domain: "cell"}, point_indices=(), point_expr=point, argument_multiindices=argument_multiindices, @@ -540,7 +540,6 @@ def dg_injection_kernel(Vf, Vc, ncell): integration_dim, entity_ids = lower_integral_type(Vfe.cell, "cell") macro_cfg = dict(interface=macro_builder, ufl_cell=Vf.ufl_cell(), - integral_type="cell", integration_dim=integration_dim, entity_ids=entity_ids, index_cache=index_cache, @@ -580,7 +579,6 @@ def dg_injection_kernel(Vf, Vc, ncell): coarse_cfg = dict(interface=coarse_builder, ufl_cell=Vc.ufl_cell(), - integral_type="cell", integration_dim=integration_dim, entity_ids=entity_ids, index_cache=index_cache, diff --git a/firedrake/pointeval_utils.py b/firedrake/pointeval_utils.py index 34a64e2964..ac920640e4 100644 --- a/firedrake/pointeval_utils.py +++ b/firedrake/pointeval_utils.py @@ -72,7 +72,7 @@ def compile_element(expression, coordinates, parameters=None): config = dict(interface=builder, ufl_cell=extract_unique_domain(coordinates).ufl_cell(), - integral_type="cell", + domain_integral_type_map={domain: "cell"}, point_indices=(), point_expr=point, scalar_type=utils.ScalarType) diff --git a/firedrake/pointquery_utils.py b/firedrake/pointquery_utils.py index c0dfb83482..a1776caa25 100644 --- a/firedrake/pointquery_utils.py +++ b/firedrake/pointquery_utils.py @@ -160,7 +160,7 @@ def to_reference_coords_newton_step(ufl_coordinate_element, parameters, x0_dtype context = tsfc.fem.GemPointContext( interface=builder, ufl_cell=cell, - integral_type="cell", + domain_integral_type_map={domain: "cell"}, point_indices=(), point_expr=point, scalar_type=parameters["scalar_type"] diff --git a/firedrake/slate/slac/compiler.py b/firedrake/slate/slac/compiler.py index 8bcbed0523..5a05eb3acf 100644 --- a/firedrake/slate/slac/compiler.py +++ b/firedrake/slate/slac/compiler.py @@ -196,7 +196,10 @@ def generate_loopy_kernel(slate_expr, compiler_parameters=None): cell_orientations=(0, ) if builder.bag.needs_cell_orientations else (), cell_sizes=(0, ) if builder.bag.needs_cell_sizes else (), exterior_facets=(), - interior_facets=()), + interior_facets=(), + exterior_facet_orientations=(), + interior_facet_orientations=(), + ), coefficient_numbers=coefficient_numbers, constant_numbers=constant_numbers, needs_cell_facets=builder.bag.needs_cell_facets,