Skip to content

Commit

Permalink
Type fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
garth-wells committed Jun 5, 2024
1 parent d324f38 commit 6c13c63
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions ffcx/element_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import basix.ufl
import numpy as np
import numpy.typing as npt
from basix import CellType as _CellType
from basix import QuadratureType as _QuadratureType


def basix_index(indices: tuple[int]) -> int:
Expand All @@ -18,33 +20,31 @@ def basix_index(indices: tuple[int]) -> int:

def create_quadrature(
cellname: str, degree: int, rule: str, elements: list[basix.ufl._ElementBase]
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
) -> tuple[npt.ArrayLike, npt.ArrayLike]:
"""Create a quadrature rule."""
if cellname == "vertex":
return (np.ones((1, 0), dtype=np.float64), np.ones(1, dtype=np.float64))
else:
celltype = basix.cell.string_to_type(cellname)
celltype = _CellType[cellname]
polyset_type = basix.PolysetType.standard
for e in elements:
polyset_type = basix.polyset_superset(celltype, polyset_type, e.polyset_type)
return basix.make_quadrature(
celltype, degree, rule=basix.quadrature.string_to_type(rule), polyset_type=polyset_type
celltype, degree, rule=_QuadratureType[rule], polyset_type=polyset_type
)


def reference_cell_vertices(cellname: str) -> npt.NDArray[np.float64]:
"""Get the vertices of a reference cell."""
return basix.geometry(basix.cell.string_to_type(cellname))
return np.asarray(basix.geometry(_CellType[cellname]))


def map_facet_points(
points: npt.NDArray[np.float64], facet: int, cellname: str
) -> npt.NDArray[np.float64]:
"""Map points from a reference facet to a physical facet."""
geom = basix.geometry(basix.cell.string_to_type(cellname))
facet_vertices = [
geom[i] for i in basix.topology(basix.cell.string_to_type(cellname))[-2][facet]
]
geom = np.asarray(basix.geometry(_CellType[cellname]))
facet_vertices = [geom[i] for i in basix.topology(_CellType[cellname])[-2][facet]]
return np.asarray(
[
facet_vertices[0]
Expand Down

0 comments on commit 6c13c63

Please sign in to comment.