diff --git a/requirements-git.txt b/requirements-git.txt index b6e78108..67a08a00 100644 --- a/requirements-git.txt +++ b/requirements-git.txt @@ -1,4 +1,4 @@ -git+https://github.com/firedrakeproject/ufl.git#egg=fenics-ufl +git+https://github.com/firedrakeproject/ufl.git@ksagiyam/introduce_mixed_map#egg=fenics-ufl git+https://github.com/firedrakeproject/fiat.git#egg=fenics-fiat git+https://github.com/FInAT/FInAT.git#egg=finat git+https://github.com/firedrakeproject/loopy.git#egg=loopy diff --git a/tests/test_mixed_function_space_with_mixed_mesh.py b/tests/test_mixed_function_space_with_mixed_mesh.py new file mode 100644 index 00000000..acd8c75e --- /dev/null +++ b/tests/test_mixed_function_space_with_mixed_mesh.py @@ -0,0 +1,201 @@ +from tsfc import compile_form +from ufl import (triangle, Mesh, MixedMesh, FunctionSpace, TestFunction, TrialFunction, Coefficient, + Measure, SpatialCoordinate, inner, grad, curl, div, split, as_vector, ) +from finat.ufl import FiniteElement, MixedElement, VectorElement +from tsfc.ufl_utils import compute_form_data +from tsfc import kernel_args + + +def test_mixed_function_space_with_mixed_mesh_restrictions_base(): + cell = triangle + elem0 = FiniteElement("Discontinuous Lagrange", cell, 2) + elem1 = FiniteElement("Discontinuous Lagrange", cell, 3) + elem = MixedElement([elem0, elem1]) + mesh0 = Mesh(VectorElement("Lagrange", cell, 1), ufl_id=100) + mesh1 = Mesh(VectorElement("Lagrange", cell, 1), ufl_id=101) + domain = MixedMesh([mesh0, mesh1]) + V = FunctionSpace(domain, elem) + V0 = FunctionSpace(mesh0, elem0) + V1 = FunctionSpace(mesh1, elem1) + f = Coefficient(V, count=1000) + f0, f1 = split(f) + u1 = TrialFunction(V1) + v0 = TestFunction(V0) + dx1 = Measure("dx", mesh1) + ds1 = Measure("ds", mesh1) + dS0 = Measure("dS", mesh0) + f0_split = Coefficient(V0) + f1_split = Coefficient(V1) + # a + form = inner(grad(f1('|')), as_vector([1, 0])) * ds1(777) + form_data = compute_form_data(form, do_split_coefficients={f: [f0_split, f1_split]}) + integral_data, = form_data.integral_data + assert len(integral_data.domain_integral_type_map) == 1 + assert integral_data.domain_integral_type_map[mesh1] == "exterior_facet" + # b + form = inner(grad(f1('|')), grad(f1('|'))) * dS0(777) + form_data = compute_form_data(form, do_split_coefficients={f: [f0_split, f1_split]}) + integral_data, = form_data.integral_data + assert len(integral_data.domain_integral_type_map) == 2 + assert integral_data.domain_integral_type_map[mesh0] == "interior_facet" + assert integral_data.domain_integral_type_map[mesh1] == "exterior_facet" + # c + form = div(f) * inner(grad(f1), grad(f1)) * inner(grad(u1), grad(v0)) * dx1 + form_data = compute_form_data(form, do_split_coefficients={f: [f0_split, f1_split]}) + integral_data, = form_data.integral_data + assert len(integral_data.domain_integral_type_map) == 2 + assert integral_data.domain_integral_type_map[mesh0] == "cell" + assert integral_data.domain_integral_type_map[mesh1] == "cell" + + +def test_mixed_function_space_with_mixed_mesh_3_cg3_bdm3_dg2_dx1(): + cell = triangle + gdim = 2 + elem0 = FiniteElement("Lagrange", cell, 3) + elem1 = FiniteElement("Brezzi-Douglas-Marini", cell, 3) + elem2 = FiniteElement("Discontinuous Lagrange", cell, 2) + elem = MixedElement([elem0, elem1, elem2]) + mesh0 = Mesh(VectorElement("Lagrange", cell, 1), ufl_id=100) + mesh1 = Mesh(VectorElement("Lagrange", cell, 1), ufl_id=101) + mesh2 = Mesh(VectorElement("Lagrange", cell, 1), ufl_id=102) + domain = MixedMesh([mesh0, mesh1, mesh2]) + V = FunctionSpace(domain, elem) + V0 = FunctionSpace(mesh0, elem0) + V1 = FunctionSpace(mesh1, elem1) + V2 = FunctionSpace(mesh2, elem2) + f = Coefficient(V, count=1000) + u0 = TrialFunction(V0) + v1 = TestFunction(V1) + f0, f1, f2 = split(f) + f0_split = Coefficient(V0) + f1_split = Coefficient(V1) + f2_split = Coefficient(V2) + x2 = SpatialCoordinate(mesh2) + dx1 = Measure("dx", mesh1) + form = inner(x2, x2) * f2 * inner(grad(u0), v1) * dx1(999) + form_data = compute_form_data(form, do_split_coefficients={f: [f0_split, f1_split, f2_split]}) + integral_data, = form_data.integral_data + assert len(integral_data.domain_integral_type_map) == 3 + assert integral_data.domain_integral_type_map[mesh0] == "cell" + assert integral_data.domain_integral_type_map[mesh1] == "cell" + assert integral_data.domain_integral_type_map[mesh2] == "cell" + kernel, = compile_form(form) + assert kernel.domain_number == 0 + assert kernel.integral_type == "cell" + assert kernel.subdomain_id == (999, ) + assert kernel.active_domain_numbers.coordinates == (0, 1, 2) + assert kernel.active_domain_numbers.cell_orientations == () + assert kernel.active_domain_numbers.cell_sizes == () + assert kernel.active_domain_numbers.exterior_facets == () + assert kernel.active_domain_numbers.interior_facets == () + assert kernel.coefficient_numbers == ((0, (2, )), ) + assert isinstance(kernel.arguments[0], kernel_args.OutputKernelArg) + assert isinstance(kernel.arguments[1], kernel_args.CoordinatesKernelArg) + assert isinstance(kernel.arguments[2], kernel_args.CoordinatesKernelArg) + assert isinstance(kernel.arguments[3], kernel_args.CoordinatesKernelArg) + assert isinstance(kernel.arguments[4], kernel_args.CoefficientKernelArg) + assert kernel.arguments[0].loopy_arg.shape == (20, 10) + assert kernel.arguments[1].loopy_arg.shape == (3 * gdim, ) + assert kernel.arguments[2].loopy_arg.shape == (3 * gdim, ) + assert kernel.arguments[3].loopy_arg.shape == (3 * gdim, ) + assert kernel.arguments[4].loopy_arg.shape == (6, ) + + +def test_mixed_function_space_with_mixed_mesh_restrictions_bdm3_dg2_dS0(): + cell = triangle + gdim = 2 + elem0 = FiniteElement("Brezzi-Douglas-Marini", cell, 3) + elem1 = FiniteElement("Discontinuous Lagrange", cell, 2) + elem = MixedElement([elem0, elem1]) + mesh0 = Mesh(VectorElement("Lagrange", cell, 1), ufl_id=100) + mesh1 = Mesh(VectorElement("Lagrange", cell, 1), ufl_id=101) + domain = MixedMesh([mesh0, mesh1]) + V = FunctionSpace(domain, elem) + V0 = FunctionSpace(mesh0, elem0) + V1 = FunctionSpace(mesh1, elem1) + f = Coefficient(V, count=1000) + f0, f1 = split(f) + f0_split = Coefficient(V0) + f1_split = Coefficient(V1) + u1 = TrialFunction(V1) + v0 = TestFunction(V0) + dS0 = Measure("dS", mesh0) + form = inner(curl(f1('|')), curl(f1('|'))) * inner(grad(u1('|')), v0('+')) * dS0(777) + form_data = compute_form_data(form, do_split_coefficients={f: [f0_split, f1_split]}) + integral_data, = form_data.integral_data + assert len(integral_data.domain_integral_type_map) == 2 + assert integral_data.domain_integral_type_map[mesh0] == "interior_facet" + assert integral_data.domain_integral_type_map[mesh1] == "exterior_facet" + kernel, = compile_form(form) + assert kernel.domain_number == 0 + assert kernel.integral_type == "interior_facet" + assert kernel.subdomain_id == (777, ) + assert kernel.active_domain_numbers.coordinates == (0, 1) + assert kernel.active_domain_numbers.cell_orientations == () + assert kernel.active_domain_numbers.cell_sizes == () + assert kernel.active_domain_numbers.exterior_facets == (1, ) + assert kernel.active_domain_numbers.interior_facets == (0, ) + assert kernel.coefficient_numbers == ((0, (1, )), ) + assert isinstance(kernel.arguments[0], kernel_args.OutputKernelArg) + assert isinstance(kernel.arguments[1], kernel_args.CoordinatesKernelArg) + assert isinstance(kernel.arguments[2], kernel_args.CoordinatesKernelArg) + assert isinstance(kernel.arguments[3], kernel_args.CoefficientKernelArg) + assert isinstance(kernel.arguments[4], kernel_args.ExteriorFacetKernelArg) + assert isinstance(kernel.arguments[5], kernel_args.InteriorFacetKernelArg) + assert kernel.arguments[0].loopy_arg.shape == (2 * 20, 6) + assert kernel.arguments[1].loopy_arg.shape == (2 * (3 * gdim), ) + assert kernel.arguments[2].loopy_arg.shape == (3 * gdim, ) + assert kernel.arguments[3].loopy_arg.shape == (6, ) + assert kernel.arguments[4].loopy_arg.shape == (1, ) + assert kernel.arguments[5].loopy_arg.shape == (2, ) + + +def test_mixed_function_space_with_mixed_mesh_restrictions_dg2_dg3_ds1(): + cell = triangle + gdim = 2 + elem0 = FiniteElement("Discontinuous Lagrange", cell, 2) + elem1 = FiniteElement("Discontinuous Lagrange", cell, 3) + elem = MixedElement([elem0, elem1]) + mesh0 = Mesh(VectorElement("Lagrange", cell, 1), ufl_id=100) + mesh1 = Mesh(VectorElement("Lagrange", cell, 1), ufl_id=101) + domain = MixedMesh([mesh0, mesh1]) + V = FunctionSpace(domain, elem) + V0 = FunctionSpace(mesh0, elem0) + V1 = FunctionSpace(mesh1, elem1) + f = Coefficient(V, count=1000) + f0_split = Coefficient(V0) + f1_split = Coefficient(V1) + f0, f1 = split(f) + u0 = TrialFunction(V0) + v1 = TestFunction(V1) + ds1 = Measure("ds", mesh1) + form = inner(grad(f1('|')), grad(f0('-'))) * inner(grad(u0('-')), grad(v1('|'))) * ds1(777) + form_data = compute_form_data(form, do_split_coefficients={f: [f0_split, f1_split]}) + integral_data, = form_data.integral_data + assert len(integral_data.domain_integral_type_map) == 2 + assert integral_data.domain_integral_type_map[mesh0] == "interior_facet" + assert integral_data.domain_integral_type_map[mesh1] == "exterior_facet" + kernel, = compile_form(form) + assert kernel.domain_number == 0 + assert kernel.integral_type == "exterior_facet" + assert kernel.subdomain_id == (777, ) + assert kernel.active_domain_numbers.coordinates == (0, 1) + assert kernel.active_domain_numbers.cell_orientations == () + assert kernel.active_domain_numbers.cell_sizes == () + assert kernel.active_domain_numbers.exterior_facets == (0, ) + assert kernel.active_domain_numbers.interior_facets == (1, ) + assert kernel.coefficient_numbers == ((0, (0, 1)), ) + assert isinstance(kernel.arguments[0], kernel_args.OutputKernelArg) + assert isinstance(kernel.arguments[1], kernel_args.CoordinatesKernelArg) + assert isinstance(kernel.arguments[2], kernel_args.CoordinatesKernelArg) + assert isinstance(kernel.arguments[3], kernel_args.CoefficientKernelArg) + assert isinstance(kernel.arguments[4], kernel_args.CoefficientKernelArg) + assert isinstance(kernel.arguments[5], kernel_args.ExteriorFacetKernelArg) + assert isinstance(kernel.arguments[6], kernel_args.InteriorFacetKernelArg) + assert kernel.arguments[0].loopy_arg.shape == (10, 2 * 6) + assert kernel.arguments[1].loopy_arg.shape == (1 * (3 * gdim), ) + assert kernel.arguments[2].loopy_arg.shape == (2 * (3 * gdim), ) + assert kernel.arguments[3].loopy_arg.shape == (2 * 6, ) + assert kernel.arguments[4].loopy_arg.shape == (10, ) + assert kernel.arguments[5].loopy_arg.shape == (1, ) + assert kernel.arguments[6].loopy_arg.shape == (2, ) diff --git a/tests/test_tsfc_182.py b/tests/test_tsfc_182.py index 556a6baf..30e81804 100644 --- a/tests/test_tsfc_182.py +++ b/tests/test_tsfc_182.py @@ -1,6 +1,6 @@ import pytest -from ufl import Coefficient, TestFunction, dx, inner, tetrahedron, Mesh, FunctionSpace +from ufl import Coefficient, TestFunction, dx, inner, tetrahedron, Mesh, MixedMesh, FunctionSpace from finat.ufl import FiniteElement, MixedElement, VectorElement from tsfc import compile_form @@ -20,7 +20,8 @@ def test_delta_elimination(mode): element_chi_lambda = MixedElement(element_eps_p, element_lambda) domain = Mesh(VectorElement("Lagrange", tetrahedron, 1)) - space = FunctionSpace(domain, element_chi_lambda) + domains = MixedMesh([domain, domain]) + space = FunctionSpace(domains, element_chi_lambda) chi_lambda = Coefficient(space) delta_chi_lambda = TestFunction(space) diff --git a/tests/test_tsfc_204.py b/tests/test_tsfc_204.py index 35e2caa0..70d465a8 100644 --- a/tests/test_tsfc_204.py +++ b/tests/test_tsfc_204.py @@ -1,12 +1,13 @@ from tsfc import compile_form from ufl import (Coefficient, FacetNormal, - FunctionSpace, Mesh, as_matrix, + FunctionSpace, Mesh, MixedMesh, as_matrix, dot, dS, ds, dx, facet, grad, inner, outer, split, triangle) from finat.ufl import BrokenElement, FiniteElement, MixedElement, VectorElement def test_physically_mapped_facet(): mesh = Mesh(VectorElement("P", triangle, 1)) + meshes = MixedMesh([mesh, mesh, mesh, mesh, mesh]) # set up variational problem U = FiniteElement("Morley", mesh.ufl_cell(), 2) @@ -15,7 +16,7 @@ def test_physically_mapped_facet(): Vv = VectorElement(BrokenElement(V)) Qhat = VectorElement(BrokenElement(V[facet]), dim=2) Vhat = VectorElement(V[facet], dim=2) - Z = FunctionSpace(mesh, MixedElement(U, Vv, Qhat, Vhat, R)) + Z = FunctionSpace(meshes, MixedElement(U, Vv, Qhat, Vhat, R)) z = Coefficient(Z) u, d, qhat, dhat, lam = split(z) diff --git a/tsfc/driver.py b/tsfc/driver.py index 6e3c3baa..9588406c 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -8,7 +8,7 @@ from ufl.algorithms import extract_arguments, extract_coefficients from ufl.algorithms.analysis import has_type from ufl.classes import Form, GeometricQuantity -from ufl.domain import extract_unique_domain +from ufl.domain import extract_unique_domain, extract_domains import gem import gem.impero_utils as impero_utils @@ -26,9 +26,9 @@ TSFCIntegralDataInfo = collections.namedtuple("TSFCIntegralDataInfo", - ["domain", "integral_type", "subdomain_id", "domain_number", + ["domain", "integral_type", "subdomain_id", "domain_number", "domain_integral_type_map", "arguments", - "coefficients", "coefficient_numbers"]) + "coefficients", "coefficient_split", "coefficient_numbers"]) TSFCIntegralDataInfo.__doc__ = """ Minimal set of objects for kernel builders. @@ -47,7 +47,7 @@ """ -def compile_form(form, prefix="form", parameters=None, interface=None, diagonal=False, log=False): +def compile_form(form, prefix="form", parameters=None, dont_split_numbers=(), diagonal=False, log=False): """Compiles a UFL form into a set of assembly kernels. :arg form: UFL form @@ -65,67 +65,77 @@ def compile_form(form, prefix="form", parameters=None, interface=None, diagonal= # Determine whether in complex mode: complex_mode = parameters and is_complex(parameters.get("scalar_type")) - fd = ufl_utils.compute_form_data(form, complex_mode=complex_mode) + form_data = ufl_utils.compute_form_data(form, + do_split_coefficients=tuple(c for i, c in enumerate(form.coefficients()) if i not in dont_split_numbers), + complex_mode=complex_mode) logger.info(GREEN % "compute_form_data finished in %g seconds.", time.time() - cpu_time) kernels = [] - for integral_data in fd.integral_data: + for integral_data in form_data.integral_data: start = time.time() - kernel = compile_integral(integral_data, fd, prefix, parameters, interface=interface, diagonal=diagonal, log=log) - if kernel is not None: - kernels.append(kernel) + if integral_data.integrals: + kernel = compile_integral(integral_data, form_data, prefix, parameters, diagonal=diagonal, log=log) + if kernel is not None: + kernels.append(kernel) logger.info(GREEN % "compile_integral finished in %g seconds.", time.time() - start) logger.info(GREEN % "TSFC finished in %g seconds.", time.time() - cpu_time) return kernels -def compile_integral(integral_data, form_data, prefix, parameters, interface, *, diagonal=False, log=False): +def compile_integral(integral_data, form_data, prefix, parameters, *, diagonal=False, log=False): """Compiles a UFL integral into an assembly kernel. :arg integral_data: UFL integral data :arg form_data: UFL form data :arg prefix: kernel name will start with this string :arg parameters: parameters object - :arg interface: backend module for the kernel interface :arg diagonal: Are we building a kernel for the diagonal of a rank-2 element tensor? :arg log: bool if the Kernel should be profiled with Log events :returns: a kernel constructed by the kernel interface """ parameters = preprocess_parameters(parameters) - if interface is None: - interface = firedrake_interface_loopy.KernelBuilder scalar_type = parameters["scalar_type"] integral_type = integral_data.integral_type if integral_type.startswith("interior_facet") and diagonal: raise NotImplementedError("Sorry, we can't assemble the diagonal of a form for interior facet integrals") - mesh = integral_data.domain arguments = form_data.preprocessed_form.arguments() kernel_name = f"{prefix}_{integral_type}_integral" - # Dict mapping domains to index in original_form.ufl_domains() - domain_numbering = form_data.original_form.domain_numbering() - domain_number = domain_numbering[integral_data.domain] - coefficients = [form_data.function_replace_map[c] for c in integral_data.integral_coefficients] # This is which coefficient in the original form the # current coefficient is. # Consider f*v*dx + g*v*ds, the full form contains two # coefficients, but each integral only requires one. - coefficient_numbers = tuple(form_data.original_coefficient_positions[i] - for i, (_, enabled) in enumerate(zip(form_data.reduced_coefficients, integral_data.enabled_coefficients)) - if enabled) + coefficients = [] + coefficient_split = {} + coefficient_numbers = [] + for i, (coeff_orig, enabled) in enumerate(zip(form_data.reduced_coefficients, integral_data.enabled_coefficients)): + if enabled: + coeff = form_data.function_replace_map[coeff_orig] + coefficients.append(coeff) + if coeff in form_data.coefficient_split: + coefficient_split[coeff] = form_data.coefficient_split[coeff] + coefficient_numbers.append(form_data.original_coefficient_positions[i]) + mesh = integral_data.domain + all_meshes = extract_domains(form_data.original_form) + domain_number = all_meshes.index(mesh) integral_data_info = TSFCIntegralDataInfo(domain=integral_data.domain, integral_type=integral_data.integral_type, subdomain_id=integral_data.subdomain_id, domain_number=domain_number, + domain_integral_type_map={mesh: integral_data.domain_integral_type_map[mesh] if mesh in integral_data.domain_integral_type_map else None for mesh in all_meshes}, arguments=arguments, coefficients=coefficients, + coefficient_split=coefficient_split, coefficient_numbers=coefficient_numbers) - builder = interface(integral_data_info, - scalar_type, - diagonal=diagonal) - builder.set_coordinates(mesh) - builder.set_cell_sizes(mesh) - builder.set_coefficients(integral_data, form_data) + builder = firedrake_interface_loopy.KernelBuilder(integral_data_info, + scalar_type, + diagonal=diagonal) + builder.set_entity_numbers(all_meshes) + builder.set_entity_orientations(all_meshes) + builder.set_coordinates(all_meshes) + builder.set_cell_orientations(all_meshes) + builder.set_cell_sizes(all_meshes) + builder.set_coefficients() # TODO: We do not want pass constants to kernels that do not need them # so we should attach the constants to integral data instead builder.set_constants(form_data.constants) @@ -133,8 +143,7 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, *, for integral in integral_data.integrals: params = parameters.copy() params.update(integral.metadata()) # integral metadata overrides - integrand = ufl.replace(integral.integrand(), form_data.function_replace_map) - integrand_exprs = builder.compile_integrand(integrand, params, ctx) + integrand_exprs = builder.compile_integrand(integral.integrand(), params, ctx) integral_exprs = builder.construct_integrals(integrand_exprs, params) builder.stash_integrals(integral_exprs, params, ctx) return builder.construct_kernel(kernel_name, ctx, log) @@ -207,6 +216,7 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, if domain is None: domain = extract_unique_domain(expression) assert domain is not None + builder._domain_integral_type_map = {domain: "cell"} # Collect required coefficients and determine numbering coefficients = extract_coefficients(expression) @@ -219,7 +229,8 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, # Create a fake coordinate coefficient for a domain. coords_coefficient = ufl.Coefficient(ufl.FunctionSpace(domain, domain.ufl_coordinate_element())) builder.domain_coordinate[domain] = coords_coefficient - builder.set_cell_sizes(domain) + builder.set_cell_orientations((domain, )) + builder.set_cell_sizes((domain, )) coefficients = [coords_coefficient] + coefficients needs_external_coords = True builder.set_coefficients(coefficients) @@ -235,7 +246,7 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, ufl_cell=domain.ufl_cell(), # FIXME: change if we ever implement # interpolation on facets. - integral_type="cell", + domain_integral_type_map={domain: "cell"}, argument_multiindices=argument_multiindices, index_cache={}, scalar_type=parameters["scalar_type"]) diff --git a/tsfc/fem.py b/tsfc/fem.py index abc8bc7c..8f15d2d5 100644 --- a/tsfc/fem.py +++ b/tsfc/fem.py @@ -27,10 +27,11 @@ PositiveRestricted, QuadratureWeight, ReferenceCellEdgeVectors, ReferenceCellVolume, ReferenceFacetVolume, ReferenceNormal, - SpatialCoordinate) + SingleValueRestricted, SpatialCoordinate) from ufl.corealg.map_dag import map_expr_dag, map_expr_dags from ufl.corealg.multifunction import MultiFunction from ufl.domain import extract_unique_domain +from ufl.algorithms import extract_arguments from tsfc import ufl2gem from tsfc.finatinterface import as_fiat_cell, create_element @@ -48,7 +49,7 @@ class ContextBase(ProxyKernelInterface): keywords = ('ufl_cell', 'fiat_cell', - 'integral_type', + 'domain_integral_type_map', 'integration_dim', 'entity_ids', 'argument_multiindices', @@ -82,7 +83,7 @@ def epsilon(self): def complex_mode(self): return is_complex(self.scalar_type) - def entity_selector(self, callback, restriction): + def entity_selector(self, callback, domain, restriction): """Selects code for the correct entity at run-time. Callback generates code for a specified entity. @@ -96,7 +97,7 @@ def entity_selector(self, callback, restriction): if len(self.entity_ids) == 1: return callback(self.entity_ids[0]) else: - f = self.entity_number(restriction) + f = self.entity_number(domain, restriction) return gem.select_expression(list(map(callback, self.entity_ids)), f) argument_multiindices = () @@ -112,7 +113,17 @@ def translator(self): @cached_property def use_canonical_quadrature_point_ordering(self): - return isinstance(self.fiat_cell, UFCHexahedron) and self.integral_type in ['exterior_facet', 'interior_facet'] + cell_integral_type_map = {as_fiat_cell(domain.ufl_cell()): integral_type + for domain, integral_type in self.domain_integral_type_map.items() + if integral_type is not None} + if all(integral_type == 'cell' for integral_type in cell_integral_type_map.values()): + return False + elif all(integral_type in ['exterior_facet', 'interior_facet'] for integral_type in cell_integral_type_map.values()): + if all(isinstance(cell, UFCHexahedron) for cell in cell_integral_type_map): + return True + elif len(set(cell_integral_type_map)) > 1: # mixed cell types + return True + return False class CoordinateMapping(PhysicalGeometry): @@ -137,19 +148,20 @@ def preprocess(self, expr, context): :arg context: The translation context. :returns: A new UFL expression """ - ifacet = self.interface.integral_type.startswith("interior_facet") + domain = extract_unique_domain(self.mt.terminal) + ifacet = self.interface.domain_integral_type_map[domain].startswith("interior_facet") return preprocess_expression(expr, complex_mode=context.complex_mode, do_apply_restrictions=ifacet) @property def config(self): config = {name: getattr(self.interface, name) - for name in ["ufl_cell", "index_cache", "scalar_type"]} + for name in ["ufl_cell", "index_cache", "scalar_type", "domain_integral_type_map"]} config["interface"] = self.interface return config def cell_size(self): - return self.interface.cell_size(self.mt.restriction) + return self.interface.cell_size(extract_unique_domain(self.mt.terminal), self.mt.restriction) def jacobian_at(self, point): ps = PointSingleton(point) @@ -159,6 +171,10 @@ def jacobian_at(self, point): expr = PositiveRestricted(expr) elif self.mt.restriction == '-': expr = NegativeRestricted(expr) + elif self.mt.restriction == '|': + expr = SingleValueRestricted(expr) + elif self.mt.restriction == '?': + raise RuntimeError("Not expecting '?' restriction at this stage") config = {"point_set": PointSingleton(point)} config.update(self.config) context = PointSetContext(**config) @@ -171,6 +187,10 @@ def detJ_at(self, point): expr = PositiveRestricted(expr) elif self.mt.restriction == '-': expr = NegativeRestricted(expr) + elif self.mt.restriction == '|': + expr = SingleValueRestricted(expr) + elif self.mt.restriction == '?': + raise RuntimeError("Not expecting '?' restriction at this stage") config = {"point_set": PointSingleton(point)} config.update(self.config) context = PointSetContext(**config) @@ -214,6 +234,10 @@ def physical_edge_lengths(self): expr = PositiveRestricted(expr) elif self.mt.restriction == '-': expr = NegativeRestricted(expr) + elif self.mt.restriction == '|': + expr = SingleValueRestricted(expr) + elif self.mt.restriction == '?': + raise RuntimeError("Not expecting '?' restriction at this stage") cell = self.interface.fiat_cell sd = cell.get_spatial_dimension() @@ -238,6 +262,10 @@ def physical_points(self, point_set, entity=None): expr = PositiveRestricted(expr) elif self.mt.restriction == '-': expr = NegativeRestricted(expr) + elif self.mt.restriction == '|': + expr = SingleValueRestricted(expr) + elif self.mt.restriction == '?': + raise RuntimeError("Not expecting '?' restriction at this stage") config = {"point_set": point_set} config.update(self.config) if entity is not None: @@ -337,35 +365,37 @@ def __init__(self, context): # Can't put these in the ufl2gem mixin, since they (unlike # everything else) want access to the translation context. def cell_avg(self, o): - if self.context.integral_type != "cell": + domain = extract_unique_domain(o) + integral_type = self.context.domain_integral_type_map[domain] + if integral_type != "cell": # Need to create a cell-based quadrature rule and # translate the expression using that (c.f. CellVolume # below). raise NotImplementedError("CellAvg on non-cell integrals not yet implemented") integrand, = o.ufl_operands - domain = extract_unique_domain(o) - measure = ufl.Measure(self.context.integral_type, domain=domain) + measure = ufl.Measure(integral_type, domain=domain) integrand, degree, argument_multiindices = entity_avg(integrand / CellVolume(domain), measure, self.context.argument_multiindices) config = {name: getattr(self.context, name) - for name in ["ufl_cell", "index_cache", "scalar_type"]} + for name in ["ufl_cell", "index_cache", "scalar_type", "domain_integral_type_map"]} config.update(quadrature_degree=degree, interface=self.context, argument_multiindices=argument_multiindices) expr, = compile_ufl(integrand, PointSetContext(**config), point_sum=True) return expr def facet_avg(self, o): - if self.context.integral_type == "cell": + domain = extract_unique_domain(o) + integral_type = self.context.domain_integral_type_map[domain] + if integral_type == "cell": raise ValueError("Can't take FacetAvg in cell integral") integrand, = o.ufl_operands - domain = extract_unique_domain(o) - measure = ufl.Measure(self.context.integral_type, domain=domain) + measure = ufl.Measure(integral_type, domain=domain) integrand, degree, argument_multiindices = entity_avg(integrand / FacetArea(domain), measure, self.context.argument_multiindices) config = {name: getattr(self.context, name) for name in ["ufl_cell", "index_cache", "scalar_type", "integration_dim", "entity_ids", - "integral_type"]} + "domain_integral_type_map"]} config.update(quadrature_degree=degree, interface=self.context, argument_multiindices=argument_multiindices) expr, = compile_ufl(integrand, PointSetContext(**config), point_sum=True) @@ -402,7 +432,7 @@ def translate_geometricquantity(terminal, mt, ctx): @translate.register(CellOrientation) def translate_cell_orientation(terminal, mt, ctx): - return ctx.cell_orientation(mt.restriction) + return ctx.cell_orientation(extract_unique_domain(terminal), mt.restriction) @translate.register(ReferenceCellVolume) @@ -412,7 +442,7 @@ def translate_reference_cell_volume(terminal, mt, ctx): @translate.register(ReferenceFacetVolume) def translate_reference_facet_volume(terminal, mt, ctx): - assert ctx.integral_type != "cell" + assert ctx.domain_integral_type_map[extract_unique_domain(terminal)] != "cell" # Sum of quadrature weights is entity volume return gem.optimise.aggressive_unroll(gem.index_sum(ctx.weight_expr, ctx.point_indices)) @@ -426,7 +456,7 @@ def translate_cell_facet_jacobian(terminal, mt, ctx): def callback(entity_id): return gem.Literal(make_cell_facet_jacobian(cell, facet_dim, entity_id)) - return ctx.entity_selector(callback, mt.restriction) + return ctx.entity_selector(callback, extract_unique_domain(terminal), mt.restriction) def make_cell_facet_jacobian(cell, facet_dim, facet_i): @@ -451,7 +481,7 @@ def translate_reference_normal(terminal, mt, ctx): def callback(facet_i): n = ctx.fiat_cell.compute_reference_normal(ctx.integration_dim, facet_i) return gem.Literal(n) - return ctx.entity_selector(callback, mt.restriction) + return ctx.entity_selector(callback, extract_unique_domain(terminal), mt.restriction) @translate.register(ReferenceCellEdgeVectors) @@ -484,7 +514,7 @@ def callback(entity_id): data = numpy.asarray(list(map(t, ps.points))) return gem.Literal(data.reshape(point_shape + data.shape[1:])) - return gem.partial_indexed(ctx.entity_selector(callback, mt.restriction), + return gem.partial_indexed(ctx.entity_selector(callback, extract_unique_domain(terminal), mt.restriction), ps.indices) @@ -527,7 +557,7 @@ def translate_cellvolume(terminal, mt, ctx): interface = CellVolumeKernelInterface(ctx, mt.restriction) config = {name: getattr(ctx, name) - for name in ["ufl_cell", "index_cache", "scalar_type"]} + for name in ["ufl_cell", "index_cache", "scalar_type", "domain_integral_type_map"]} config.update(interface=interface, quadrature_degree=degree) expr, = compile_ufl(integrand, PointSetContext(**config), point_sum=True) return expr @@ -535,13 +565,14 @@ def translate_cellvolume(terminal, mt, ctx): @translate.register(FacetArea) def translate_facetarea(terminal, mt, ctx): - assert ctx.integral_type != 'cell' domain = extract_unique_domain(terminal) - integrand, degree = one_times(ufl.Measure(ctx.integral_type, domain=domain)) + integral_type = ctx.domain_integral_type_map[domain] + assert integral_type != 'cell' + integrand, degree = one_times(ufl.Measure(integral_type, domain=domain)) config = {name: getattr(ctx, name) for name in ["ufl_cell", "integration_dim", "scalar_type", - "entity_ids", "index_cache"]} + "entity_ids", "index_cache", "domain_integral_type_map"]} config.update(interface=ctx, quadrature_degree=degree) expr, = compile_ufl(integrand, PointSetContext(**config), point_sum=True) return expr @@ -555,7 +586,7 @@ def translate_cellorigin(terminal, mt, ctx): point_set = PointSingleton((0.0,) * domain.topological_dimension()) config = {name: getattr(ctx, name) - for name in ["ufl_cell", "index_cache", "scalar_type"]} + for name in ["ufl_cell", "index_cache", "scalar_type", "domain_integral_type_map"]} config.update(interface=ctx, point_set=point_set) context = PointSetContext(**config) return context.translator(expression) @@ -568,7 +599,7 @@ def translate_cell_vertices(terminal, mt, ctx): ps = PointSet(numpy.array(ctx.fiat_cell.get_vertices())) config = {name: getattr(ctx, name) - for name in ["ufl_cell", "index_cache", "scalar_type"]} + for name in ["ufl_cell", "index_cache", "scalar_type", "domain_integral_type_map"]} config.update(interface=ctx, point_set=ps) context = PointSetContext(**config) expr = context.translator(ufl_expr) @@ -637,10 +668,10 @@ def callback(entity_id): # A numerical hack that FFC used to apply on FIAT tables still # lives on after ditching FFC and switching to FInAT. return ffc_rounding(square, ctx.epsilon) - table = ctx.entity_selector(callback, mt.restriction) + table = ctx.entity_selector(callback, extract_unique_domain(terminal), mt.restriction) if ctx.use_canonical_quadrature_point_ordering: quad_multiindex = ctx.quadrature_rule.point_set.indices - quad_multiindex_permuted = _make_quad_multiindex_permuted(mt, ctx) + quad_multiindex_permuted = _make_quad_multiindex_permuted(terminal, mt, ctx) mapper = gem.node.MemoizerArg(gem.optimise.filtered_replace_indices) table = mapper(table, tuple(zip(quad_multiindex, quad_multiindex_permuted))) return gem.ComponentTensor(gem.Indexed(table, argument_multiindex + sigma), sigma) @@ -682,7 +713,7 @@ def take_singleton(xs): per_derivative = {alpha: take_singleton(tables) for alpha, tables in per_derivative.items()} else: - f = ctx.entity_number(mt.restriction) + f = ctx.entity_number(extract_unique_domain(terminal), mt.restriction) per_derivative = {alpha: gem.select_expression(tables, f) for alpha, tables in per_derivative.items()} @@ -715,13 +746,13 @@ def take_singleton(xs): if ctx.use_canonical_quadrature_point_ordering: quad_multiindex = ctx.quadrature_rule.point_set.indices - quad_multiindex_permuted = _make_quad_multiindex_permuted(mt, ctx) + quad_multiindex_permuted = _make_quad_multiindex_permuted(terminal, mt, ctx) mapper = gem.node.MemoizerArg(gem.optimise.filtered_replace_indices) result = mapper(result, tuple(zip(quad_multiindex, quad_multiindex_permuted))) return result -def _make_quad_multiindex_permuted(mt, ctx): +def _make_quad_multiindex_permuted(terminal, mt, ctx): quad_rule = ctx.quadrature_rule # Note that each quad index here represents quad points on a physical # cell axis, but the table is indexed by indices representing the points @@ -734,7 +765,8 @@ def _make_quad_multiindex_permuted(mt, ctx): if len(extents) != 1: raise ValueError("Must have the same number of quadrature points in each symmetric axis") quad_multiindex_permuted = [] - o = ctx.entity_orientation(mt.restriction) + domain = extract_unique_domain(terminal) + o = ctx.entity_orientation(domain, mt.restriction) if not isinstance(o, FIATOrientation): raise ValueError(f"Expecting an instance of FIATOrientation : got {o}") eo = cell.extract_extrinsic_orientation(o) @@ -749,27 +781,23 @@ def _make_quad_multiindex_permuted(mt, ctx): return tuple(quad_multiindex_permuted) -def compile_ufl(expression, context, interior_facet=False, point_sum=False): +def compile_ufl(expression, context, point_sum=False): """Translate a UFL expression to GEM. :arg expression: The UFL expression to compile. :arg context: translation context - either a :class:`GemPointContext` or :class:`PointSetContext` - :arg interior_facet: If ``true``, treat expression as an interior - facet integral (default ``False``) :arg point_sum: If ``true``, return a `gem.IndexSum` of the final gem expression along the ``context.point_indices`` (if present). """ # Abs-simplification expression = simplify_abs(expression, context.complex_mode) - if interior_facet: - expressions = [] - for rs in itertools.product(("+", "-"), repeat=len(context.argument_multiindices)): - expressions.append(map_expr_dag(PickRestriction(*rs), expression)) - else: - expressions = [expression] - + arguments = extract_arguments(expression) + domains = [extract_unique_domain(argument) for argument in arguments] + integral_types = [context.domain_integral_type_map[domain] for domain in domains] + rs_tuples = [("+", "-") if integral_type.startswith("interior_facet") else (None, ) for integral_type in integral_types] + expressions = [map_expr_dag(PickRestriction(*rs), expression) for rs in itertools.product(*rs_tuples)] # Translate UFL to GEM, lowering finite element specific nodes result = map_expr_dags(context.translator, expressions) if point_sum: diff --git a/tsfc/kernel_interface/__init__.py b/tsfc/kernel_interface/__init__.py index 51142638..75754799 100644 --- a/tsfc/kernel_interface/__init__.py +++ b/tsfc/kernel_interface/__init__.py @@ -22,19 +22,19 @@ def constant(self, const): """Return the GEM expression corresponding to the constant.""" @abstractmethod - def cell_orientation(self, restriction): + def cell_orientation(self, domain, restriction): """Cell orientation as a GEM expression.""" @abstractmethod - def cell_size(self, restriction): + def cell_size(self, domain, restriction): """Mesh cell size as a GEM expression. Shape (nvertex, ) in FIAT vertex ordering.""" @abstractmethod - def entity_number(self, restriction): + def entity_number(self, domain, restriction): """Facet or vertex number as a GEM index.""" @abstractmethod - def entity_orientation(self, restriction): + def entity_orientation(self, domain, restriction): """Entity orientation as a GEM index.""" @abstractmethod diff --git a/tsfc/kernel_interface/common.py b/tsfc/kernel_interface/common.py index df7e879f..385b2f09 100644 --- a/tsfc/kernel_interface/common.py +++ b/tsfc/kernel_interface/common.py @@ -3,6 +3,10 @@ import string from functools import reduce from itertools import chain, product +import copy + +from ufl.utils.sequences import max_degree +from ufl.domain import extract_unique_domain import gem import gem.impero_utils as impero_utils @@ -15,24 +19,18 @@ from gem.optimise import remove_componenttensors as prune from gem.utils import cached_property from numpy import asarray -from tsfc import fem, ufl_utils +from tsfc import fem from tsfc.finatinterface import as_fiat_cell, create_element from tsfc.kernel_interface import KernelInterface from tsfc.logging import logger -from ufl.utils.sequences import max_degree class KernelBuilderBase(KernelInterface): """Helper class for building local assembly kernels.""" - def __init__(self, scalar_type, interior_facet=False): - """Initialise a kernel builder. - - :arg interior_facet: kernel accesses two cells - """ - assert isinstance(interior_facet, bool) + def __init__(self, scalar_type): + """Initialise a kernel builder.""" self.scalar_type = scalar_type - self.interior_facet = interior_facet self.prepare = [] self.finalise = [] @@ -56,10 +54,13 @@ def coordinate(self, domain): def coefficient(self, ufl_coefficient, restriction): """A function that maps :class:`ufl.Coefficient`s to GEM expressions.""" + if restriction == '?': + raise RuntimeError("Not expecting '?' restriction at this stage") kernel_arg = self.coefficient_map[ufl_coefficient] + domain = extract_unique_domain(ufl_coefficient) if ufl_coefficient.ufl_element().family() == 'Real': return kernel_arg - elif not self.interior_facet: + elif not self._domain_integral_type_map[domain].startswith("interior_facet"): # '|' is for exterior_facet return kernel_arg else: return kernel_arg[{'+': 0, '-': 1}[restriction]] @@ -67,34 +68,41 @@ def coefficient(self, ufl_coefficient, restriction): def constant(self, const): return self.constant_map[const] - def cell_orientation(self, restriction): + def cell_orientation(self, domain, restriction): """Cell orientation as a GEM expression.""" - f = {None: 0, '+': 0, '-': 1}[restriction] - # Assume self._cell_orientations tuple is set up at this point. - co_int = self._cell_orientations[f] + if restriction == '?': + raise RuntimeError("Not expecting '?' restriction at this stage") + if not hasattr(self, "_cell_orientations"): + raise RuntimeError("Haven't called set_cell_orientations") + f = {None: 0, '|': 0, '+': 0, '-': 1}[restriction] + co_int = self._cell_orientations[domain][f] return gem.Conditional(gem.Comparison("==", co_int, gem.Literal(1)), gem.Literal(-1), gem.Conditional(gem.Comparison("==", co_int, gem.Zero()), gem.Literal(1), gem.Literal(numpy.nan))) - def cell_size(self, restriction): + def cell_size(self, domain, restriction): + if restriction == '?': + raise RuntimeError("Not expecting '?' restriction at this stage") if not hasattr(self, "_cell_sizes"): raise RuntimeError("Haven't called set_cell_sizes") - if self.interior_facet: - return self._cell_sizes[{'+': 0, '-': 1}[restriction]] + if self._domain_integral_type_map[domain].startswith("interior_facet"): + return self._cell_sizes[domain][{'+': 0, '-': 1}[restriction]] else: - return self._cell_sizes + return self._cell_sizes[domain] - def entity_number(self, restriction): + def entity_number(self, domain, restriction): """Facet or vertex number as a GEM index.""" - # Assume self._entity_number dict is set up at this point. - return self._entity_number[restriction] + if not hasattr(self, "_entity_numbers"): + raise RuntimeError("Haven't called set_entity_numbers") + return self._entity_numbers[domain][restriction] - def entity_orientation(self, restriction): + def entity_orientation(self, domain, restriction): """Facet orientation as a GEM index.""" - # Assume self._entity_orientation dict is set up at this point. - return self._entity_orientation[restriction] + if not hasattr(self, "_entity_orientations"): + raise RuntimeError("Haven't called set_entity_orientations") + return self._entity_orientations[domain][restriction] def apply_glue(self, prepare=None, finalise=None): """Append glue code for operations that are not handled in the @@ -132,21 +140,18 @@ def compile_integrand(self, integrand, params, ctx): See :meth:`create_context` for typical calling sequence. """ - # Split Coefficients - if self.coefficient_split: - integrand = ufl_utils.split_coefficients(integrand, self.coefficient_split) # Compile: ufl -> gem info = self.integral_data_info functions = list(info.arguments) + [self.coordinate(info.domain)] + list(info.coefficients) set_quad_rule(params, info.domain.ufl_cell(), info.integral_type, functions) quad_rule = params["quadrature_rule"] config = self.fem_config() + config['domain_integral_type_map'] = self._domain_integral_type_map config['argument_multiindices'] = self.argument_multiindices config['quadrature_rule'] = quad_rule config['index_cache'] = ctx['index_cache'] expressions = fem.compile_ufl(integrand, - fem.PointSetContext(**config), - interior_facet=self.interior_facet) + fem.PointSetContext(**config)) ctx['quadrature_indices'].extend(quad_rule.point_set.indices) return expressions @@ -243,7 +248,6 @@ def fem_config(self): integration_dim, entity_ids = lower_integral_type(fiat_cell, integral_type) return dict(interface=self, ufl_cell=cell, - integral_type=integral_type, integration_dim=integration_dim, entity_ids=entity_ids, scalar_type=self.fem_scalar_type) @@ -434,9 +438,9 @@ def check_requirements(ir): rt_tabs = {} for node in traversal(ir): if isinstance(node, gem.Variable): - if node.name == "cell_orientations": + if node.name == "cell_orientations_0": cell_orientations = True - elif node.name == "cell_sizes": + elif node.name == "cell_sizes_0": cell_sizes = True elif node.name.startswith("rt_"): rt_tabs[node.name] = node.shape @@ -457,55 +461,74 @@ def prepare_constant(constant, number): constant.ufl_shape) -def prepare_coefficient(coefficient, name, interior_facet=False): +def prepare_coefficient(coefficient, name, domain_integral_type_map): """Bridges the kernel interface and the GEM abstraction for Coefficients. - :arg coefficient: UFL Coefficient - :arg name: unique name to refer to the Coefficient in the kernel - :arg interior_facet: interior facet integral? - :returns: (funarg, expression) - expression - GEM expression referring to the Coefficient - values - """ - assert isinstance(interior_facet, bool) + Parameters + ---------- + coefficient : ufl.Coefficient + UFL Coefficient. + name : str + Unique name to refer to the Coefficient in the kernel. + domain_integral_type_map : dict + Map from domain to integral_type. + + Returns + ------- + gem.Node + GEM expression referring to the Coefficient values. + """ if coefficient.ufl_element().family() == 'Real': # Constant value_size = coefficient.ufl_function_space().value_size expression = gem.reshape(gem.Variable(name, (value_size,)), coefficient.ufl_shape) return expression - finat_element = create_element(coefficient.ufl_element()) shape = finat_element.index_shape size = numpy.prod(shape, dtype=int) - - if not interior_facet: - expression = gem.reshape(gem.Variable(name, (size,)), shape) - else: + domain = extract_unique_domain(coefficient) + integral_type = domain_integral_type_map[domain] + if integral_type is None: + # This means that this coefficient does not exist in the DAG, + # so corresponding gem expression will never be needed. + expression = None + elif integral_type.startswith("interior_facet"): varexp = gem.Variable(name, (2 * size,)) plus = gem.view(varexp, slice(size)) minus = gem.view(varexp, slice(size, 2 * size)) expression = (gem.reshape(plus, shape), gem.reshape(minus, shape)) + else: + expression = gem.reshape(gem.Variable(name, (size,)), shape) return expression -def prepare_arguments(arguments, multiindices, interior_facet=False, diagonal=False): +def prepare_arguments(arguments, multiindices, domain_integral_type_map, diagonal=False): """Bridges the kernel interface and the GEM abstraction for Arguments. Vector Arguments are rearranged here for interior facet integrals. - :arg arguments: UFL Arguments - :arg multiindices: Argument multiindices - :arg interior_facet: interior facet integral? - :arg diagonal: Are we assembling the diagonal of a rank-2 element tensor? - :returns: (funarg, expression) - expressions - GEM expressions referring to the argument - tensor - """ - assert isinstance(interior_facet, bool) + Parameters + ---------- + arguments : tuple + UFL Arguments. + multiindices : tuple + Argument multiindices. + domain_integral_type_map : dict + Map from domain to integral_type. + diagonal : bool + Are we assembling the diagonal of a rank-2 element tensor? + + Returns + ------- + tuple + Tuple of function arg and GEM expressions referring to the argument tensor. + """ + if len(multiindices) != len(arguments): + raise ValueError(f"Got inconsistent lengths of arguments ({len(arguments)}) and multiindices ({len(multiindices)})") if len(arguments) == 0: # No arguments expression = gem.Indexed(gem.Variable("A", (1,)), (0,)) @@ -531,15 +554,20 @@ def expression(restricted): tuple(chain(*multiindices))) u_shape = numpy.array([numpy.prod(shape, dtype=int) for shape in shapes]) - if interior_facet: - c_shape = tuple(2 * u_shape) - slicez = [[slice(r * s, (r + 1) * s) - for r, s in zip(restrictions, u_shape)] - for restrictions in product((0, 1), repeat=len(arguments))] - else: - c_shape = tuple(u_shape) - slicez = [[slice(s) for s in u_shape]] - - varexp = gem.Variable("A", c_shape) + c_shape = copy.deepcopy(u_shape) + rs_tuples = [] + for arg_num, arg in enumerate(arguments): + integral_type = domain_integral_type_map[extract_unique_domain(arg)] + if integral_type is None: + raise RuntimeError(f"Can not determine integral_type on {arg}") + if integral_type.startswith("interior_facet"): + rs_tuples.append((0, 1)) + c_shape[arg_num] *= 2 + else: + rs_tuples.append((0, )) + slicez = [[slice(r * s, (r + 1) * s) + for r, s in zip(restrictions, u_shape)] + for restrictions in product(*rs_tuples)] + varexp = gem.Variable("A", tuple(c_shape)) expressions = [expression(gem.view(varexp, *slices)) for slices in slicez] return tuple(prune(expressions)) diff --git a/tsfc/kernel_interface/firedrake_loopy.py b/tsfc/kernel_interface/firedrake_loopy.py index cc35a7c5..9639ab85 100644 --- a/tsfc/kernel_interface/firedrake_loopy.py +++ b/tsfc/kernel_interface/firedrake_loopy.py @@ -1,9 +1,8 @@ import numpy from collections import namedtuple, OrderedDict -from functools import partial from ufl import Coefficient, FunctionSpace -from ufl.domain import extract_unique_domain +from ufl.domain import MixedMesh from finat.ufl import MixedElement as ufl_MixedElement, FiniteElement import gem @@ -11,7 +10,7 @@ import loopy as lp -from tsfc import kernel_args, fem +from tsfc import kernel_args from tsfc.finatinterface import create_element from tsfc.kernel_interface.common import KernelBuilderBase as _KernelBuilderBase, KernelBuilderMixin, get_index_names, check_requirements, prepare_coefficient, prepare_arguments, prepare_constant from tsfc.loopy import generate as generate_loopy @@ -25,20 +24,28 @@ 'flop_count', 'event']) -def make_builder(*args, **kwargs): - return partial(KernelBuilder, *args, **kwargs) +ActiveDomainNumbers = namedtuple('ActiveDomainNumbers', ['coordinates', + 'cell_orientations', + 'cell_sizes', + 'exterior_facets', + 'interior_facets', + 'exterior_facet_orientations', + 'interior_facet_orientations']) +ActiveDomainNumbers.__doc__ = """ + Active domain numbers collected for each key. + + """ class Kernel: - __slots__ = ("ast", "arguments", "integral_type", "oriented", "subdomain_id", - "domain_number", "needs_cell_sizes", "tabulations", + __slots__ = ("ast", "arguments", "integral_type", "subdomain_id", + "domain_number", "active_domain_numbers", "tabulations", "coefficient_numbers", "name", "flop_count", "event", "__weakref__") """A compiled Kernel object. :kwarg ast: The loopy kernel object. :kwarg integral_type: The type of integral. - :kwarg oriented: Does the kernel require cell_orientations. :kwarg subdomain_id: What is the subdomain id for this kernel. :kwarg domain_number: Which domain number in the original form does this kernel correspond to (can be used to index into @@ -46,15 +53,13 @@ class Kernel: :kwarg coefficient_numbers: A list of which coefficients from the form the kernel needs. :kwarg tabulations: The runtime tabulations this kernel requires - :kwarg needs_cell_sizes: Does the kernel require cell sizes. :kwarg name: The name of this kernel. :kwarg flop_count: Estimated total flops for this kernel. :kwarg event: name for logging event """ - def __init__(self, ast=None, arguments=None, integral_type=None, oriented=False, - subdomain_id=None, domain_number=None, + def __init__(self, ast=None, arguments=None, integral_type=None, + subdomain_id=None, domain_number=None, active_domain_numbers=None, coefficient_numbers=(), - needs_cell_sizes=False, tabulations=None, flop_count=0, name=None, @@ -63,11 +68,10 @@ def __init__(self, ast=None, arguments=None, integral_type=None, oriented=False, self.ast = ast self.arguments = arguments self.integral_type = integral_type - self.oriented = oriented self.domain_number = domain_number + self.active_domain_numbers = active_domain_numbers self.subdomain_id = subdomain_id self.coefficient_numbers = coefficient_numbers - self.needs_cell_sizes = needs_cell_sizes self.tabulations = tabulations self.flop_count = flop_count self.name = name @@ -76,21 +80,9 @@ def __init__(self, ast=None, arguments=None, integral_type=None, oriented=False, class KernelBuilderBase(_KernelBuilderBase): - def __init__(self, scalar_type, interior_facet=False): - """Initialise a kernel builder. - - :arg interior_facet: kernel accesses two cells - """ - super().__init__(scalar_type=scalar_type, interior_facet=interior_facet) - - # Cell orientation - if self.interior_facet: - cell_orientations = gem.Variable("cell_orientations", (2,)) - self._cell_orientations = (gem.Indexed(cell_orientations, (0,)), - gem.Indexed(cell_orientations, (1,))) - else: - cell_orientations = gem.Variable("cell_orientations", (1,)) - self._cell_orientations = (gem.Indexed(cell_orientations, (0,)),) + def __init__(self, scalar_type): + """Initialise a kernel builder.""" + super().__init__(scalar_type=scalar_type) def _coefficient(self, coefficient, name): """Prepare a coefficient. Adds glue code for the coefficient @@ -100,14 +92,41 @@ def _coefficient(self, coefficient, name): :arg name: coefficient name :returns: GEM expression representing the coefficient """ - expr = prepare_coefficient(coefficient, name, interior_facet=self.interior_facet) + expr = prepare_coefficient(coefficient, name, self._domain_integral_type_map) self.coefficient_map[coefficient] = expr return expr - def set_cell_sizes(self, domain): - """Setup a fake coefficient for "cell sizes". + def set_cell_orientations(self, domains): + """Set cell orientations for each domain. + + Parameters + ---------- + domains : list or tuple + All domains in the form. - :arg domain: The domain of the integral. + """ + # Cell orientation + self._cell_orientations = {} + for i, domain in enumerate(domains): + integral_type = self._domain_integral_type_map[domain] + if integral_type is None: + # See comment in prepare_coefficient. + self._cell_orientations[domain] = None + elif integral_type.startswith("interior_facet"): + cell_orientations = gem.Variable(f"cell_orientations_{i}", (2,)) + self._cell_orientations[domain] = (gem.Indexed(cell_orientations, (0,)), + gem.Indexed(cell_orientations, (1,))) + else: + cell_orientations = gem.Variable(f"cell_orientations_{i}", (1,)) + self._cell_orientations[domain] = (gem.Indexed(cell_orientations, (0,)),) + + def set_cell_sizes(self, domains): + """Setup a fake coefficient for "cell sizes" for each domain. + + Parameters + ---------- + domains : list or tuple + All domains in the form. This is required for scaling of derivative basis functions on physically mapped elements (Argyris, Bell, etc...). We need a @@ -117,13 +136,15 @@ def set_cell_sizes(self, domain): Should the domain have topological dimension 0 this does nothing. """ - if domain.ufl_cell().topological_dimension() > 0: - # Can't create P1 since only P0 is a valid finite element if - # topological_dimension is 0 and the concept of "cell size" - # is not useful for a vertex. - f = Coefficient(FunctionSpace(domain, FiniteElement("P", domain.ufl_cell(), 1))) - expr = prepare_coefficient(f, "cell_sizes", interior_facet=self.interior_facet) - self._cell_sizes = expr + self._cell_sizes = {} + for i, domain in enumerate(domains): + if domain.ufl_cell().topological_dimension() > 0: + # Can't create P1 since only P0 is a valid finite element if + # topological_dimension is 0 and the concept of "cell size" + # is not useful for a vertex. + f = Coefficient(FunctionSpace(domain, FiniteElement("P", domain.ufl_cell(), 1))) + expr = prepare_coefficient(f, f"cell_sizes_{i}", self._domain_integral_type_map) + self._cell_sizes[domain] = expr def create_element(self, element, **kwargs): """Create a FInAT element (suitable for tabulating with) given @@ -210,10 +231,12 @@ def construct_kernel(self, impero_c, index_names, needs_external_coords, log=Fal """ args = [self.output_arg] if self.oriented: - funarg = self.generate_arg_from_expression(self._cell_orientations, dtype=numpy.int32) + cell_orientations, = tuple(self._cell_orientations.values()) + funarg = self.generate_arg_from_expression(cell_orientations, dtype=numpy.int32) args.append(kernel_args.CellOrientationsKernelArg(funarg)) if self.cell_sizes: - funarg = self.generate_arg_from_expression(self._cell_sizes) + cell_sizes, = tuple(self._cell_sizes.values()) + funarg = self.generate_arg_from_expression(cell_sizes) args.append(kernel_args.CellSizesKernelArg(funarg)) for _, expr in self.coefficient_map.items(): # coefficient_map is OrderedDict. @@ -245,50 +268,19 @@ class KernelBuilder(KernelBuilderBase, KernelBuilderMixin): def __init__(self, integral_data_info, scalar_type, dont_split=(), diagonal=False): """Initialise a kernel builder.""" - integral_type = integral_data_info.integral_type - super(KernelBuilder, self).__init__(scalar_type, integral_type.startswith("interior_facet")) + super(KernelBuilder, self).__init__(scalar_type) self.fem_scalar_type = scalar_type - self.diagonal = diagonal self.local_tensor = None - self.coefficient_split = {} self.coefficient_number_index_map = OrderedDict() self.dont_split = frozenset(dont_split) - - # Facet number - if integral_type in ['exterior_facet', 'exterior_facet_vert']: - facet = gem.Variable('facet', (1,)) - self._entity_number = {None: gem.VariableIndex(gem.Indexed(facet, (0,)))} - facet_orientation = gem.Variable('facet_orientation', (1,), dtype=gem.uint_type) - self._entity_orientation = {None: gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,)))} - elif integral_type in ['interior_facet', 'interior_facet_vert']: - facet = gem.Variable('facet', (2,)) - self._entity_number = { - '+': gem.VariableIndex(gem.Indexed(facet, (0,))), - '-': gem.VariableIndex(gem.Indexed(facet, (1,))) - } - facet_orientation = gem.Variable('facet_orientation', (2,), dtype=gem.uint_type) - self._entity_orientation = { - '+': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,))), - '-': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (1,))) - } - elif integral_type == 'interior_facet_horiz': - self._entity_number = {'+': 1, '-': 0} - facet_orientation = gem.Variable('facet_orientation', (1,), dtype=gem.uint_type) # base mesh entity orientation - self._entity_orientation = { - '+': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,))), - '-': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,))) - } - - self.set_arguments(integral_data_info.arguments) self.integral_data_info = integral_data_info + self._domain_integral_type_map = integral_data_info.domain_integral_type_map # For consistency with ExpressionKernelBuilder. + self.set_arguments() - def set_arguments(self, arguments): - """Process arguments. - - :arg arguments: :class:`ufl.Argument`s - :returns: GEM expression representing the return variable - """ + def set_arguments(self): + """Process arguments.""" + arguments = self.integral_data_info.arguments argument_multiindices = tuple(create_element(arg.ufl_element()).get_indices() for arg in arguments) if self.diagonal: @@ -299,53 +291,97 @@ def set_arguments(self, arguments): argument_multiindices = (a, a) return_variables = prepare_arguments(arguments, argument_multiindices, - interior_facet=self.interior_facet, + self.integral_data_info.domain_integral_type_map, diagonal=self.diagonal) self.return_variables = return_variables self.argument_multiindices = argument_multiindices - def set_coordinates(self, domain): - """Prepare the coordinate field. + def set_entity_numbers(self, domains): + """Set entity numbers for each domain. + + Parameters + ---------- + domains : list or tuple + All domains in the form. - :arg domain: :class:`ufl.Domain` """ - # Create a fake coordinate coefficient for a domain. - f = Coefficient(FunctionSpace(domain, domain.ufl_coordinate_element())) - self.domain_coordinate[domain] = f - self._coefficient(f, "coords") + self._entity_numbers = {} + for i, domain in enumerate(domains): + # Facet number + integral_type = self.integral_data_info.domain_integral_type_map[domain] + if integral_type in ['exterior_facet', 'exterior_facet_vert']: + facet = gem.Variable(f'facet_{i}', (1,)) + self._entity_numbers[domain] = {None: gem.VariableIndex(gem.Indexed(facet, (0,))), + '|': gem.VariableIndex(gem.Indexed(facet, (0,)))} + elif integral_type in ['interior_facet', 'interior_facet_vert']: + facet = gem.Variable(f'facet_{i}', (2,)) + self._entity_numbers[domain] = { + '+': gem.VariableIndex(gem.Indexed(facet, (0,))), + '-': gem.VariableIndex(gem.Indexed(facet, (1,))) + } + elif integral_type == 'interior_facet_horiz': + self._entity_numbers[domain] = {'+': 1, '-': 0} + + def set_entity_orientations(self, domains): + """Set entity orientations for each domain. + + Parameters + ---------- + domains : list or tuple + All domains in the form. - def set_coefficients(self, integral_data, form_data): - """Prepare the coefficients of the form. + """ + self._entity_orientations = {} + for i, domain in enumerate(domains): + integral_type = self.integral_data_info.domain_integral_type_map[domain] + if integral_type in ['exterior_facet', 'exterior_facet_vert']: + facet_orientation = gem.Variable(f'facet_orientation_{i}', (1,), dtype=gem.uint_type) + self._entity_orientations[domain] = {None: gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,))), + '|': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,)))} + elif integral_type in ['interior_facet', 'interior_facet_vert']: + facet_orientation = gem.Variable(f'facet_orientation_{i}', (2,), dtype=gem.uint_type) + self._entity_orientations[domain] = { + '+': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,))), + '-': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (1,))) + } + elif integral_type == 'interior_facet_horiz': + facet_orientation = gem.Variable(f'facet_orientation_{i}', (1,), dtype=gem.uint_type) # base mesh entity orientation + self._entity_orientations[domain] = { + '+': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,))), + '-': gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,))) + } + + def set_coordinates(self, domains): + """Set coordinates for each domain. + + Parameters + ---------- + domains : list or tuple + All domains in the form. - :arg integral_data: UFL integral data - :arg form_data: UFL form data """ - # enabled_coefficients is a boolean array that indicates which - # of reduced_coefficients the integral requires. - n, k = 0, 0 - for i in range(len(integral_data.enabled_coefficients)): - if integral_data.enabled_coefficients[i]: - original = form_data.reduced_coefficients[i] - coefficient = form_data.function_replace_map[original] - if type(coefficient.ufl_element()) == ufl_MixedElement: - if original in self.dont_split: - self.coefficient_split[coefficient] = [coefficient] - self._coefficient(coefficient, f"w_{k}") - self.coefficient_number_index_map[coefficient] = (n, 0) - k += 1 - else: - self.coefficient_split[coefficient] = [] - for j, element in enumerate(coefficient.ufl_element().sub_elements): - c = Coefficient(FunctionSpace(extract_unique_domain(coefficient), element)) - self.coefficient_split[coefficient].append(c) - self._coefficient(c, f"w_{k}") - self.coefficient_number_index_map[c] = (n, j) - k += 1 - else: - self._coefficient(coefficient, f"w_{k}") - self.coefficient_number_index_map[coefficient] = (n, 0) + # Create a fake coordinate coefficient for a domain. + for i, domain in enumerate(domains): + if isinstance(domain, MixedMesh): + raise RuntimeError("Found a MixedMesh") + f = Coefficient(FunctionSpace(domain, domain.ufl_coordinate_element())) + self.domain_coordinate[domain] = f + self._coefficient(f, f"coords_{i}") + + def set_coefficients(self): + """Prepare the coefficients of the form.""" + info = self.integral_data_info + k = 0 + for n, coeff in enumerate(info.coefficients): + if coeff in info.coefficient_split: + for i, c in enumerate(info.coefficient_split[coeff]): + self.coefficient_number_index_map[c] = (n, i) + self._coefficient(c, f"w_{k}") k += 1 - n += 1 + else: + self.coefficient_number_index_map[coeff] = (n, 0) + self._coefficient(coeff, f"w_{k}") + k += 1 def set_constants(self, constants): for i, const in enumerate(constants): @@ -368,7 +404,7 @@ def construct_kernel(self, name, ctx, log=False): :arg log: bool if the Kernel should be profiled with Log events :returns: :class:`Kernel` object """ - impero_c, oriented, needs_cell_sizes, tabulations, active_variables = self.compile_gem(ctx) + impero_c, _, _, tabulations, active_variables = self.compile_gem(ctx) if impero_c is None: return self.construct_empty_kernel(name) info = self.integral_data_info @@ -384,48 +420,72 @@ def construct_kernel(self, name, ctx, log=False): # Add return arg funarg = self.generate_arg_from_expression(self.return_variables) args = [kernel_args.OutputKernelArg(funarg)] - # Add coordinates arg - coord = self.domain_coordinate[info.domain] - expr = self.coefficient_map[coord] - funarg = self.generate_arg_from_expression(expr) - args.append(kernel_args.CoordinatesKernelArg(funarg)) - if oriented: - funarg = self.generate_arg_from_expression(self._cell_orientations, dtype=numpy.int32) - args.append(kernel_args.CellOrientationsKernelArg(funarg)) - if needs_cell_sizes: - funarg = self.generate_arg_from_expression(self._cell_sizes) - args.append(kernel_args.CellSizesKernelArg(funarg)) + active_domain_numbers_coordinates, args_ = self.make_active_domain_numbers({d: self.coefficient_map[c] for d, c in self.domain_coordinate.items()}, + active_variables, + kernel_args.CoordinatesKernelArg) + args.extend(args_) + active_domain_numbers_cell_orientations, args_ = self.make_active_domain_numbers(self._cell_orientations, + active_variables, + kernel_args.CellOrientationsKernelArg, + dtype=numpy.int32) + args.extend(args_) + active_domain_numbers_cell_sizes, args_ = self.make_active_domain_numbers(self._cell_sizes, + active_variables, + kernel_args.CellSizesKernelArg) + args.extend(args_) coefficient_indices = OrderedDict() for coeff, (number, index) in self.coefficient_number_index_map.items(): a = coefficient_indices.setdefault(number, []) expr = self.coefficient_map[coeff] + if expr is None: + # See comment in prepare_coefficient. + continue var, = gem.extract_type(expr if isinstance(expr, tuple) else (expr, ), gem.Variable) if var in active_variables: funarg = self.generate_arg_from_expression(expr) args.append(kernel_args.CoefficientKernelArg(funarg)) a.append(index) - - # now constants for gemexpr in self.constant_map.values(): funarg = self.generate_arg_from_expression(gemexpr) args.append(kernel_args.ConstantKernelArg(funarg)) - coefficient_indices = tuple(tuple(v) for v in coefficient_indices.values()) assert len(coefficient_indices) == len(info.coefficient_numbers) - if info.integral_type in ["exterior_facet", "exterior_facet_vert"]: - ext_loopy_arg = lp.GlobalArg("facet", numpy.uint32, shape=(1,)) - args.append(kernel_args.ExteriorFacetKernelArg(ext_loopy_arg)) - elif info.integral_type in ["interior_facet", "interior_facet_vert"]: - int_loopy_arg = lp.GlobalArg("facet", numpy.uint32, shape=(2,)) - args.append(kernel_args.InteriorFacetKernelArg(int_loopy_arg)) - # Will generalise this in the submesh PR. - if fem.PointSetContext(**self.fem_config()).use_canonical_quadrature_point_ordering: - if info.integral_type == "exterior_facet": - ext_ornt_loopy_arg = lp.GlobalArg("facet_orientation", gem.uint_type, shape=(1,)) - args.append(kernel_args.ExteriorFacetOrientationKernelArg(ext_ornt_loopy_arg)) - elif info.integral_type == "interior_facet": - int_ornt_loopy_arg = lp.GlobalArg("facet_orientation", gem.uint_type, shape=(2,)) - args.append(kernel_args.InteriorFacetOrientationKernelArg(int_ornt_loopy_arg)) + ext_dict = {} + for domain, expr in self._entity_numbers.items(): + integral_type = info.domain_integral_type_map[domain] + ext_dict[domain] = expr[None].expression if integral_type in ["exterior_facet", "exterior_facet_vert"] else None + active_domain_numbers_exterior_facets, args_ = self.make_active_domain_numbers(ext_dict, + active_variables, + kernel_args.ExteriorFacetKernelArg, + dtype=numpy.uint32) + args.extend(args_) + int_dict = {} + for domain, expr in self._entity_numbers.items(): + integral_type = info.domain_integral_type_map[domain] + int_dict[domain] = expr['+'].expression if integral_type in ["interior_facet", "interior_facet_vert"] else None + active_domain_numbers_interior_facets, args_ = self.make_active_domain_numbers(int_dict, + active_variables, + kernel_args.InteriorFacetKernelArg, + dtype=numpy.uint32) + args.extend(args_) + ext_dict = {} + for domain, expr in self._entity_orientations.items(): + integral_type = info.domain_integral_type_map[domain] + ext_dict[domain] = expr[None].expression if integral_type in ["exterior_facet", "exterior_facet_vert"] else None + active_domain_numbers_exterior_facet_orientations, args_ = self.make_active_domain_numbers(ext_dict, + active_variables, + kernel_args.ExteriorFacetOrientationKernelArg, + dtype=gem.uint_type) + args.extend(args_) + int_dict = {} + for domain, expr in self._entity_orientations.items(): + integral_type = info.domain_integral_type_map[domain] + int_dict[domain] = expr['+'].expression if integral_type in ["interior_facet", "interior_facet_vert", "interior_facet_horiz"] else None + active_domain_numbers_interior_facet_orientations, args_ = self.make_active_domain_numbers(int_dict, + active_variables, + kernel_args.InteriorFacetOrientationKernelArg, + dtype=gem.uint_type) + args.extend(args_) for name_, shape in tabulations: tab_loopy_arg = lp.GlobalArg(name_, dtype=self.scalar_type, shape=shape) args.append(kernel_args.TabulationKernelArg(tab_loopy_arg)) @@ -438,9 +498,14 @@ def construct_kernel(self, name, ctx, log=False): integral_type=info.integral_type, subdomain_id=info.subdomain_id, domain_number=info.domain_number, + active_domain_numbers=ActiveDomainNumbers(coordinates=tuple(active_domain_numbers_coordinates), + cell_orientations=tuple(active_domain_numbers_cell_orientations), + cell_sizes=tuple(active_domain_numbers_cell_sizes), + exterior_facets=tuple(active_domain_numbers_exterior_facets), + interior_facets=tuple(active_domain_numbers_interior_facets), + exterior_facet_orientations=tuple(active_domain_numbers_exterior_facet_orientations), + interior_facet_orientations=tuple(active_domain_numbers_interior_facet_orientations),), coefficient_numbers=tuple(zip(info.coefficient_numbers, coefficient_indices)), - oriented=oriented, - needs_cell_sizes=needs_cell_sizes, tabulations=tabulations, flop_count=flop_count, name=name, @@ -453,3 +518,36 @@ def construct_empty_kernel(self, name): :returns: None """ return None + + def make_active_domain_numbers(self, domain_expr_dict, active_variables, kernel_arg_type, dtype=None): + """Make active domain numbers. + + Parameters + ---------- + domain_expr_dict : dict + Map from domains to expressions; must be ordered as extract_domains(form). + active_variables : tuple + Active variables in the DAG. + kernel_arg_type : KernelArg + Type of `KernelArg`. + dtype : numpy.dtype + dtype. + + Returns + ------- + tuple + Tuple of active domain numbers and corresponding kernel args. + + """ + active_dns = [] + args = [] + for i, expr in enumerate(domain_expr_dict.values()): + if expr is None: + var = None + else: + var, = gem.extract_type(expr if isinstance(expr, tuple) else (expr, ), gem.Variable) + if var in active_variables: + funarg = self.generate_arg_from_expression(expr, dtype=dtype) + args.append(kernel_arg_type(funarg)) + active_dns.append(i) + return tuple(active_dns), tuple(args) diff --git a/tsfc/modified_terminals.py b/tsfc/modified_terminals.py index 8c5162bf..9e694f2e 100644 --- a/tsfc/modified_terminals.py +++ b/tsfc/modified_terminals.py @@ -21,7 +21,7 @@ """Definitions of 'modified terminals', a core concept in uflacs.""" from ufl.classes import (ReferenceValue, ReferenceGrad, - NegativeRestricted, PositiveRestricted, + NegativeRestricted, PositiveRestricted, SingleValueRestricted, ToBeRestricted, Restricted, ConstantValue, Jacobian, SpatialCoordinate, Zero) from ufl.checks import is_cellwise_constant @@ -39,7 +39,7 @@ class ModifiedTerminal(object): terminal - the underlying Terminal object local_derivatives - tuple of ints, each meaning derivative in that local direction reference_value - bool, whether this is represented in reference frame - restriction - None, '+' or '-' + restriction - None, '+', '-', '|', or '?' """ def __init__(self, expr, terminal, local_derivatives, restriction, reference_value): @@ -175,5 +175,9 @@ def construct_modified_terminal(mt, terminal): expr = PositiveRestricted(expr) elif mt.restriction == '-': expr = NegativeRestricted(expr) + elif mt.restriction == '|': + expr = SingleValueRestricted(expr) + elif mt.restriction == '?': + expr = ToBeRestricted(expr) return expr diff --git a/tsfc/ufl_utils.py b/tsfc/ufl_utils.py index 46192663..ae5fd0da 100644 --- a/tsfc/ufl_utils.py +++ b/tsfc/ufl_utils.py @@ -43,8 +43,11 @@ def compute_form_data(form, do_apply_integral_scaling=True, do_apply_geometry_lowering=True, preserve_geometry_types=preserve_geometry_types, + do_apply_default_restrictions=True, do_apply_restrictions=True, do_estimate_degrees=True, + do_split_coefficients=None, + do_assume_single_integral_type=False, complex_mode=False): """Preprocess UFL form in a format suitable for TSFC. Return form data. @@ -59,8 +62,11 @@ def compute_form_data(form, do_apply_integral_scaling=do_apply_integral_scaling, do_apply_geometry_lowering=do_apply_geometry_lowering, preserve_geometry_types=preserve_geometry_types, + do_apply_default_restrictions=do_apply_default_restrictions, do_apply_restrictions=do_apply_restrictions, do_estimate_degrees=do_estimate_degrees, + do_split_coefficients=do_split_coefficients, + do_assume_single_integral_type=do_assume_single_integral_type, complex_mode=complex_mode ) constants = extract_firedrake_constants(form) @@ -166,6 +172,8 @@ def _modified_terminal(self, o): positive_restricted = _modified_terminal negative_restricted = _modified_terminal + single_value_restricted = _modified_terminal + to_be_restricted = _modified_terminal reference_grad = _modified_terminal reference_value = _modified_terminal @@ -250,8 +258,13 @@ def modified_terminal(self, o): mt = analyse_modified_terminal(o) t = mt.terminal r = mt.restriction - if isinstance(t, Argument) and r != self.restrictions[t.number()]: - return Zero(o.ufl_shape, o.ufl_free_indices, o.ufl_index_dimensions) + if r == '?': + raise RuntimeError("Not expecting '?' restriction at this stage") + if isinstance(t, Argument) and r in ['+', '-']: + if r == self.restrictions[t.number()]: + return o + else: + return Zero(o.ufl_shape, o.ufl_free_indices, o.ufl_index_dimensions) else: return o