From 3a81fe838b34af3c8e49a960f93a0c87404d738b Mon Sep 17 00:00:00 2001 From: Chris Richardson Date: Fri, 25 Aug 2023 14:40:31 +0100 Subject: [PATCH] Adjust for dtype --- ffcx/codegeneration/C/c_implementation.py | 332 ++++++++++++++++++++ ffcx/codegeneration/C/expressions.py | 6 +- ffcx/codegeneration/C/integrals.py | 5 +- ffcx/codegeneration/access.py | 19 +- ffcx/codegeneration/backend.py | 17 +- ffcx/codegeneration/definitions.py | 17 +- ffcx/codegeneration/expression_generator.py | 17 +- ffcx/codegeneration/geometry.py | 16 +- ffcx/codegeneration/integral_generator.py | 117 +++---- ffcx/codegeneration/symbols.py | 50 ++- ffcx/codegeneration/utils.py | 34 ++ 11 files changed, 464 insertions(+), 166 deletions(-) create mode 100644 ffcx/codegeneration/C/c_implementation.py create mode 100644 ffcx/codegeneration/utils.py diff --git a/ffcx/codegeneration/C/c_implementation.py b/ffcx/codegeneration/C/c_implementation.py new file mode 100644 index 000000000..f735673c8 --- /dev/null +++ b/ffcx/codegeneration/C/c_implementation.py @@ -0,0 +1,332 @@ +# Copyright (C) 2023 Chris Richardson +# +# This file is part of FFCx. (https://www.fenicsproject.org) +# +# SPDX-License-Identifier: LGPL-3.0-or-later + +import warnings +import ffcx.codegeneration.lnodes as L +from ffcx.codegeneration.utils import scalar_to_value_type + +math_table = { + "double": { + "sqrt": "sqrt", + "abs": "fabs", + "cos": "cos", + "sin": "sin", + "tan": "tan", + "acos": "acos", + "asin": "asin", + "atan": "atan", + "cosh": "cosh", + "sinh": "sinh", + "tanh": "tanh", + "acosh": "acosh", + "asinh": "asinh", + "atanh": "atanh", + "power": "pow", + "exp": "exp", + "ln": "log", + "erf": "erf", + "atan_2": "atan2", + "min_value": "fmin", + "max_value": "fmax", + "bessel_y": "yn", + "bessel_j": "jn", + }, + "float": { + "sqrt": "sqrtf", + "abs": "fabsf", + "cos": "cosf", + "sin": "sinf", + "tan": "tanf", + "acos": "acosf", + "asin": "asinf", + "atan": "atanf", + "cosh": "coshf", + "sinh": "sinhf", + "tanh": "tanhf", + "acosh": "acoshf", + "asinh": "asinhf", + "atanh": "atanhf", + "power": "powf", + "exp": "expf", + "ln": "logf", + "erf": "erff", + "atan_2": "atan2f", + "min_value": "fminf", + "max_value": "fmaxf", + "bessel_y": "yn", + "bessel_j": "jn", + }, + "long double": { + "sqrt": "sqrtl", + "abs": "fabsl", + "cos": "cosl", + "sin": "sinl", + "tan": "tanl", + "acos": "acosl", + "asin": "asinl", + "atan": "atanl", + "cosh": "coshl", + "sinh": "sinhl", + "tanh": "tanhl", + "acosh": "acoshl", + "asinh": "asinhl", + "atanh": "atanhl", + "power": "powl", + "exp": "expl", + "ln": "logl", + "erf": "erfl", + "atan_2": "atan2l", + "min_value": "fminl", + "max_value": "fmaxl", + }, + "double _Complex": { + "sqrt": "csqrt", + "abs": "cabs", + "cos": "ccos", + "sin": "csin", + "tan": "ctan", + "acos": "cacos", + "asin": "casin", + "atan": "catan", + "cosh": "ccosh", + "sinh": "csinh", + "tanh": "ctanh", + "acosh": "cacosh", + "asinh": "casinh", + "atanh": "catanh", + "power": "cpow", + "exp": "cexp", + "ln": "clog", + "real": "creal", + "imag": "cimag", + "conj": "conj", + "max_value": "fmax", + "min_value": "fmin", + "bessel_y": "yn", + "bessel_j": "jn", + }, + "float _Complex": { + "sqrt": "csqrtf", + "abs": "cabsf", + "cos": "ccosf", + "sin": "csinf", + "tan": "ctanf", + "acos": "cacosf", + "asin": "casinf", + "atan": "catanf", + "cosh": "ccoshf", + "sinh": "csinhf", + "tanh": "ctanhf", + "acosh": "cacoshf", + "asinh": "casinhf", + "atanh": "catanhf", + "power": "cpowf", + "exp": "cexpf", + "ln": "clogf", + "real": "crealf", + "imag": "cimagf", + "conj": "conjf", + "max_value": "fmaxf", + "min_value": "fminf", + "bessel_y": "yn", + "bessel_j": "jn", + }, +} + + +def build_initializer_lists(values): + arr = "{" + if len(values.shape) == 1: + arr += ", ".join(str(v) for v in values) + elif len(values.shape) > 1: + arr += ",\n ".join(build_initializer_lists(v) for v in values) + arr += "}" + return arr + + +class CFormatter(object): + def __init__(self, scalar) -> None: + self.scalar_type = scalar + self.real_type = scalar_to_value_type(scalar) + + def format_statement_list(self, slist) -> str: + return "".join(self.c_format(s) for s in slist.statements) + + def format_comment(self, c) -> str: + return "// " + c.comment + "\n" + + def format_array_decl(self, arr) -> str: + dtype = arr.symbol.dtype + assert dtype is not None + + if dtype == L.DataType.SCALAR: + typename = self.scalar_type + elif dtype == L.DataType.REAL: + typename = self.real_type + else: + raise ValueError(f"Invalid dtype: {dtype}") + + symbol = self.c_format(arr.symbol) + dims = "".join([f"[{i}]" for i in arr.sizes]) + if arr.values is None: + assert arr.const is False + return f"{typename} {symbol}{dims};\n" + + vals = build_initializer_lists(arr.values) + cstr = "static const " if arr.const else "" + return f"{cstr}{typename} {symbol}{dims} = {vals};\n" + + def format_array_access(self, arr) -> str: + name = self.c_format(arr.array) + indices = f"[{']['.join(self.c_format(i) for i in arr.indices)}]" + return f"{name}{indices}" + + def format_variable_decl(self, v) -> str: + val = self.c_format(v.value) + symbol = self.c_format(v.symbol) + assert v.symbol.dtype + if v.symbol.dtype == L.DataType.SCALAR: + typename = self.scalar_type + elif v.symbol.dtype == L.DataType.REAL: + typename = self.real_type + return f"{typename} {symbol} = {val};\n" + + def format_nary_op(self, oper) -> str: + # Format children + args = [self.c_format(arg) for arg in oper.args] + + # Apply parentheses + for i in range(len(args)): + if oper.args[i].precedence >= oper.precedence: + args[i] = "(" + args[i] + ")" + + # Return combined string + return f" {oper.op} ".join(args) + + def format_binary_op(self, oper) -> str: + # Format children + lhs = self.c_format(oper.lhs) + rhs = self.c_format(oper.rhs) + + # Apply parentheses + if oper.lhs.precedence >= oper.precedence: + lhs = f"({lhs})" + if oper.rhs.precedence >= oper.precedence: + rhs = f"({rhs})" + + # Return combined string + return f"{lhs} {oper.op} {rhs}" + + def format_neg(self, val) -> str: + arg = self.c_format(val.arg) + return f"-{arg}" + + def format_not(self, val) -> str: + arg = self.c_format(val.arg) + return f"{val.op}({arg})" + + def format_literal_float(self, val) -> str: + return f"{val.value}" + + def format_literal_int(self, val) -> str: + return f"{val.value}" + + def format_for_range(self, r) -> str: + begin = self.c_format(r.begin) + end = self.c_format(r.end) + index = self.c_format(r.index) + output = f"for (int {index} = {begin}; {index} < {end}; ++{index})\n" + output += "{\n" + body = self.c_format(r.body) + for line in body.split("\n"): + if len(line) > 0: + output += f" {line}\n" + output += "}\n" + return output + + def format_statement(self, s) -> str: + return self.c_format(s.expr) + + def format_assign(self, expr) -> str: + rhs = self.c_format(expr.rhs) + lhs = self.c_format(expr.lhs) + return f"{lhs} {expr.op} {rhs};\n" + + def format_conditional(self, s) -> str: + # Format children + c = self.c_format(s.condition) + t = self.c_format(s.true) + f = self.c_format(s.false) + + # Apply parentheses + if s.condition.precedence >= s.precedence: + c = "(" + c + ")" + if s.true.precedence >= s.precedence: + t = "(" + t + ")" + if s.false.precedence >= s.precedence: + f = "(" + f + ")" + + # Return combined string + return c + " ? " + t + " : " + f + + def format_symbol(self, s) -> str: + return f"{s.name}" + + def format_math_function(self, c) -> str: + # Get a table of functions for this type, if available + arg_type = self.scalar_type + if hasattr(c.args[0], "dtype"): + if c.args[0].dtype == L.DataType.REAL: + arg_type = self.real_type + else: + warnings.warn(f"Syntax item without dtype {c.args[0]}") + + dtype_math_table = math_table.get(arg_type, {}) + + # Get a function from the table, if available, else just use bare name + func = dtype_math_table.get(c.function, c.function) + args = ", ".join(self.c_format(arg) for arg in c.args) + return f"{func}({args})" + + c_impl = { + "StatementList": format_statement_list, + "Comment": format_comment, + "ArrayDecl": format_array_decl, + "ArrayAccess": format_array_access, + "VariableDecl": format_variable_decl, + "ForRange": format_for_range, + "Statement": format_statement, + "Assign": format_assign, + "AssignAdd": format_assign, + "Product": format_nary_op, + "Neg": format_neg, + "Sum": format_nary_op, + "Add": format_binary_op, + "Sub": format_binary_op, + "Mul": format_binary_op, + "Div": format_binary_op, + "Not": format_not, + "LiteralFloat": format_literal_float, + "LiteralInt": format_literal_int, + "Symbol": format_symbol, + "Conditional": format_conditional, + "MathFunction": format_math_function, + "And": format_binary_op, + "Or": format_binary_op, + "NE": format_binary_op, + "EQ": format_binary_op, + "GE": format_binary_op, + "LE": format_binary_op, + "GT": format_binary_op, + "LT": format_binary_op, + } + + def c_format(self, s) -> str: + name = s.__class__.__name__ + try: + return self.c_impl[name](self, s) + except KeyError: + raise RuntimeError("Unknown statement: ", name) diff --git a/ffcx/codegeneration/C/expressions.py b/ffcx/codegeneration/C/expressions.py index a994b8c5c..530c0ab2a 100644 --- a/ffcx/codegeneration/C/expressions.py +++ b/ffcx/codegeneration/C/expressions.py @@ -9,7 +9,7 @@ from ffcx.codegeneration.C import expressions_template from ffcx.codegeneration.expression_generator import ExpressionGenerator from ffcx.codegeneration.backend import FFCXBackend -from ffcx.codegeneration.C.format_lines import format_indented_lines +from ffcx.codegeneration.C.c_implementation import CFormatter from ffcx.naming import cdtype_to_numpy, scalar_to_value_type logger = logging.getLogger("ffcx") @@ -36,8 +36,8 @@ def generator(ir, options): parts = eg.generate() - body = format_indented_lines(parts.cs_format(), 1) - d["tabulate_expression"] = body + CF = CFormatter(options["scalar_type"]) + d["tabulate_expression"] = CF.c_format(parts) if len(ir.original_coefficient_positions) > 0: d["original_coefficient_positions"] = f"original_coefficient_positions_{ir.name}" diff --git a/ffcx/codegeneration/C/integrals.py b/ffcx/codegeneration/C/integrals.py index b1b73c525..c6916591a 100644 --- a/ffcx/codegeneration/C/integrals.py +++ b/ffcx/codegeneration/C/integrals.py @@ -9,7 +9,7 @@ from ffcx.codegeneration.integral_generator import IntegralGenerator from ffcx.codegeneration.C import integrals_template as ufcx_integrals from ffcx.codegeneration.backend import FFCXBackend -from ffcx.codegeneration.C.format_lines import format_indented_lines +from ffcx.codegeneration.C.c_implementation import CFormatter from ffcx.naming import cdtype_to_numpy, scalar_to_value_type logger = logging.getLogger("ffcx") @@ -36,7 +36,8 @@ def generator(ir, options): parts = ig.generate() # Format code as string - body = format_indented_lines(parts.cs_format(ir.precision), 1) + CF = CFormatter(options["scalar_type"]) + body = CF.c_format(parts) # Generate generic FFCx code snippets and add specific parts code = {} diff --git a/ffcx/codegeneration/access.py b/ffcx/codegeneration/access.py index 9d1dcc985..b88de608b 100644 --- a/ffcx/codegeneration/access.py +++ b/ffcx/codegeneration/access.py @@ -11,6 +11,7 @@ import ufl import basix.ufl from ffcx.element_interface import convert_element +import ffcx.codegeneration.lnodes as L logger = logging.getLogger("ffcx") @@ -18,12 +19,11 @@ class FFCXBackendAccess(object): """FFCx specific cpp formatter class.""" - def __init__(self, ir, language, symbols, options): + def __init__(self, ir, symbols, options): # Store ir and options self.entitytype = ir.entitytype self.integral_type = ir.integral_type - self.language = language self.symbols = symbols self.options = options @@ -178,10 +178,9 @@ def jacobian(self, e, mt, tabledata, num_points): return self.symbols.J_component(mt) def reference_cell_volume(self, e, mt, tabledata, access): - L = self.language cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname() if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"): - return L.Symbol(f"{cellname}_reference_cell_volume") + return L.Symbol(f"{cellname}_reference_cell_volume", dtype=L.DataType.REAL) else: raise RuntimeError(f"Unhandled cell types {cellname}.") @@ -189,7 +188,7 @@ def reference_facet_volume(self, e, mt, tabledata, access): L = self.language cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname() if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"): - return L.Symbol(f"{cellname}_reference_facet_volume") + return L.Symbol(f"{cellname}_reference_facet_volume", dtype=L.DataType.REAL) else: raise RuntimeError(f"Unhandled cell types {cellname}.") @@ -197,7 +196,7 @@ def reference_normal(self, e, mt, tabledata, access): L = self.language cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname() if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"): - table = L.Symbol(f"{cellname}_reference_facet_normals") + table = L.Symbol(f"{cellname}_reference_facet_normals", dtype=L.DataType.REAL) facet = self.symbols.entity("facet", mt.restriction) return table[facet][mt.component[0]] else: @@ -207,7 +206,7 @@ def cell_facet_jacobian(self, e, mt, tabledata, num_points): L = self.language cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname() if cellname in ("triangle", "tetrahedron", "quadrilateral", "hexahedron"): - table = L.Symbol(f"{cellname}_reference_facet_jacobian") + table = L.Symbol(f"{cellname}_reference_facet_jacobian", dtype=L.DataType.REAL) facet = self.symbols.entity("facet", mt.restriction) return table[facet][mt.component[0]][mt.component[1]] elif cellname == "interval": @@ -219,7 +218,7 @@ def reference_cell_edge_vectors(self, e, mt, tabledata, num_points): L = self.language cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname() if cellname in ("triangle", "tetrahedron", "quadrilateral", "hexahedron"): - table = L.Symbol(f"{cellname}_reference_edge_vectors") + table = L.Symbol(f"{cellname}_reference_edge_vectors", dtype=L.DataType.REAL) return table[mt.component[0]][mt.component[1]] elif cellname == "interval": raise RuntimeError("The reference cell edge vectors doesn't make sense for interval cell.") @@ -230,7 +229,7 @@ def reference_facet_edge_vectors(self, e, mt, tabledata, num_points): L = self.language cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname() if cellname in ("tetrahedron", "hexahedron"): - table = L.Symbol(f"{cellname}_reference_edge_vectors") + table = L.Symbol(f"{cellname}_reference_edge_vectors", dtype=L.DataType.REAL) facet = self.symbols.entity("facet", mt.restriction) return table[facet][mt.component[0]][mt.component[1]] elif cellname in ("interval", "triangle", "quadrilateral"): @@ -246,7 +245,7 @@ def facet_orientation(self, e, mt, tabledata, num_points): if cellname not in ("interval", "triangle", "tetrahedron"): raise RuntimeError(f"Unhandled cell types {cellname}.") - table = L.Symbol(f"{cellname}_facet_orientations") + table = L.Symbol(f"{cellname}_facet_orientations", dtype=L.DataType.INT) facet = self.symbols.entity("facet", mt.restriction) return table[facet] diff --git a/ffcx/codegeneration/backend.py b/ffcx/codegeneration/backend.py index 0b9c5d8d2..b874196e4 100644 --- a/ffcx/codegeneration/backend.py +++ b/ffcx/codegeneration/backend.py @@ -5,11 +5,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """Collection of FFCx specific pieces for the code generation phase.""" -import types - -import ffcx.codegeneration.C.cnodes from ffcx.codegeneration.access import FFCXBackendAccess -from ffcx.codegeneration.C.ufl_to_cnodes import UFL2CNodesTranslatorCpp from ffcx.codegeneration.definitions import FFCXBackendDefinitions from ffcx.codegeneration.symbols import FFCXBackendSymbols @@ -19,19 +15,12 @@ class FFCXBackend(object): def __init__(self, ir, options): - # This is the seam where cnodes/C is chosen for the FFCx backend - self.language: types.ModuleType = ffcx.codegeneration.C.cnodes - scalar_type = options["scalar_type"] - self.ufl_to_language = UFL2CNodesTranslatorCpp(self.language, scalar_type) - coefficient_numbering = ir.coefficient_numbering coefficient_offsets = ir.coefficient_offsets original_constant_offsets = ir.original_constant_offsets - self.symbols = FFCXBackendSymbols(self.language, coefficient_numbering, + self.symbols = FFCXBackendSymbols(coefficient_numbering, coefficient_offsets, original_constant_offsets) - self.definitions = FFCXBackendDefinitions(ir, self.language, - self.symbols, options) - self.access = FFCXBackendAccess(ir, self.language, self.symbols, - options) + self.definitions = FFCXBackendDefinitions(ir, self.symbols, options) + self.access = FFCXBackendAccess(ir, self.symbols, options) diff --git a/ffcx/codegeneration/definitions.py b/ffcx/codegeneration/definitions.py index 1b26de95f..07390fb1c 100644 --- a/ffcx/codegeneration/definitions.py +++ b/ffcx/codegeneration/definitions.py @@ -9,7 +9,7 @@ import ufl from ffcx.element_interface import convert_element -from ffcx.naming import scalar_to_value_type +import ffcx.codegeneration.lnodes as L logger = logging.getLogger("ffcx") @@ -17,11 +17,10 @@ class FFCXBackendDefinitions(object): """FFCx specific code definitions.""" - def __init__(self, ir, language, symbols, options): + def __init__(self, ir, symbols, options): # Store ir and options self.integral_type = ir.integral_type self.entitytype = ir.entitytype - self.language = language self.symbols = symbols self.options = options @@ -64,8 +63,6 @@ def get(self, t, mt, tabledata, quadrature_rule, access): def coefficient(self, t, mt, tabledata, quadrature_rule, access): """Return definition code for coefficients.""" - L = self.language - ttype = tabledata.ttype num_dofs = tabledata.values.shape[3] bs = tabledata.block_size @@ -106,7 +103,7 @@ def coefficient(self, t, mt, tabledata, quadrature_rule, access): dof_access = self.symbols.coefficient_dof_access(mt.terminal, ic * bs + begin) body = [L.AssignAdd(access, dof_access * FE[ic])] - code += [L.VariableDecl(self.options["scalar_type"], access, 0.0)] + code += [L.VariableDecl(access, 0.0)] code += [L.ForRange(ic, 0, num_dofs, body)] return pre_code, code @@ -119,8 +116,6 @@ def constant(self, t, mt, tabledata, quadrature_rule, access): def _define_coordinate_dofs_lincomb(self, e, mt, tabledata, quadrature_rule, access): """Define x or J as a linear combination of coordinate dofs with given table data.""" - L = self.language - # Get properties of domain domain = ufl.domain.extract_unique_domain(mt.terminal) coordinate_element = domain.ufl_coordinate_element() @@ -140,7 +135,7 @@ def _define_coordinate_dofs_lincomb(self, e, mt, tabledata, quadrature_rule, acc # Get access to element table FE = self.symbols.element_table(tabledata, self.entitytype, mt.restriction) ic = self.symbols.coefficient_dof_sum_index() - dof_access = self.symbols.S("coordinate_dofs") + dof_access = L.Symbol("coordinate_dofs", dtype=L.DataType.REAL) # coordinate dofs is always 3d dim = 3 @@ -148,11 +143,9 @@ def _define_coordinate_dofs_lincomb(self, e, mt, tabledata, quadrature_rule, acc if mt.restriction == "-": offset = num_scalar_dofs * dim - value_type = scalar_to_value_type(self.options["scalar_type"]) - code = [] body = [L.AssignAdd(access, dof_access[ic * dim + begin + offset] * FE[ic])] - code += [L.VariableDecl(f"{value_type}", access, 0.0)] + code += [L.VariableDecl(access, 0.0)] code += [L.ForRange(ic, 0, num_scalar_dofs, body)] return [], code diff --git a/ffcx/codegeneration/expression_generator.py b/ffcx/codegeneration/expression_generator.py index 553e8b315..28ee90e29 100644 --- a/ffcx/codegeneration/expression_generator.py +++ b/ffcx/codegeneration/expression_generator.py @@ -12,7 +12,8 @@ import ufl from ffcx.codegeneration import geometry from ffcx.codegeneration.backend import FFCXBackend -from ffcx.codegeneration.C.cnodes import CNode +import ffcx.codegeneration.lnodes as L +from ffcx.codegeneration.lnodes import LNode from ffcx.ir.representation import ExpressionIR from ffcx.naming import scalar_to_value_type @@ -27,7 +28,7 @@ def __init__(self, ir: ExpressionIR, backend: FFCXBackend): self.ir = ir self.backend = backend - self.scope: Dict[Any, CNode] = {} + self.scope: Dict[Any, LNode] = {} self._ufl_names: Set[Any] = set() self.symbol_counters: DefaultDict[Any, int] = collections.defaultdict(int) self.shared_symbols: Dict[Any, Any] = {} @@ -58,10 +59,8 @@ def generate(self): return L.StatementList(parts) - def generate_geometry_tables(self, float_type: str): + def generate_geometry_tables(self): """Generate static tables of geometry data.""" - L = self.backend.language - # Currently we only support circumradius ufl_geometry = { ufl.geometry.ReferenceCellVolume: "reference_cell_volume", @@ -79,24 +78,20 @@ def generate_geometry_tables(self, float_type: str): parts = [] for i, cell_list in cells.items(): for c in cell_list: - parts.append(geometry.write_table(L, ufl_geometry[i], c, float_type)) + parts.append(geometry.write_table(L, ufl_geometry[i], c)) return parts def generate_element_tables(self, float_type: str): """Generate tables of FE basis evaluated at specified points.""" - L = self.backend.language parts = [] tables = self.ir.unique_tables - - padlen = self.ir.options["padlen"] table_names = sorted(tables) for name in table_names: table = tables[name] - decl = L.ArrayDecl( - f"static const {float_type}", name, table.shape, table, padlen=padlen) + decl = L.ArrayDecl(name, table) parts += [decl] # Add leading comment if there are any tables diff --git a/ffcx/codegeneration/geometry.py b/ffcx/codegeneration/geometry.py index 2df2bcd93..271a438fa 100644 --- a/ffcx/codegeneration/geometry.py +++ b/ffcx/codegeneration/geometry.py @@ -9,23 +9,23 @@ import basix -def write_table(L, tablename, cellname, type: str): +def write_table(L, tablename, cellname): if tablename == "facet_edge_vertices": return facet_edge_vertices(L, tablename, cellname) if tablename == "reference_facet_jacobian": - return reference_facet_jacobian(L, tablename, cellname, type) + return reference_facet_jacobian(L, tablename, cellname) if tablename == "reference_cell_volume": - return reference_cell_volume(L, tablename, cellname, type) + return reference_cell_volume(L, tablename, cellname) if tablename == "reference_facet_volume": - return reference_facet_volume(L, tablename, cellname, type) + return reference_facet_volume(L, tablename, cellname) if tablename == "reference_edge_vectors": - return reference_edge_vectors(L, tablename, cellname, type) + return reference_edge_vectors(L, tablename, cellname) if tablename == "facet_reference_edge_vectors": - return facet_reference_edge_vectors(L, tablename, cellname, type) + return facet_reference_edge_vectors(L, tablename, cellname) if tablename == "reference_facet_normals": - return reference_facet_normals(L, tablename, cellname, type) + return reference_facet_normals(L, tablename, cellname) if tablename == "facet_orientation": - return facet_orientation(L, tablename, cellname, type) + return facet_orientation(L, tablename, cellname) raise ValueError(f"Unknown geometry table name: {tablename}") diff --git a/ffcx/codegeneration/integral_generator.py b/ffcx/codegeneration/integral_generator.py index 1ebcef70f..c58111737 100644 --- a/ffcx/codegeneration/integral_generator.py +++ b/ffcx/codegeneration/integral_generator.py @@ -1,4 +1,4 @@ -# Copyright (C) 2015-2021 Martin Sandve Alnæs, Michal Habera, Igor Baratta +# Copyright (C) 2015-2023 Martin Sandve Alnæs, Michal Habera, Igor Baratta, Chris Richardson # # This file is part of FFCx. (https://www.fenicsproject.org) # @@ -10,11 +10,11 @@ import ufl from ffcx.codegeneration import geometry -from ffcx.codegeneration.C.cnodes import BinOp, CNode from ffcx.ir.elementtables import piecewise_ttypes from ffcx.ir.integral import BlockDataT +import ffcx.codegeneration.lnodes as L +from ffcx.codegeneration.lnodes import LNode, BinOp from ffcx.ir.representationutils import QuadratureRule -from ffcx.naming import scalar_to_value_type logger = logging.getLogger("ffcx") @@ -25,11 +25,11 @@ def __init__(self, ir, backend): self.ir = ir # Backend specific plugin with attributes - # - language: for translating ufl operators to target language # - symbols: for translating ufl operators to target language # - definitions: for defining backend specific variables # - access: for accessing backend specific variables self.backend = backend + self.ufl_to_language = L.UFL2LNodes() # Set of operator names code has been generated for, used in the # end for selecting necessary includes @@ -57,7 +57,7 @@ def set_var(self, quadrature_rule, v, vaccess): Scope is determined by quadrature_rule which identifies the quadrature loop scope or None if outside quadrature loops. - v is the ufl expression and vaccess is the CNodes + v is the ufl expression and vaccess is the LNodes expression to access the value in the code. """ @@ -72,10 +72,10 @@ def get_var(self, quadrature_rule, v): If v is not found in quadrature loop scope, the piecewise scope (None) is checked. - Returns the CNodes expression to access the value in the code. + Returns the LNodes expression to access the value in the code. """ if v._ufl_is_literal_: - return self.backend.ufl_to_language.get(v) + return self.ufl_to_language.get(v) f = self.scopes[quadrature_rule].get(v) if f is None: f = self.scopes[None].get(v) @@ -83,13 +83,12 @@ def get_var(self, quadrature_rule, v): def new_temp_symbol(self, basename): """Create a new code symbol named basename + running counter.""" - L = self.backend.language name = "%s%d" % (basename, self.symbol_counters[basename]) self.symbol_counters[basename] += 1 - return L.Symbol(name) + return L.Symbol(name, dtype=L.DataType.SCALAR) def get_temp_symbol(self, tempname, key): - key = (tempname, ) + key + key = (tempname,) + key s = self.shared_symbols.get(key) defined = s is not None if not defined: @@ -104,32 +103,21 @@ def generate(self): context that matches a suitable version of the UFC tabulate_tensor signatures. """ - L = self.backend.language - # Assert that scopes are empty: expecting this to be called only # once assert not any(d for d in self.scopes.values()) parts = [] - scalar_type = self.backend.access.options["scalar_type"] - value_type = scalar_to_value_type(scalar_type) - alignment = self.ir.options['assume_aligned'] - if alignment != -1: - scalar_type = self.backend.access.options["scalar_type"] - parts += [L.VerbatimStatement(f"A = ({scalar_type}*)__builtin_assume_aligned(A, {alignment});"), - L.VerbatimStatement(f"w = (const {scalar_type}*)__builtin_assume_aligned(w, {alignment});"), - L.VerbatimStatement(f"c = (const {scalar_type}*)__builtin_assume_aligned(c, {alignment});"), - L.VerbatimStatement(f"coordinate_dofs = (const {value_type}*)__builtin_assume_aligned(coordinate_dofs, {alignment});")] # noqa # Generate the tables of quadrature points and weights - parts += self.generate_quadrature_tables(value_type) + parts += self.generate_quadrature_tables() # Generate the tables of basis function values and # pre-integrated blocks - parts += self.generate_element_tables(value_type) + parts += self.generate_element_tables() # Generate the tables of geometry data that are needed - parts += self.generate_geometry_tables(value_type) + parts += self.generate_geometry_tables() # Loop generation code will produce parts to go before # quadloops, to define the quadloops, and to go after the @@ -160,11 +148,9 @@ def generate(self): return L.StatementList(parts) - def generate_quadrature_tables(self, value_type: str) -> List[str]: + def generate_quadrature_tables(self): """Generate static tables of quadrature points and weights.""" - L = self.backend.language - - parts: List[str] = [] + parts = [] # No quadrature tables for custom (given argument) or point # (evaluation in single vertex) @@ -172,25 +158,18 @@ def generate_quadrature_tables(self, value_type: str) -> List[str]: if self.ir.integral_type in skip: return parts - padlen = self.ir.options["padlen"] - # Loop over quadrature rules for quadrature_rule, integrand in self.ir.integrand.items(): - num_points = quadrature_rule.weights.shape[0] - # Generate quadrature weights array wsym = self.backend.symbols.weights_table(quadrature_rule) - parts += [L.ArrayDecl(f"static const {value_type}", wsym, num_points, - quadrature_rule.weights, padlen=padlen)] + parts += [L.ArrayDecl(wsym, values=quadrature_rule.weights, const=True)] # Add leading comment if there are any tables parts = L.commented_code_list(parts, "Quadrature rules") return parts - def generate_geometry_tables(self, float_type: str): + def generate_geometry_tables(self): """Generate static tables of geometry data.""" - L = self.backend.language - ufl_geometry = { ufl.geometry.FacetEdgeVectors: "facet_edge_vertices", ufl.geometry.CellFacetJacobian: "reference_facet_jacobian", @@ -214,17 +193,15 @@ def generate_geometry_tables(self, float_type: str): parts = [] for i, cell_list in cells.items(): for c in cell_list: - parts.append(geometry.write_table(L, ufl_geometry[i], c, float_type)) + parts.append(geometry.write_table(L, ufl_geometry[i], c)) return parts - def generate_element_tables(self, float_type: str): + def generate_element_tables(self): """Generate static tables with precomputed element basisfunction values in quadrature points.""" - L = self.backend.language parts = [] tables = self.ir.unique_tables table_types = self.ir.unique_table_types - padlen = self.ir.options["padlen"] if self.ir.integral_type in ufl.custom_integral_types: # Define only piecewise tables table_names = [name for name in sorted(tables) if table_types[name] in piecewise_ttypes] @@ -234,7 +211,7 @@ def generate_element_tables(self, float_type: str): for name in table_names: table = tables[name] - parts += self.declare_table(name, table, padlen, float_type) + parts += self.declare_table(name, table) # Add leading comment if there are any tables parts = L.commented_code_list(parts, [ @@ -242,19 +219,18 @@ def generate_element_tables(self, float_type: str): "FE* dimensions: [permutation][entities][points][dofs]"]) return parts - def declare_table(self, name, table, padlen, value_type: str): + def declare_table(self, name, table): """Declare a table. If the dof dimensions of the table have dof rotations, apply these rotations. """ - L = self.backend.language - return [L.ArrayDecl(f"static const {value_type}", name, table.shape, table, padlen=padlen)] + table_symbol = L.Symbol(name, dtype=L.DataType.REAL) + return [L.ArrayDecl(table_symbol, values=table, const=True)] def generate_quadrature_loop(self, quadrature_rule: QuadratureRule): """Generate quadrature loop with for this quadrature_rule.""" - L = self.backend.language # Generate varying partition pre_definitions, body = self.generate_varying_partition(quadrature_rule) @@ -278,12 +254,10 @@ def generate_quadrature_loop(self, quadrature_rule: QuadratureRule): return pre_definitions, preparts, quadparts def generate_piecewise_partition(self, quadrature_rule): - L = self.backend.language - # Get annotated graph of factorisation F = self.ir.integrand[quadrature_rule]["factorization"] - arraysymbol = L.Symbol(f"sp_{quadrature_rule.id()}") + arraysymbol = L.Symbol(f"sp_{quadrature_rule.id()}", dtype=L.DataType.SCALAR) pre_definitions, parts = self.generate_partition(arraysymbol, F, "piecewise", None) assert len(pre_definitions) == 0, "Quadrature independent code should have not pre-definitions" parts = L.commented_code_list( @@ -292,19 +266,17 @@ def generate_piecewise_partition(self, quadrature_rule): return parts def generate_varying_partition(self, quadrature_rule): - L = self.backend.language # Get annotated graph of factorisation F = self.ir.integrand[quadrature_rule]["factorization"] - arraysymbol = L.Symbol(f"sv_{quadrature_rule.id()}") + arraysymbol = L.Symbol(f"sv_{quadrature_rule.id()}", dtype=L.DataType.SCALAR) pre_definitions, parts = self.generate_partition(arraysymbol, F, "varying", quadrature_rule) parts = L.commented_code_list(parts, f"Varying computations for quadrature rule {quadrature_rule.id()}") return pre_definitions, parts def generate_partition(self, symbol, F, mode, quadrature_rule): - L = self.backend.language definitions = dict() pre_definitions = dict() @@ -322,7 +294,7 @@ def generate_partition(self, symbol, F, mode, quadrature_rule): # cache if not self.get_var(quadrature_rule, v): if v._ufl_is_literal_: - vaccess = self.backend.ufl_to_language.get(v) + vaccess = self.ufl_to_language.get(v) elif mt is not None: # All finite element based terminals have table # data, as well as some, but not all, of the @@ -352,7 +324,7 @@ def generate_partition(self, symbol, F, mode, quadrature_rule): # Mapping UFL operator to target language self._ufl_names.add(v._ufl_handler_name_) - vexpr = self.backend.ufl_to_language.get(v, *vops) + vexpr = self.ufl_to_language.get(v, *vops) # Create a new intermediate for each subexpression # except boolean conditions and its childs @@ -379,9 +351,8 @@ def generate_partition(self, symbol, F, mode, quadrature_rule): vaccess = symbol[j] intermediates.append(L.Assign(vaccess, vexpr)) else: - scalar_type = self.backend.access.options["scalar_type"] vaccess = L.Symbol("%s_%d" % (symbol.name, j)) - intermediates.append(L.VariableDecl(f"const {scalar_type}", vaccess, vexpr)) + intermediates.append(L.VariableDecl(vaccess, vexpr)) # Store access node for future reference self.set_var(quadrature_rule, v, vaccess) @@ -393,9 +364,7 @@ def generate_partition(self, symbol, F, mode, quadrature_rule): if intermediates: if use_symbol_array: - padlen = self.ir.options["padlen"] - parts += [L.ArrayDecl(self.backend.access.options["scalar_type"], - symbol, len(intermediates), padlen=padlen)] + parts += [L.ArrayDecl(symbol, sizes=len(intermediates))] parts += intermediates return pre_definitions, parts @@ -467,11 +436,9 @@ def generate_block_parts(self, quadrature_rule: QuadratureRule, blockmap: Tuple, Should be called with quadrature_rule=None for quadloop-independent blocks. """ - L = self.backend.language - # The parts to return - preparts: List[CNode] = [] - quadparts: List[CNode] = [] + preparts: List[LNode] = [] + quadparts: List[LNode] = [] # RHS expressions grouped by LHS "dofmap" rhs_expressions = collections.defaultdict(list) @@ -523,8 +490,7 @@ def generate_block_parts(self, quadrature_rule: QuadratureRule, blockmap: Tuple, key = (quadrature_rule, factor_index, blockdata.all_factors_piecewise) fw, defined = self.get_temp_symbol("fw", key) if not defined: - scalar_type = self.backend.access.options["scalar_type"] - quadparts.append(L.VariableDecl(f"const {scalar_type}", fw, fw_rhs)) + quadparts.append(L.VariableDecl(fw, fw_rhs)) assert not blockdata.transposed, "Not handled yet" A_shape = self.ir.tensor_shape @@ -551,7 +517,7 @@ def generate_block_parts(self, quadrature_rule: QuadratureRule, blockmap: Tuple, # List of statements to keep in the inner loop keep = collections.defaultdict(list) # List of temporary array declarations - pre_loop: List[CNode] = [] + pre_loop: List[LNode] = [] # List of loop invariant expressions to hoist hoist: List[BinOp] = [] @@ -577,34 +543,29 @@ def generate_block_parts(self, quadrature_rule: QuadratureRule, blockmap: Tuple, # floating point operations (factorize expressions by # grouping) for statement in hoist_rhs: - sum = [] - for rhs in hoist_rhs[statement]: - sum.append(L.float_product(rhs)) - sum = L.Sum(sum) + sum = L.Sum([L.float_product(rhs) for rhs in hoist_rhs[statement]]) lhs = None for h in hoist: - if (h.rhs == sum): + if h.rhs == sum: lhs = h.lhs break if lhs: keep[indices].append(L.float_product([statement, lhs])) else: t = self.new_temp_symbol("t") - scalar_type = self.backend.access.options["scalar_type"] - pre_loop.append(L.ArrayDecl(scalar_type, t, blockdims[0])) + pre_loop.append(L.ArrayDecl(t, sizes=blockdims[0])) keep[indices].append(L.float_product([statement, t[B_indices[0]]])) hoist.append(L.Assign(t[B_indices[i - 1]], sum)) else: keep[indices] = rhs_expressions[indices] - hoist_code: List[CNode] = [L.ForRange(B_indices[0], 0, blockdims[0], body=hoist)] if hoist else [] + hoist_code: List[LNode] = [L.ForRange(B_indices[0], 0, blockdims[0], body=hoist)] if hoist else [] - body: List[CNode] = [] + body: List[LNode] = [] for indices in keep: - sum = L.Sum(keep[indices]) - body.append(L.AssignAdd(A[indices], sum)) + body.append(L.AssignAdd(A[indices], L.Sum(keep[indices]))) for i in reversed(range(block_rank)): body = [L.ForRange(B_indices[i], 0, blockdims[i], body=body)] @@ -626,8 +587,6 @@ def fuse_loops(self, definitions): determine how many loops should fuse at a time. """ - L = self.backend.language - loops = collections.defaultdict(list) pre_loop = [] for access, definition in definitions.items(): diff --git a/ffcx/codegeneration/symbols.py b/ffcx/codegeneration/symbols.py index 7630a0d68..41d6dd20b 100644 --- a/ffcx/codegeneration/symbols.py +++ b/ffcx/codegeneration/symbols.py @@ -7,6 +7,7 @@ import logging import ufl +import ffcx.codegeneration.lnodes as L logger = logging.getLogger("ffcx") @@ -60,10 +61,8 @@ def format_mt_name(basename, mt): class FFCXBackendSymbols(object): """FFCx specific symbol definitions. Provides non-ufl symbols.""" - def __init__(self, language, coefficient_numbering, coefficient_offsets, + def __init__(self, coefficient_numbering, coefficient_offsets, original_constant_offsets): - self.L = language - self.S = self.L.Symbol self.coefficient_numbering = coefficient_numbering self.coefficient_offsets = coefficient_offsets @@ -71,71 +70,71 @@ def __init__(self, language, coefficient_numbering, coefficient_offsets, def element_tensor(self): """Symbol for the element tensor itself.""" - return self.S("A") + return L.Symbol("A") def entity(self, entitytype, restriction): """Entity index for lookup in element tables.""" if entitytype == "cell": # Always 0 for cells (even with restriction) - return self.L.LiteralInt(0) + return L.LiteralInt(0) elif entitytype == "facet": postfix = "[0]" if restriction == "-": postfix = "[1]" - return self.S("entity_local_index" + postfix) + return L.Symbol("entity_local_index" + postfix, dtype=L.DataType.INT) elif entitytype == "vertex": - return self.S("entity_local_index[0]") + return L.Symbol("entity_local_index[0]", dtype=L.DataType.INT) else: logging.exception(f"Unknown entitytype {entitytype}") def argument_loop_index(self, iarg): """Loop index for argument #iarg.""" indices = ["i", "j", "k", "l"] - return self.S(indices[iarg]) + return L.Symbol(indices[iarg], dtype=L.DataType.INT) def coefficient_dof_sum_index(self): """Index for loops over coefficient dofs, assumed to never be used in two nested loops.""" - return self.S("ic") + return L.Symbol("ic", dtype=L.DataType.INT) def quadrature_loop_index(self): """Reusing a single index name for all quadrature loops, assumed not to be nested.""" - return self.S("iq") + return L.Symbol("iq", dtype=L.DataType.INT) def quadrature_permutation(self, index): """Quadrature permutation, as input to the function.""" - return self.S("quadrature_permutation")[index] + return L.Symbol("quadrature_permutation", dtype=L.DataType.INT)[index] def custom_weights_table(self): """Table for chunk of custom quadrature weights (including cell measure scaling).""" - return self.S("weights_chunk") + return L.Symbol("weights_chunk", dtype=L.DataType.REAL) def custom_points_table(self): """Table for chunk of custom quadrature points (physical coordinates).""" - return self.S("points_chunk") + return L.Symbol("points_chunk", dtype=L.DataType.REAL) def weights_table(self, quadrature_rule): """Table of quadrature weights.""" - return self.S(f"weights_{quadrature_rule.id()}") + return L.Symbol(f"weights_{quadrature_rule.id()}", dtype=L.DataType.REAL) def points_table(self, quadrature_rule): """Table of quadrature points (points on the reference integration entity).""" - return self.S(f"points_{quadrature_rule.id()}") + return L.Symbol(f"points_{quadrature_rule.id()}", dtype=L.DataType.REAL) def x_component(self, mt): """Physical coordinate component.""" - return self.S(format_mt_name("x", mt)) + return L.Symbol(format_mt_name("x", mt), dtype=L.DataType.REAL) def J_component(self, mt): """Jacobian component.""" # FIXME: Add domain number! - return self.S(format_mt_name("J", mt)) + return L.Symbol(format_mt_name("J", mt), dtype=L.DataType.REAL) def domain_dof_access(self, dof, component, gdim, num_scalar_dofs, restriction): # FIXME: Add domain number or offset! offset = 0 if restriction == "-": offset = num_scalar_dofs * 3 - vc = self.S("coordinate_dofs") + vc = L.Symbol("coordinate_dofs", dtype=L.DataType.REAL) return vc[3 * dof + component + offset] def domain_dofs_access(self, gdim, num_scalar_dofs, restriction): @@ -147,14 +146,14 @@ def domain_dofs_access(self, gdim, num_scalar_dofs, restriction): def coefficient_dof_access(self, coefficient, dof_index): offset = self.coefficient_offsets[coefficient] - w = self.S("w") + w = L.Symbol("w", dtype=L.DataType.SCALAR) return w[offset + dof_index] def coefficient_dof_access_blocked(self, coefficient: ufl.Coefficient, index, block_size, dof_offset): coeff_offset = self.coefficient_offsets[coefficient] - w = self.S("w") - _w = self.S(f"_w_{coeff_offset}_{dof_offset}") + w = L.Symbol("w", dtype=L.DataType.SCALAR) + _w = L.Symbol(f"_w_{coeff_offset}_{dof_offset}", dtype=L.DataType.SCALAR) unit_stride_access = _w[index] original_access = w[coeff_offset + index * block_size + dof_offset] return unit_stride_access, original_access @@ -162,17 +161,14 @@ def coefficient_dof_access_blocked(self, coefficient: ufl.Coefficient, index, def coefficient_value(self, mt): """Symbol for variable holding value or derivative component of coefficient.""" c = self.coefficient_numbering[mt.terminal] - return self.S(format_mt_name("w%d" % (c, ), mt)) + return L.Symbol(format_mt_name("w%d" % (c, ), mt), dtype=L.DataType.SCALAR) def constant_index_access(self, constant, index): offset = self.original_constant_offsets[constant] - c = self.S("c") + c = L.Symbol("c", dtype=L.DataType.SCALAR) return c[offset + index] - def named_table(self, name): - return self.S(name) - def element_table(self, tabledata, entitytype, restriction): entity = self.entity(entitytype, restriction) @@ -194,4 +190,4 @@ def element_table(self, tabledata, entitytype, restriction): qp = 0 # Return direct access to element table - return self.named_table(tabledata.name)[qp][entity][iq] + return L.Symbol(tabledata.name, dtype=L.DataType.REAL)[qp][entity][iq] diff --git a/ffcx/codegeneration/utils.py b/ffcx/codegeneration/utils.py new file mode 100644 index 000000000..06497c172 --- /dev/null +++ b/ffcx/codegeneration/utils.py @@ -0,0 +1,34 @@ +# Copyright (C) 2020-2023 Michal Habera and Chris Richardson +# +# This file is part of FFCx.(https://www.fenicsproject.org) +# +# SPDX-License-Identifier: LGPL-3.0-or-later + +def cdtype_to_numpy(cdtype: str): + """Map a C data type string NumPy datatype string.""" + if cdtype == "double": + return "float64" + elif cdtype == "double _Complex": + return "complex128" + elif cdtype == "float": + return "float32" + elif cdtype == "float _Complex": + return "complex64" + elif cdtype == "long double": + return "longdouble" + else: + raise RuntimeError(f"Unknown NumPy type for: {cdtype}") + + +def scalar_to_value_type(scalar_type: str) -> str: + """The C value type associated with a C scalar type. + + Args: + scalar_type: A C type. + + Returns: + The value type associated with ``scalar_type``. E.g., if + ``scalar_type`` is ``float _Complex`` the return value is 'float'. + + """ + return scalar_type.replace(' _Complex', '')