From 9955d164e2dada401edb1de2e0cb724a972ebfc0 Mon Sep 17 00:00:00 2001 From: "Jack S. Hale" Date: Wed, 8 May 2024 10:20:14 +0200 Subject: [PATCH] Fix test. --- test/test_signatures.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/test_signatures.py b/test/test_signatures.py index 8c068cac9..68391e2c2 100644 --- a/test/test_signatures.py +++ b/test/test_signatures.py @@ -40,10 +40,10 @@ def generate_kernel(forms, scalar_type, options): @pytest.mark.parametrize( "dtype", [ - np.float32, - np.float64, + "float32", + "float64", pytest.param( - np.complex64, + "complex64", marks=pytest.mark.xfail( sys.platform.startswith("win32"), raises=NotImplementedError, @@ -51,7 +51,7 @@ def generate_kernel(forms, scalar_type, options): ), ), pytest.param( - np.complex128, + "complex128", marks=pytest.mark.xfail( sys.platform.startswith("win32"), raises=NotImplementedError, @@ -66,9 +66,6 @@ def test_numba_kernel_signature(dtype): except ImportError: pytest.skip("Numba not installed") - # Convert to numpy dtype - dtype = np.dtype(dtype) - # Create a simple form mesh = ufl.Mesh(basix.ufl.element("P", "triangle", 2, shape=(2,))) e = basix.ufl.element("Lagrange", "triangle", 2) @@ -82,9 +79,12 @@ def test_numba_kernel_signature(dtype): # Generate and compile the kernel kernel, code, module = generate_kernel([a], dtype, {}) + # Convert to numpy dtype + np_dtype = np.dtype(dtype) + # Generate the Numba signature xtype = utils.dtype_to_scalar_dtype(dtype) - signature = utils.numba_ufcx_kernel_signature(dtype, xtype) + signature = utils.numba_ufcx_kernel_signature(np_dtype, xtype) assert isinstance(signature, numba.core.typing.templates.Signature) # Get the signature from the compiled kernel @@ -94,6 +94,6 @@ def test_numba_kernel_signature(dtype): # check that the signature is equivalent to the one in the generated code assert len(args) == len(signature.args) for i, (arg, sig) in enumerate(zip(args, signature.args)): - type_name = sig.name.replace(str(dtype), utils.dtype_to_c_type(dtype)) + type_name = sig.name.replace(str(np_dtype), utils.dtype_to_c_type(np_dtype)) ctypes_name = type_name.replace(" *", "*") assert ctypes_name == type_name