Skip to content

Commit

Permalink
Fix test.
Browse files Browse the repository at this point in the history
  • Loading branch information
jhale committed May 8, 2024
1 parent 56d7128 commit 9955d16
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions test/test_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,18 @@ def generate_kernel(forms, scalar_type, options):
@pytest.mark.parametrize(
"dtype",
[
np.float32,
np.float64,
"float32",
"float64",
pytest.param(
np.complex64,
"complex64",
marks=pytest.mark.xfail(
sys.platform.startswith("win32"),
raises=NotImplementedError,
reason="missing _Complex",
),
),
pytest.param(
np.complex128,
"complex128",
marks=pytest.mark.xfail(
sys.platform.startswith("win32"),
raises=NotImplementedError,
Expand All @@ -66,9 +66,6 @@ def test_numba_kernel_signature(dtype):
except ImportError:
pytest.skip("Numba not installed")

# Convert to numpy dtype
dtype = np.dtype(dtype)

# Create a simple form
mesh = ufl.Mesh(basix.ufl.element("P", "triangle", 2, shape=(2,)))
e = basix.ufl.element("Lagrange", "triangle", 2)
Expand All @@ -82,9 +79,12 @@ def test_numba_kernel_signature(dtype):
# Generate and compile the kernel
kernel, code, module = generate_kernel([a], dtype, {})

# Convert to numpy dtype
np_dtype = np.dtype(dtype)

# Generate the Numba signature
xtype = utils.dtype_to_scalar_dtype(dtype)
signature = utils.numba_ufcx_kernel_signature(dtype, xtype)
signature = utils.numba_ufcx_kernel_signature(np_dtype, xtype)
assert isinstance(signature, numba.core.typing.templates.Signature)

# Get the signature from the compiled kernel
Expand All @@ -94,6 +94,6 @@ def test_numba_kernel_signature(dtype):
# check that the signature is equivalent to the one in the generated code
assert len(args) == len(signature.args)
for i, (arg, sig) in enumerate(zip(args, signature.args)):
type_name = sig.name.replace(str(dtype), utils.dtype_to_c_type(dtype))
type_name = sig.name.replace(str(np_dtype), utils.dtype_to_c_type(np_dtype))
ctypes_name = type_name.replace(" *", "*")
assert ctypes_name == type_name

0 comments on commit 9955d16

Please sign in to comment.