-
-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into mscroggs/move-QuadratureElement
- Loading branch information
Showing
16 changed files
with
282 additions
and
372 deletions.
There are no files selected for viewing
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# Copyright (C) 2019 Michal Habera | ||
# | ||
# This file is part of FFCx.(https://www.fenicsproject.org) | ||
# | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
|
||
import logging | ||
|
||
from ffcx.codegeneration.C import expressions_template | ||
from ffcx.codegeneration.expression_generator import ExpressionGenerator | ||
from ffcx.codegeneration.backend import FFCXBackend | ||
from ffcx.codegeneration.C.format_lines import format_indented_lines | ||
from ffcx.naming import cdtype_to_numpy, scalar_to_value_type | ||
|
||
logger = logging.getLogger("ffcx") | ||
|
||
|
||
def generator(ir, options): | ||
"""Generate UFC code for an expression.""" | ||
logger.info("Generating code for expression:") | ||
logger.info(f"--- points: {ir.points}") | ||
logger.info(f"--- name: {ir.name}") | ||
|
||
factory_name = ir.name | ||
|
||
# Format declaration | ||
declaration = expressions_template.declaration.format( | ||
factory_name=factory_name, name_from_uflfile=ir.name_from_uflfile) | ||
|
||
backend = FFCXBackend(ir, options) | ||
eg = ExpressionGenerator(ir, backend) | ||
|
||
d = {} | ||
d["name_from_uflfile"] = ir.name_from_uflfile | ||
d["factory_name"] = ir.name | ||
|
||
parts = eg.generate() | ||
|
||
body = format_indented_lines(parts.cs_format(), 1) | ||
d["tabulate_expression"] = body | ||
|
||
if len(ir.original_coefficient_positions) > 0: | ||
d["original_coefficient_positions"] = f"original_coefficient_positions_{ir.name}" | ||
values = ", ".join(str(i) for i in ir.original_coefficient_positions) | ||
sizes = len(ir.original_coefficient_positions) | ||
d["original_coefficient_positions_init"] = \ | ||
f"static int original_coefficient_positions_{ir.name}[{sizes}] = {{{values}}};" | ||
else: | ||
d["original_coefficient_positions"] = "NULL" | ||
d["original_coefficient_positions_init"] = "" | ||
|
||
values = ", ".join(str(p) for p in ir.points.flatten()) | ||
sizes = ir.points.size | ||
d["points_init"] = f"static double points_{ir.name}[{sizes}] = {{{values}}};" | ||
d["points"] = f"points_{ir.name}" | ||
|
||
if len(ir.expression_shape) > 0: | ||
values = ", ".join(str(i) for i in ir.expression_shape) | ||
sizes = len(ir.expression_shape) | ||
d["value_shape_init"] = f"static int value_shape_{ir.name}[{sizes}] = {{{values}}};" | ||
d["value_shape"] = f"value_shape_{ir.name}" | ||
else: | ||
d["value_shape_init"] = "" | ||
d["value_shape"] = "NULL" | ||
|
||
d["num_components"] = len(ir.expression_shape) | ||
d["num_coefficients"] = len(ir.coefficient_numbering) | ||
d["num_constants"] = len(ir.constant_names) | ||
d["num_points"] = ir.points.shape[0] | ||
d["topological_dimension"] = ir.points.shape[1] | ||
d["scalar_type"] = options["scalar_type"] | ||
d["geom_type"] = scalar_to_value_type(options["scalar_type"]) | ||
d["np_scalar_type"] = cdtype_to_numpy(options["scalar_type"]) | ||
|
||
d["rank"] = len(ir.tensor_shape) | ||
|
||
if len(ir.coefficient_names) > 0: | ||
values = ", ".join(f'"{name}"' for name in ir.coefficient_names) | ||
sizes = len(ir.coefficient_names) | ||
d["coefficient_names_init"] = f"static const char* coefficient_names_{ir.name}[{sizes}] = {{{values}}};" | ||
d["coefficient_names"] = f"coefficient_names_{ir.name}" | ||
else: | ||
d["coefficient_names_init"] = "" | ||
d["coefficient_names"] = "NULL" | ||
|
||
if len(ir.constant_names) > 0: | ||
values = ", ".join(f'"{name}"' for name in ir.constant_names) | ||
sizes = len(ir.constant_names) | ||
d["constant_names_init"] = f"static const char* constant_names_{ir.name}[{sizes}] = {{{values}}};" | ||
d["constant_names"] = f"constant_names_{ir.name}" | ||
else: | ||
d["constant_names_init"] = "" | ||
d["constant_names"] = "NULL" | ||
|
||
code = [] | ||
|
||
# FIXME: Should be handled differently, revise how | ||
# ufcx_function_space is generated (also for ufcx_form) | ||
for (name, (element, dofmap, cmap_family, cmap_degree)) in ir.function_spaces.items(): | ||
code += [f"static ufcx_function_space function_space_{name}_{ir.name_from_uflfile} ="] | ||
code += ["{"] | ||
code += [f".finite_element = &{element},"] | ||
code += [f".dofmap = &{dofmap},"] | ||
code += [f".geometry_family = \"{cmap_family}\","] | ||
code += [f".geometry_degree = {cmap_degree}"] | ||
code += ["};"] | ||
|
||
d["function_spaces_alloc"] = "\n".join(code) | ||
d["function_spaces"] = "" | ||
|
||
if len(ir.function_spaces) > 0: | ||
d["function_spaces"] = f"function_spaces_{ir.name}" | ||
values = ", ".join(f"&function_space_{name}_{ir.name_from_uflfile}" | ||
for (name, _) in ir.function_spaces.items()) | ||
sizes = len(ir.function_spaces) | ||
d["function_spaces_init"] = f"ufcx_function_space* function_spaces_{ir.name}[{sizes}] = {{{values}}};" | ||
else: | ||
d["function_spaces"] = "NULL" | ||
d["function_spaces_init"] = "" | ||
|
||
# Check that no keys are redundant or have been missed | ||
from string import Formatter | ||
fields = [fname for _, fname, _, _ in Formatter().parse(expressions_template.factory) if fname] | ||
assert set(fields) == set(d.keys()), "Mismatch between keys in template and in formatting dict" | ||
|
||
# Format implementation code | ||
implementation = expressions_template.factory.format_map(d) | ||
|
||
return declaration, implementation |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
Oops, something went wrong.