From 5b3dbeccd08f4ed85095024ead1c0712a7788dcd Mon Sep 17 00:00:00 2001 From: Chris Richardson Date: Tue, 15 Aug 2023 15:35:18 +0100 Subject: [PATCH] Remove switch/case from generated code (#591) * Work on removing switch/case from integrals/ids * Fix typos * Fix for zero integrals case * Typos * Improve documentation * Remove from form.py * Fix order * Remove for integrals * Remove all remaining cases * Remove from Cnodes * Fix tests * Remove break/continue * fix * Reset files in tests --- ffcx/codegeneration/C/cnodes.py | 304 ++++++++++++------------- ffcx/codegeneration/dofmap.py | 67 ------ ffcx/codegeneration/dofmap_template.py | 18 -- ffcx/codegeneration/form.py | 53 +---- ffcx/codegeneration/form_template.py | 20 -- ffcx/codegeneration/ufcx.h | 9 - test/test_add_mode.py | 21 +- test/test_blocked_elements.py | 50 ++-- test/test_jit_forms.py | 73 +++--- 9 files changed, 214 insertions(+), 401 deletions(-) diff --git a/ffcx/codegeneration/C/cnodes.py b/ffcx/codegeneration/C/cnodes.py index 860945923..3d6eb7666 100644 --- a/ffcx/codegeneration/C/cnodes.py +++ b/ffcx/codegeneration/C/cnodes.py @@ -10,8 +10,7 @@ import numpy as np from ffcx.codegeneration.C.format_lines import Indented, format_indented_lines -from ffcx.codegeneration.C.format_value import (format_float, format_int, - format_value) +from ffcx.codegeneration.C.format_value import format_float, format_int, format_value from ffcx.codegeneration.C.precedence import PRECEDENCE logger = logging.getLogger("ffcx") @@ -32,18 +31,21 @@ def is_zero_cexpr(cexpr): - return ((isinstance(cexpr, LiteralFloat) and cexpr.value == 0.0) - or (isinstance(cexpr, LiteralInt) and cexpr.value == 0)) + return (isinstance(cexpr, LiteralFloat) and cexpr.value == 0.0) or ( + isinstance(cexpr, LiteralInt) and cexpr.value == 0 + ) def is_one_cexpr(cexpr): - return ((isinstance(cexpr, LiteralFloat) and cexpr.value == 1.0) - or (isinstance(cexpr, LiteralInt) and cexpr.value == 1)) + return (isinstance(cexpr, LiteralFloat) and cexpr.value == 1.0) or ( + isinstance(cexpr, LiteralInt) and cexpr.value == 1 + ) def is_negative_one_cexpr(cexpr): - return ((isinstance(cexpr, LiteralFloat) and cexpr.value == -1.0) - or (isinstance(cexpr, LiteralInt) and cexpr.value == -1)) + return (isinstance(cexpr, LiteralFloat) and cexpr.value == -1.0) or ( + isinstance(cexpr, LiteralInt) and cexpr.value == -1 + ) def float_product(factors): @@ -58,6 +60,8 @@ def float_product(factors): if is_zero_cexpr(f): return f return Product(factors) + + # CNode core @@ -77,6 +81,7 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + # CExpr base classes @@ -261,7 +266,7 @@ def __eq__(self, other): class LiteralFloat(CExprLiteral): """A floating point literal value.""" - __slots__ = ("value", ) + __slots__ = ("value",) precedence = PRECEDENCE.LITERAL def __init__(self, value): @@ -289,7 +294,7 @@ def flops(self): class LiteralInt(CExprLiteral): """An integer literal value.""" - __slots__ = ("value", ) + __slots__ = ("value",) precedence = PRECEDENCE.LITERAL def __init__(self, value): @@ -323,11 +328,11 @@ def __hash__(self): class LiteralBool(CExprLiteral): """A boolean literal value.""" - __slots__ = ("value", ) + __slots__ = ("value",) precedence = PRECEDENCE.LITERAL def __init__(self, value): - assert isinstance(value, (bool, )) + assert isinstance(value, (bool,)) self.value = value def ce_format(self, precision=None): @@ -345,16 +350,16 @@ def __bool__(self): class LiteralString(CExprLiteral): """A boolean literal value.""" - __slots__ = ("value", ) + __slots__ = ("value",) precedence = PRECEDENCE.LITERAL def __init__(self, value): - assert isinstance(value, (str, )) + assert isinstance(value, (str,)) assert '"' not in value self.value = value def ce_format(self, precision=None): - return '"%s"' % (self.value, ) + return '"%s"' % (self.value,) def __eq__(self, other): return isinstance(other, LiteralString) and self.value == other.value @@ -363,7 +368,7 @@ def __eq__(self, other): class Symbol(CExprTerminal): """A named symbol.""" - __slots__ = ("name", ) + __slots__ = ("name",) precedence = PRECEDENCE.SYMBOL def __init__(self, name): @@ -385,10 +390,11 @@ def __hash__(self): # CExprOperator base classes + class UnaryOp(CExprOperator): """Base class for unary operators.""" - __slots__ = ("arg", ) + __slots__ = ("arg",) def __init__(self, arg): self.arg = as_cexpr(arg) @@ -408,7 +414,7 @@ class PrefixUnaryOp(UnaryOp): def ce_format(self, precision=None): arg = self.arg.ce_format(precision) if self.arg.precedence >= self.precedence: - arg = '(' + arg + ')' + arg = "(" + arg + ")" return self.op + arg def __eq__(self, other): @@ -423,7 +429,7 @@ class PostfixUnaryOp(UnaryOp): def ce_format(self, precision=None): arg = self.arg.ce_format(precision) if self.arg.precedence >= self.precedence: - arg = '(' + arg + ')' + arg = "(" + arg + ")" return arg + self.op def __eq__(self, other): @@ -444,15 +450,19 @@ def ce_format(self, precision=None): # Apply parentheses if self.lhs.precedence >= self.precedence: - lhs = '(' + lhs + ')' + lhs = "(" + lhs + ")" if self.rhs.precedence >= self.precedence: - rhs = '(' + rhs + ')' + rhs = "(" + rhs + ")" # Return combined string return lhs + (" " + self.op + " ") + rhs def __eq__(self, other): - return (isinstance(other, type(self)) and self.lhs == other.lhs and self.rhs == other.rhs) + return ( + isinstance(other, type(self)) + and self.lhs == other.lhs + and self.rhs == other.rhs + ) def __hash__(self): return hash(self.ce_format()) @@ -464,7 +474,7 @@ def flops(self): class NaryOp(CExprOperator): """Base class for special n-ary operators.""" - __slots__ = ("args", ) + __slots__ = ("args",) def __init__(self, args): self.args = [as_cexpr(arg) for arg in args] @@ -476,7 +486,7 @@ def ce_format(self, precision=None): # Apply parentheses for i in range(len(args)): if self.args[i].precedence >= self.precedence: - args[i] = '(' + args[i] + ')' + args[i] = "(" + args[i] + ")" # Return combined string op = " " + self.op + " " @@ -486,8 +496,11 @@ def ce_format(self, precision=None): return s def __eq__(self, other): - return (isinstance(other, type(self)) and len(self.args) == len(other.args) - and all(a == b for a, b in zip(self.args, other.args))) + return ( + isinstance(other, type(self)) + and len(self.args) == len(other.args) + and all(a == b for a, b in zip(self.args, other.args)) + ) def flops(self): flops = len(self.args) - 1 @@ -716,6 +729,7 @@ class AssignDiv(AssignOp): __slots__ = () op = "/=" + # CExpr operators @@ -763,7 +777,7 @@ def __init__(self, array, dummy=None, dims=None, strides=None, offset=None): def __getitem__(self, indices): if not isinstance(indices, (list, tuple)): - indices = (indices, ) + indices = (indices,) n = len(indices) if n == 0: # Handle scalar case, allowing dims=() and indices=() for A[0] @@ -773,7 +787,7 @@ def __getitem__(self, indices): else: i, s = (indices[0], self.strides[0]) literal_one = LiteralInt(1) - flat = (i if s == literal_one else s * i) + flat = i if s == literal_one else s * i if self.offset is not None: flat = self.offset + flat for i, s in zip(indices[1:n], self.strides[1:n]): @@ -796,11 +810,11 @@ def __init__(self, array, indices): elif isinstance(array, ArrayDecl): self.array = array.symbol else: - raise ValueError("Unexpected array type %s." % (type(array).__name__, )) + raise ValueError("Unexpected array type %s." % (type(array).__name__,)) # Allow expressions or literals as indices if not isinstance(indices, (list, tuple)): - indices = (indices, ) + indices = (indices,) self.indices = tuple(as_cexpr_or_string_symbol(i) for i in indices) # Early error checking for negative array dimensions @@ -812,8 +826,10 @@ def __init__(self, array, indices): if len(self.indices) != len(array.sizes): raise ValueError("Invalid number of indices.") ints = (int, LiteralInt) - if any((isinstance(i, ints) and isinstance(d, ints) and int(i) >= int(d)) - for i, d in zip(self.indices, array.sizes)): + if any( + (isinstance(i, ints) and isinstance(d, ints) and int(i) >= int(d)) + for i, d in zip(self.indices, array.sizes) + ): raise ValueError("Index value >= array dimension.") def __getitem__(self, indices): @@ -821,7 +837,7 @@ def __getitem__(self, indices): if isinstance(indices, list): indices = tuple(indices) elif not isinstance(indices, tuple): - indices = (indices, ) + indices = (indices,) return ArrayAccess(self.array, self.indices + indices) def ce_format(self, precision=None): @@ -831,8 +847,11 @@ def ce_format(self, precision=None): return s def __eq__(self, other): - return (isinstance(other, type(self)) and self.array == other.array - and self.indices == other.indices) + return ( + isinstance(other, type(self)) + and self.array == other.array + and self.indices == other.indices + ) def __hash__(self): return hash(self.ce_format()) @@ -858,18 +877,22 @@ def ce_format(self, precision=None): # Apply parentheses if self.condition.precedence >= self.precedence: - c = '(' + c + ')' + c = "(" + c + ")" if self.true.precedence >= self.precedence: - t = '(' + t + ')' + t = "(" + t + ")" if self.false.precedence >= self.precedence: - f = '(' + f + ')' + f = "(" + f + ")" # Return combined string return c + " ? " + t + " : " + f def __eq__(self, other): - return (isinstance(other, type(self)) and self.condition == other.condition - and self.true == other.true and self.false == other.false) + return ( + isinstance(other, type(self)) + and self.condition == other.condition + and self.true == other.true + and self.false == other.false + ) def flops(self): raise NotImplementedError("Flop count is not implemented for conditionals") @@ -887,7 +910,7 @@ def __init__(self, function, arguments=None): if arguments is None: arguments = () elif not isinstance(arguments, (tuple, list)): - arguments = (arguments, ) + arguments = (arguments,) self.arguments = [as_cexpr(arg) for arg in arguments] def ce_format(self, precision=None): @@ -895,8 +918,11 @@ def ce_format(self, precision=None): return self.function.ce_format(precision) + "(" + args + ")" def __eq__(self, other): - return (isinstance(other, type(self)) and self.function == other.function - and self.arguments == other.arguments) + return ( + isinstance(other, type(self)) + and self.function == other.function + and self.arguments == other.arguments + ) def flops(self): return 1 @@ -934,7 +960,7 @@ def as_cexpr(node): elif isinstance(node, numbers.Real): return LiteralFloat(node) elif isinstance(node, str): - raise RuntimeError("Got string for CExpr, this is ambiguous: %s" % (node, )) + raise RuntimeError("Got string for CExpr, this is ambiguous: %s" % (node,)) else: raise RuntimeError("Unexpected CExpr type %s:\n%s" % (type(node), str(node))) @@ -1005,7 +1031,9 @@ class CStatement(CNode): def cs_format(self, precision=None): """Return S: string | list(S) | Indented(S).""" - raise NotImplementedError("Missing implementation of cs_format() in CStatement.") + raise NotImplementedError( + "Missing implementation of cs_format() in CStatement." + ) def __str__(self): try: @@ -1025,7 +1053,7 @@ def flops(self): class VerbatimStatement(CStatement): """Wraps a source code string to be pasted verbatim into the source code.""" - __slots__ = ("codestring", ) + __slots__ = ("codestring",) is_scoped = False def __init__(self, codestring): @@ -1036,13 +1064,13 @@ def cs_format(self, precision=None): return self.codestring def __eq__(self, other): - return (isinstance(other, type(self)) and self.codestring == other.codestring) + return isinstance(other, type(self)) and self.codestring == other.codestring class Statement(CStatement): """Make an expression into a statement.""" - __slots__ = ("expr", ) + __slots__ = ("expr",) is_scoped = False def __init__(self, expr): @@ -1052,7 +1080,7 @@ def cs_format(self, precision=None): return self.expr.ce_format(precision) + ";" def __eq__(self, other): - return (isinstance(other, type(self)) and self.expr == other.expr) + return isinstance(other, type(self)) and self.expr == other.expr def flops(self): # print(self.expr.rhs.flops()) @@ -1062,7 +1090,7 @@ def flops(self): class StatementList(CStatement): """A simple sequence of statements. No new scopes are introduced.""" - __slots__ = ("statements", ) + __slots__ = ("statements",) def __init__(self, statements): self.statements = [as_cstatement(st) for st in statements] @@ -1075,7 +1103,7 @@ def cs_format(self, precision=None): return [st.cs_format(precision) for st in self.statements] def __eq__(self, other): - return (isinstance(other, type(self)) and self.statements == other.statements) + return isinstance(other, type(self)) and self.statements == other.statements def flops(self): flops = 0 @@ -1087,36 +1115,8 @@ def flops(self): # Simple statements -class Break(CStatement): - __slots__ = () - is_scoped = True - - def cs_format(self, precision=None): - return "break;" - - def __eq__(self, other): - return isinstance(other, type(self)) - - def flops(self): - return 0 - - -class Continue(CStatement): - __slots__ = () - is_scoped = True - - def cs_format(self, precision=None): - return "continue;" - - def __eq__(self, other): - return isinstance(other, type(self)) - - def flops(self): - return 0 - - class Return(CStatement): - __slots__ = ("value", ) + __slots__ = ("value",) is_scoped = True def __init__(self, value=None): @@ -1129,10 +1129,10 @@ def cs_format(self, precision=None): if self.value is None: return "return;" else: - return "return %s;" % (self.value.ce_format(precision), ) + return "return %s;" % (self.value.ce_format(precision),) def __eq__(self, other): - return (isinstance(other, type(self)) and self.value == other.value) + return isinstance(other, type(self)) and self.value == other.value def flops(self): return 0 @@ -1141,7 +1141,7 @@ def flops(self): class Comment(CStatement): """Line comment(s) used for annotating the generated code with human readable remarks.""" - __slots__ = ("comment", ) + __slots__ = ("comment",) is_scoped = True def __init__(self, comment): @@ -1153,7 +1153,7 @@ def cs_format(self, precision=None): return ["// " + line.strip() for line in lines] def __eq__(self, other): - return (isinstance(other, type(self)) and self.comment == other.comment) + return isinstance(other, type(self)) and self.comment == other.comment def flops(self): return 0 @@ -1179,7 +1179,7 @@ def commented_code_list(code, comments): class Pragma(CStatement): """Pragma comments used for compiler-specific annotations.""" - __slots__ = ("comment", ) + __slots__ = ("comment",) is_scoped = True def __init__(self, comment): @@ -1191,7 +1191,7 @@ def cs_format(self, precision=None): return "#pragma " + self.comment def __eq__(self, other): - return (isinstance(other, type(self)) and self.comment == other.comment) + return isinstance(other, type(self)) and self.comment == other.comment def flops(self): return 0 @@ -1207,7 +1207,6 @@ class VariableDecl(CStatement): is_scoped = False def __init__(self, typename, symbol, value=None): - # No type system yet, just using strings assert isinstance(typename, str) self.typename = typename @@ -1226,8 +1225,12 @@ def cs_format(self, precision=None): return code + ";" def __eq__(self, other): - return (isinstance(other, type(self)) and self.typename == other.typename - and self.symbol == other.symbol and self.value == other.value) + return ( + isinstance(other, type(self)) + and self.typename == other.typename + and self.symbol == other.symbol + and self.value == other.value + ) def flops(self): if self.value is not None: @@ -1305,13 +1308,18 @@ def formatter(x, p): r = len(sizes) assert r > 0 if r == 1: - return [build_1d_initializer_list(values, formatter, padlen=padlen, precision=precision)] + return [ + build_1d_initializer_list( + values, formatter, padlen=padlen, precision=precision + ) + ] else: # Render all sublists parts = [] for val in values: sublist = build_initializer_lists( - val, sizes[1:], level + 1, formatter, padlen=padlen, precision=precision) + val, sizes[1:], level + 1, formatter, padlen=padlen, precision=precision + ) parts.append(sublist) # Add comma after last line in each part except the last one for part in parts[:-1]: @@ -1357,7 +1365,7 @@ def __init__(self, typename, symbol, sizes=None, values=None, padlen=0): self.symbol = as_symbol(symbol) if isinstance(sizes, int): - sizes = (sizes, ) + sizes = (sizes,) self.sizes = tuple(sizes) # NB! No type checking, assuming nested lists of literal values. Not applying as_cexpr. @@ -1370,13 +1378,15 @@ def __init__(self, typename, symbol, sizes=None, values=None, padlen=0): def cs_format(self, precision=None): if not all(self.sizes): - raise RuntimeError(f"Detected an array {self.symbol} dimension of zero. This is not valid in C.") + raise RuntimeError( + f"Detected an array {self.symbol} dimension of zero. This is not valid in C." + ) # Pad innermost array dimension sizes = pad_innermost_dim(self.sizes, self.padlen) # Add brackets - brackets = ''.join("[%d]" % n for n in sizes) + brackets = "".join("[%d]" % n for n in sizes) # Join declaration decl = self.typename + " " + self.symbol.name + brackets @@ -1398,13 +1408,21 @@ def cs_format(self, precision=None): elif self.values.dtype.kind == "i": formatter = format_int elif self.values.dtype == np.bool_: + def format_bool(x, precision=None): return "true" if x is True else "false" + formatter = format_bool else: formatter = format_value initializer_lists = build_initializer_lists( - self.values, self.sizes, 0, formatter, padlen=self.padlen, precision=precision) + self.values, + self.sizes, + 0, + formatter, + padlen=self.padlen, + precision=precision, + ) if len(initializer_lists) == 1: return decl + " = " + initializer_lists[0] + ";" else: @@ -1413,8 +1431,9 @@ def format_bool(x, precision=None): def __eq__(self, other): attributes = ("typename", "symbol", "sizes", "padlen", "values") - return (isinstance(other, type(self)) - and all(getattr(self, name) == getattr(self, name) for name in attributes)) + return isinstance(other, type(self)) and all( + getattr(self, name) == getattr(self, name) for name in attributes + ) def flops(self): return 0 @@ -1424,7 +1443,7 @@ def flops(self): class Scope(CStatement): - __slots__ = ("body", ) + __slots__ = ("body",) is_scoped = True def __init__(self, body): @@ -1434,7 +1453,7 @@ def cs_format(self, precision=None): return ("{", Indented(self.body.cs_format(precision)), "}") def __eq__(self, other): - return (isinstance(other, type(self)) and self.body == other.body) + return isinstance(other, type(self)) and self.body == other.body def flops(self): return 0 @@ -1444,8 +1463,8 @@ def _is_simple_if_body(body): if isinstance(body, StatementList): if len(body.statements) > 1: return False - body, = body.statements - return isinstance(body, (Return, AssignOp, Break, Continue)) + (body,) = body.statements + return isinstance(body, (Return, AssignOp)) class If(CStatement): @@ -1465,8 +1484,11 @@ def cs_format(self, precision=None): return (statement, "{", body_fmt, "}") def __eq__(self, other): - return (isinstance(other, type(self)) and self.condition == other.condition - and self.body == other.body) + return ( + isinstance(other, type(self)) + and self.condition == other.condition + and self.body == other.body + ) class ElseIf(CStatement): @@ -1486,12 +1508,15 @@ def cs_format(self, precision=None): return (statement, "{", body_fmt, "}") def __eq__(self, other): - return (isinstance(other, type(self)) and self.condition == other.condition - and self.body == other.body) + return ( + isinstance(other, type(self)) + and self.condition == other.condition + and self.body == other.body + ) class Else(CStatement): - __slots__ = ("body", ) + __slots__ = ("body",) is_scoped = True def __init__(self, body): @@ -1506,7 +1531,7 @@ def cs_format(self, precision=None): return (statement, "{", body_fmt, "}") def __eq__(self, other): - return (isinstance(other, type(self)) and self.body == other.body) + return isinstance(other, type(self)) and self.body == other.body def is_simple_inner_loop(code): @@ -1517,56 +1542,6 @@ def is_simple_inner_loop(code): return False -class Switch(CStatement): - __slots__ = ("arg", "cases", "default", "autobreak", "autoscope") - is_scoped = True - - def __init__(self, arg, cases, default=None, autobreak=True, autoscope=True): - self.arg = as_cexpr_or_string_symbol(arg) - self.cases = [(as_cexpr(value), as_cstatement(body)) for value, body in cases] - if default is not None: - default = as_cstatement(default) - defcase = [(None, default)] - else: - defcase = [] - self.default = default - # If this is a switch where every case returns, scopes or breaks are never needed - if all(isinstance(case[1], Return) for case in self.cases + defcase): - autobreak = False - autoscope = False - if all(case[1].is_scoped for case in self.cases + defcase): - autoscope = False - assert autobreak in (True, False) - assert autoscope in (True, False) - self.autobreak = autobreak - self.autoscope = autoscope - - def cs_format(self, precision=None): - cases = [] - for case in self.cases: - caseheader = "case " + case[0].ce_format(precision) + ":" - casebody = case[1].cs_format(precision) - if self.autoscope: - casebody = ("{", Indented(casebody), "}") - if self.autobreak: - casebody = (casebody, "break;") - cases.extend([caseheader, Indented(casebody)]) - - if self.default is not None: - caseheader = "default:" - casebody = self.default.cs_format(precision) - if self.autoscope: - casebody = ("{", Indented(casebody), "}") - cases.extend([caseheader, Indented(casebody)]) - - return ("switch (" + self.arg.ce_format(precision) + ")", "{", cases, "}") - - def __eq__(self, other): - attributes = ("arg", "cases", "default", "autobreak", "autoscope") - return (isinstance(other, type(self)) - and all(getattr(self, name) == getattr(self, name) for name in attributes)) - - class ForRange(CStatement): """Slightly higher-level for loop assuming incrementing an index over a range.""" @@ -1603,8 +1578,9 @@ def cs_format(self, precision=None): def __eq__(self, other): attributes = ("index", "begin", "end", "body", "index_type") - return (isinstance(other, type(self)) - and all(getattr(self, name) == getattr(self, name) for name in attributes)) + return isinstance(other, type(self)) and all( + getattr(self, name) == getattr(self, name) for name in attributes + ) def flops(self): return (self.end.value - self.begin.value) * self.body.flops() @@ -1626,8 +1602,10 @@ def as_cstatement(node): # Special case for using assignment expressions as statements return Statement(node) else: - raise RuntimeError("Trying to create a statement of CExprOperator type %s:\n%s" % - (type(node), str(node))) + raise RuntimeError( + "Trying to create a statement of CExprOperator type %s:\n%s" + % (type(node), str(node)) + ) elif isinstance(node, list): # Convenience case for list of statements if len(node) == 1: @@ -1639,4 +1617,6 @@ def as_cstatement(node): # Backdoor for flexibility in code generation to allow verbatim pasted statements return VerbatimStatement(node) else: - raise RuntimeError("Unexpected CStatement type %s:\n%s" % (type(node), str(node))) + raise RuntimeError( + "Unexpected CStatement type %s:\n%s" % (type(node), str(node)) + ) diff --git a/ffcx/codegeneration/dofmap.py b/ffcx/codegeneration/dofmap.py index 3ca6e3174..de3d21dc4 100644 --- a/ffcx/codegeneration/dofmap.py +++ b/ffcx/codegeneration/dofmap.py @@ -8,54 +8,12 @@ # old implementation in FFC import logging -import typing import ffcx.codegeneration.dofmap_template as ufcx_dofmap logger = logging.getLogger("ffcx") -def tabulate_entity_dofs( - L, - entity_dofs: typing.List[typing.List[typing.List[int]]], - num_dofs_per_entity: typing.List[int], -): - # Output argument array - dofs = L.Symbol("dofs") - - # Input arguments - d = L.Symbol("d") - i = L.Symbol("i") - - # TODO: Removed check for (d <= tdim + 1) - tdim = len(num_dofs_per_entity) - 1 - - # Generate cases for each dimension: - all_cases = [] - for dim in range(tdim + 1): - # Ignore if no entities for this dimension - if num_dofs_per_entity[dim] == 0: - continue - - # Generate cases for each mesh entity - cases = [] - for entity in range(len(entity_dofs[dim])): - casebody = [] - for j, dof in enumerate(entity_dofs[dim][entity]): - casebody += [L.Assign(dofs[j], dof)] - cases.append((entity, L.StatementList(casebody))) - - # Generate inner switch - # TODO: Removed check for (i <= num_entities-1) - inner_switch = L.Switch(i, cases, autoscope=False) - all_cases.append((dim, inner_switch)) - - if all_cases: - return L.Switch(d, all_cases, autoscope=False) - else: - return L.NoOp() - - def generator(ir, options): """Generate UFC code for a dofmap.""" logger.info("Generating code for dofmap:") @@ -73,23 +31,6 @@ def generator(ir, options): import ffcx.codegeneration.C.cnodes as L - num_entity_dofs = ir.num_entity_dofs + [0, 0, 0, 0] - num_entity_dofs = num_entity_dofs[:4] - d["num_entity_dofs"] = f"num_entity_dofs_{ir.name}" - d["num_entity_dofs_init"] = L.ArrayDecl( - "int", f"num_entity_dofs_{ir.name}", values=num_entity_dofs, sizes=4 - ) - - num_entity_closure_dofs = ir.num_entity_closure_dofs + [0, 0, 0, 0] - num_entity_closure_dofs = num_entity_closure_dofs[:4] - d["num_entity_closure_dofs"] = f"num_entity_closure_dofs_{ir.name}" - d["num_entity_closure_dofs_init"] = L.ArrayDecl( - "int", - f"num_entity_closure_dofs_{ir.name}", - values=num_entity_closure_dofs, - sizes=4, - ) - flattened_entity_dofs = [] entity_dof_offsets = [0] for dim in ir.entity_dofs: @@ -137,14 +78,6 @@ def generator(ir, options): d["block_size"] = ir.block_size - # Functions - d["tabulate_entity_dofs"] = tabulate_entity_dofs( - L, ir.entity_dofs, ir.num_entity_dofs - ) - d["tabulate_entity_closure_dofs"] = tabulate_entity_dofs( - L, ir.entity_closure_dofs, ir.num_entity_closure_dofs - ) - if len(ir.sub_dofmaps) > 0: d["sub_dofmaps_initialization"] = L.ArrayDecl( "ufcx_dofmap*", diff --git a/ffcx/codegeneration/dofmap_template.py b/ffcx/codegeneration/dofmap_template.py index 0088d9a5f..abe5563f7 100644 --- a/ffcx/codegeneration/dofmap_template.py +++ b/ffcx/codegeneration/dofmap_template.py @@ -12,20 +12,6 @@ {sub_dofmaps_initialization} -void tabulate_entity_dofs_{factory_name}(int* restrict dofs, int d, int i) -{{ -{tabulate_entity_dofs} -}} - -void tabulate_entity_closure_dofs_{factory_name}(int* restrict dofs, int d, int i) -{{ -{tabulate_entity_closure_dofs} -}} - -{num_entity_dofs_init} - -{num_entity_closure_dofs_init} - {entity_dofs_init} {entity_dof_offsets_init} @@ -44,10 +30,6 @@ .entity_dof_offsets = {entity_dof_offsets}, .entity_closure_dofs = {entity_closure_dofs}, .entity_closure_dof_offsets = {entity_closure_dof_offsets}, - .num_entity_dofs = {num_entity_dofs}, - .tabulate_entity_dofs = tabulate_entity_dofs_{factory_name}, - .num_entity_closure_dofs = {num_entity_closure_dofs}, - .tabulate_entity_closure_dofs = tabulate_entity_closure_dofs_{factory_name}, .num_sub_dofmaps = {num_sub_dofmaps}, .sub_dofmaps = {sub_dofmaps} }}; diff --git a/ffcx/codegeneration/form.py b/ffcx/codegeneration/form.py index 14a806de1..2e95edadd 100644 --- a/ffcx/codegeneration/form.py +++ b/ffcx/codegeneration/form.py @@ -30,13 +30,6 @@ def generator(ir, options): d["num_coefficients"] = ir.num_coefficients d["num_constants"] = ir.num_constants - code = [] - cases = [] - for itg_type in ("cell", "interior_facet", "exterior_facet"): - cases += [(L.Symbol(itg_type), L.Return(len(ir.subdomain_ids[itg_type])))] - code += [L.Switch("integral_type", cases, default=L.Return(0))] - d["num_integrals"] = L.StatementList(code) - if len(ir.original_coefficient_position) > 0: d["original_coefficient_position_init"] = L.ArrayDecl( "int", @@ -95,6 +88,7 @@ def generator(ir, options): integrals = [] integral_ids = [] integral_offsets = [0] + # Note: the order of this list is defined by the enum ufcx_integral_type in ufcx.h for itg_type in ("cell", "exterior_facet", "interior_facet"): integrals += [L.AddressOf(L.Symbol(itg)) for itg in ir.integral_names[itg_type]] integral_ids += ir.subdomain_ids[itg_type] @@ -128,51 +122,6 @@ def generator(ir, options): sizes=len(integral_offsets), ) - code = [] - cases = [] - code_ids = [] - cases_ids = [] - for itg_type in ("cell", "interior_facet", "exterior_facet"): - if len(ir.integral_names[itg_type]) > 0: - code += [ - L.ArrayDecl( - "static ufcx_integral*", - f"integrals_{itg_type}_{ir.name}", - values=[ - L.AddressOf(L.Symbol(itg)) - for itg in ir.integral_names[itg_type] - ], - sizes=len(ir.integral_names[itg_type]), - ) - ] - cases.append( - ( - L.Symbol(itg_type), - L.Return(L.Symbol(f"integrals_{itg_type}_{ir.name}")), - ) - ) - - code_ids += [ - L.ArrayDecl( - "static int", - f"integral_ids_{itg_type}_{ir.name}", - values=ir.subdomain_ids[itg_type], - sizes=len(ir.subdomain_ids[itg_type]), - ) - ] - cases_ids.append( - ( - L.Symbol(itg_type), - L.Return(L.Symbol(f"integral_ids_{itg_type}_{ir.name}")), - ) - ) - - code += [L.Switch("integral_type", cases, default=L.Return(L.Null()))] - code_ids += [L.Switch("integral_type", cases_ids, default=L.Return(L.Null()))] - d["integrals"] = L.StatementList(code) - - d["integral_ids"] = L.StatementList(code_ids) - code = [] function_name = L.Symbol("function_name") diff --git a/ffcx/codegeneration/form_template.py b/ffcx/codegeneration/form_template.py index 5c0b1d22c..02ad31d6f 100644 --- a/ffcx/codegeneration/form_template.py +++ b/ffcx/codegeneration/form_template.py @@ -40,21 +40,6 @@ {constant_name_map} }} -int* integral_ids_{factory_name}(ufcx_integral_type integral_type) -{{ -{integral_ids} -}} - -int num_integrals_{factory_name}(ufcx_integral_type integral_type) -{{ -{num_integrals} -}} - -ufcx_integral** integrals_{factory_name}(ufcx_integral_type integral_type) -{{ -{integrals} -}} - ufcx_form {factory_name} = {{ @@ -70,11 +55,6 @@ .finite_elements = {finite_elements}, .dofmaps = {dofmaps}, - .integral_ids = integral_ids_{factory_name}, - .num_integrals = num_integrals_{factory_name}, - - .integrals = integrals_{factory_name}, - .form_integrals = {form_integrals}, .form_integral_ids = {form_integral_ids}, .form_integral_offsets = form_integral_offsets_{factory_name} diff --git a/ffcx/codegeneration/ufcx.h b/ffcx/codegeneration/ufcx.h index 63567e305..f3e3d1e00 100644 --- a/ffcx/codegeneration/ufcx.h +++ b/ffcx/codegeneration/ufcx.h @@ -456,15 +456,6 @@ extern "C" /// Coefficient number j=i-r if r+j <= i < r+n ufcx_dofmap** dofmaps; - /// All ids for integrals - int* (*integral_ids)(ufcx_integral_type); - - /// Number of integrals - int (*num_integrals)(ufcx_integral_type); - - /// Get an integral on sub domain subdomain_id - ufcx_integral** (*integrals)(ufcx_integral_type); - /// List of cell, interior facet and exterior facet integrals ufcx_integral** form_integrals; diff --git a/test/test_add_mode.py b/test/test_add_mode.py index 7ee4aba31..87d159a78 100644 --- a/test/test_add_mode.py +++ b/test/test_add_mode.py @@ -35,11 +35,13 @@ def test_additive_facet_integral(mode, compile_args): ffi = module.ffi form0 = compiled_forms[0] - assert form0.num_integrals(module.lib.exterior_facet) == 1 - ids = form0.integral_ids(module.lib.exterior_facet) - assert ids[0] == -1 + integral_offsets = form0.form_integral_offsets + ex = module.lib.exterior_facet + assert integral_offsets[ex + 1] - integral_offsets[ex] == 1 + integral_id = form0.form_integral_ids[integral_offsets[ex]] + assert integral_id == -1 - default_integral = form0.integrals(module.lib.exterior_facet)[0] + default_integral = form0.form_integrals[integral_offsets[ex]] np_type = cdtype_to_numpy(mode) A = np.zeros((3, 3), dtype=np_type) @@ -83,11 +85,14 @@ def test_additive_cell_integral(mode, compile_args): ffi = module.ffi form0 = compiled_forms[0] - assert form0.num_integrals(module.lib.cell) == 1 - ids = form0.integral_ids(module.lib.cell) - assert ids[0] == -1 + cell = module.lib.cell + offsets = form0.form_integral_offsets + num_integrals = offsets[cell + 1] - offsets[cell] + assert num_integrals == 1 + integral_id = form0.form_integral_ids[offsets[cell]] + assert integral_id == -1 - default_integral = form0.integrals(0)[0] + default_integral = form0.form_integrals[offsets[cell]] np_type = cdtype_to_numpy(mode) A = np.zeros((3, 3), dtype=np_type) diff --git a/test/test_blocked_elements.py b/test/test_blocked_elements.py index fd4db7b48..c67660a54 100644 --- a/test/test_blocked_elements.py +++ b/test/test_blocked_elements.py @@ -32,15 +32,11 @@ def test_finite_element(compile_args): assert ufcx_dofmap.num_global_support_dofs == 0 assert ufcx_dofmap.num_global_support_dofs == 0 assert ufcx_dofmap.num_element_support_dofs == 3 - assert ufcx_dofmap.num_entity_dofs[0] == 1 - assert ufcx_dofmap.num_entity_dofs[1] == 0 - assert ufcx_dofmap.num_entity_dofs[2] == 0 - assert ufcx_dofmap.num_entity_dofs[3] == 0 + off = np.array([ufcx_dofmap.entity_dof_offsets[i] for i in range(8)]) + assert np.all(np.diff(off) == [1, 1, 1, 0, 0, 0, 0]) + for v in range(3): - vals = np.zeros(1, dtype=np.int32) - vals_ptr = module.ffi.cast("int *", module.ffi.from_buffer(vals)) - ufcx_dofmap.tabulate_entity_dofs(vals_ptr, 0, v) - assert vals[0] == v + assert ufcx_dofmap.entity_dofs[v] == v assert ufcx_dofmap.num_sub_dofmaps == 0 @@ -66,15 +62,11 @@ def test_vector_element(compile_args): assert ufcx_dofmap.num_global_support_dofs == 0 assert ufcx_dofmap.num_global_support_dofs == 0 assert ufcx_dofmap.num_element_support_dofs == 3 - assert ufcx_dofmap.num_entity_dofs[0] == 1 - assert ufcx_dofmap.num_entity_dofs[1] == 0 - assert ufcx_dofmap.num_entity_dofs[2] == 0 - assert ufcx_dofmap.num_entity_dofs[3] == 0 + off = np.array([ufcx_dofmap.entity_dof_offsets[i] for i in range(8)]) + assert np.all(np.diff(off) == [1, 1, 1, 0, 0, 0, 0]) + for v in range(3): - vals = np.zeros(1, dtype=np.int32) - vals_ptr = module.ffi.cast("int *", module.ffi.from_buffer(vals)) - ufcx_dofmap.tabulate_entity_dofs(vals_ptr, 0, v) - assert vals[0] == v + assert ufcx_dofmap.entity_dofs[v] == v assert ufcx_dofmap.num_sub_dofmaps == 2 @@ -102,15 +94,11 @@ def test_tensor_element(compile_args): assert ufcx_dofmap.num_global_support_dofs == 0 assert ufcx_dofmap.num_global_support_dofs == 0 assert ufcx_dofmap.num_element_support_dofs == 3 - assert ufcx_dofmap.num_entity_dofs[0] == 1 - assert ufcx_dofmap.num_entity_dofs[1] == 0 - assert ufcx_dofmap.num_entity_dofs[2] == 0 - assert ufcx_dofmap.num_entity_dofs[3] == 0 + off = np.array([ufcx_dofmap.entity_dof_offsets[i] for i in range(8)]) + assert np.all(np.diff(off) == [1, 1, 1, 0, 0, 0, 0]) + for v in range(3): - vals = np.zeros(1, dtype=np.int32) - vals_ptr = module.ffi.cast("int *", module.ffi.from_buffer(vals)) - ufcx_dofmap.tabulate_entity_dofs(vals_ptr, 0, v) - assert vals[0] == v + assert ufcx_dofmap.entity_dofs[v] == v assert ufcx_dofmap.num_sub_dofmaps == 4 @@ -136,14 +124,10 @@ def test_vector_quadrature_element(compile_args): assert ufcx_dofmap.num_global_support_dofs == 0 assert ufcx_dofmap.num_global_support_dofs == 0 assert ufcx_dofmap.num_element_support_dofs == 4 - assert ufcx_dofmap.num_entity_dofs[0] == 0 - assert ufcx_dofmap.num_entity_dofs[1] == 0 - assert ufcx_dofmap.num_entity_dofs[2] == 0 - assert ufcx_dofmap.num_entity_dofs[3] == 4 - - vals = np.zeros(4, dtype=np.int32) - vals_ptr = module.ffi.cast("int *", module.ffi.from_buffer(vals)) - ufcx_dofmap.tabulate_entity_dofs(vals_ptr, 3, 0) - assert (vals == [0, 1, 2, 3]).all() + off = np.array([ufcx_dofmap.entity_dof_offsets[i] for i in range(16)]) + assert np.all(np.diff(off) == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4]) + + for i in range(4): + assert ufcx_dofmap.entity_dofs[i] == i assert ufcx_dofmap.num_sub_dofmaps == 3 diff --git a/test/test_jit_forms.py b/test/test_jit_forms.py index b6fe9518b..caa085d35 100644 --- a/test/test_jit_forms.py +++ b/test/test_jit_forms.py @@ -39,11 +39,13 @@ def test_laplace_bilinear_form_2d(mode, expected_result, compile_args): ffi = module.ffi form0 = compiled_forms[0] - assert form0.num_integrals(module.lib.cell) == 1 - ids = form0.integral_ids(module.lib.cell) - assert ids[0] == -1 + offsets = form0.form_integral_offsets + cell = module.lib.cell + assert offsets[cell + 1] - offsets[cell] == 1 + integral_id = form0.form_integral_ids[offsets[cell]] + assert integral_id == -1 - default_integral = form0.integrals(module.lib.cell)[0] + default_integral = form0.form_integrals[offsets[cell]] np_type = cdtype_to_numpy(mode) A = np.zeros((3, 3), dtype=np_type) @@ -107,8 +109,8 @@ def test_mass_bilinear_form_2d(mode, expected_result, compile_args): for f, compiled_f in zip(forms, compiled_forms): assert compiled_f.rank == len(f.arguments()) - form0 = compiled_forms[0].integrals(module.lib.cell)[0] - form1 = compiled_forms[1].integrals(module.lib.cell)[0] + form0 = compiled_forms[0].form_integrals[0] + form1 = compiled_forms[1].form_integrals[0] np_type = cdtype_to_numpy(mode) A = np.zeros((3, 3), dtype=np_type) @@ -165,7 +167,7 @@ def test_helmholtz_form_2d(mode, expected_result, compile_args): for f, compiled_f in zip(forms, compiled_forms): assert compiled_f.rank == len(f.arguments()) - form0 = compiled_forms[0].integrals(module.lib.cell)[0] + form0 = compiled_forms[0].form_integrals[0] np_type = cdtype_to_numpy(mode) A = np.zeros((3, 3), dtype=np_type) @@ -213,7 +215,7 @@ def test_laplace_bilinear_form_3d(mode, expected_result, compile_args): for f, compiled_f in zip(forms, compiled_forms): assert compiled_f.rank == len(f.arguments()) - form0 = compiled_forms[0].integrals(module.lib.cell)[0] + form0 = compiled_forms[0].form_integrals[0] np_type = cdtype_to_numpy(mode) A = np.zeros((4, 4), dtype=np_type) @@ -249,7 +251,7 @@ def test_form_coefficient(compile_args): for f, compiled_f in zip(forms, compiled_forms): assert compiled_f.rank == len(f.arguments()) - form0 = compiled_forms[0].integrals(module.lib.cell)[0] + form0 = compiled_forms[0].form_integrals[0] A = np.zeros((3, 3), dtype=np.float64) w = np.array([1.0, 1.0, 1.0], dtype=np.float64) c = np.array([], dtype=np.float64) @@ -288,21 +290,26 @@ def test_subdomains(compile_args): assert compiled_f.rank == len(f.arguments()) form0 = compiled_forms[0] - ids = form0.integral_ids(module.lib.cell) + offsets = form0.form_integral_offsets + cell = module.lib.cell + ids = [form0.form_integral_ids[j] for j in range(offsets[cell], offsets[cell + 1])] assert ids[0] == -1 and ids[1] == 2 form1 = compiled_forms[1] - ids = form1.integral_ids(module.lib.cell) + offsets = form1.form_integral_offsets + ids = [form1.form_integral_ids[j] for j in range(offsets[cell], offsets[cell + 1])] assert ids[0] == -1 and ids[1] == 2 form2 = compiled_forms[2] - ids = form2.integral_ids(module.lib.cell) + offsets = form2.form_integral_offsets + ids = [form2.form_integral_ids[j] for j in range(offsets[cell], offsets[cell + 1])] assert ids[0] == 1 and ids[1] == 2 form3 = compiled_forms[3] - assert form3.num_integrals(module.lib.cell) == 0 - - ids = form3.integral_ids(module.lib.exterior_facet) + offsets = form3.form_integral_offsets + assert offsets[cell + 1] - offsets[cell] == 0 + exf = module.lib.exterior_facet + ids = [form3.form_integral_ids[j] for j in range(offsets[exf], offsets[exf + 1])] assert ids[0] == 0 and ids[1] == 210 @@ -325,7 +332,7 @@ def test_interior_facet_integral(mode, compile_args): ffi = module.ffi np_type = cdtype_to_numpy(mode) - integral0 = form0.integrals(module.lib.interior_facet)[0] + integral0 = form0.form_integrals[0] A = np.zeros((6, 6), dtype=np_type) w = np.array([], dtype=np_type) c = np.array([], dtype=np.float64) @@ -370,8 +377,8 @@ def test_conditional(mode, compile_args): compiled_forms, module, code = ffcx.codegeneration.jit.compile_forms( forms, options={'scalar_type': mode}, cffi_extra_compile_args=compile_args) - form0 = compiled_forms[0].integrals(module.lib.cell)[0] - form1 = compiled_forms[1].integrals(module.lib.cell)[0] + form0 = compiled_forms[0].form_integrals[0] + form1 = compiled_forms[1].form_integrals[0] ffi = module.ffi np_type = cdtype_to_numpy(mode) @@ -427,7 +434,7 @@ def test_custom_quadrature(compile_args): ffi = module.ffi form = compiled_forms[0] - default_integral = form.integrals(module.lib.cell)[0] + default_integral = form.form_integrals[0] A = np.zeros((6, 6), dtype=np.float64) w = np.array([], dtype=np.float64) @@ -512,8 +519,8 @@ def test_lagrange_triangle(compile_args, order, mode, sym_fun, ufl_fun): ffi = module.ffi form0 = compiled_forms[0] - assert form0.num_integrals(module.lib.cell) == 1 - default_integral = form0.integrals(module.lib.cell)[0] + assert form0.form_integral_offsets[module.lib.cell + 1] == 1 + default_integral = form0.form_integrals[0] np_type = cdtype_to_numpy(mode) b = np.zeros((order + 2) * (order + 1) // 2, dtype=np_type) @@ -603,9 +610,9 @@ def test_lagrange_tetrahedron(compile_args, order, mode, sym_fun, ufl_fun): ffi = module.ffi form0 = compiled_forms[0] - assert form0.num_integrals(module.lib.cell) == 1 + assert form0.form_integral_offsets[module.lib.cell + 1] == 1 - default_integral = form0.integrals(module.lib.cell)[0] + default_integral = form0.form_integrals[0] np_type = cdtype_to_numpy(mode) b = np.zeros((order + 3) * (order + 2) * (order + 1) // 6, dtype=np_type) @@ -640,9 +647,9 @@ def test_prism(compile_args): ffi = module.ffi form0 = compiled_forms[0] - assert form0.num_integrals(module.lib.cell) == 1 + assert form0.form_integral_offsets[module.lib.cell + 1] == 1 - default_integral = form0.integrals(module.lib.cell)[0] + default_integral = form0.form_integrals[0] b = np.zeros(6, dtype=np.float64) coords = np.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0, @@ -675,8 +682,8 @@ def test_complex_operations(compile_args): compiled_forms, module, code = ffcx.codegeneration.jit.compile_forms( forms, options={'scalar_type': mode}, cffi_extra_compile_args=compile_args) - form0 = compiled_forms[0].integrals(module.lib.cell)[0] - form1 = compiled_forms[1].integrals(module.lib.cell)[0] + form0 = compiled_forms[0].form_integrals[0] + form1 = compiled_forms[1].form_integrals[0] ffi = module.ffi np_type = cdtype_to_numpy(mode) @@ -752,9 +759,9 @@ def test_interval_vertex_quadrature(compile_args): ffi = module.ffi form0 = compiled_forms[0] - assert form0.num_integrals(module.lib.cell) == 1 + assert form0.form_integral_offsets[module.lib.cell + 1] == 1 - default_integral = form0.integrals(module.lib.cell)[0] + default_integral = form0.form_integrals[0] J = np.zeros(1, dtype=np.float64) a = np.pi b = np.exp(1) @@ -799,9 +806,11 @@ def test_facet_vertex_quadrature(compile_args): assert len(compiled_forms) == 2 solutions = [] for form in compiled_forms: - assert form.num_integrals(module.lib.exterior_facet) == 1 + offsets = form.form_integral_offsets + exf = module.lib.exterior_facet + assert offsets[exf + 1] - offsets[exf] == 1 - default_integral = form.integrals(module.lib.exterior_facet)[0] + default_integral = form.form_integrals[offsets[exf]] J = np.zeros(1, dtype=np.float64) a = np.pi b = np.exp(1) @@ -850,7 +859,7 @@ def test_manifold_derivatives(compile_args): compiled_forms, module, _ = ffcx.codegeneration.jit.compile_forms( [J], cffi_extra_compile_args=compile_args) - default_integral = compiled_forms[0].integrals(module.lib.cell)[0] + default_integral = compiled_forms[0].form_integrals[0] scale = 2.5 coords = np.array([0.0, 0.0, 0.0, 0.0, scale, 0.0], dtype=np.float64) dof_coords = el.element.points.reshape(-1)