From 551f6735fa238a6edc70a4de48178872190e6074 Mon Sep 17 00:00:00 2001 From: Yang Zongze Date: Tue, 12 Sep 2023 17:03:57 +0800 Subject: [PATCH] Add function determinant_expr_nxn (#101) * Add function determinant_expr_nxn * Change the order of the assert expression --------- Co-authored-by: Matthew Scroggs --- test/test_apply_algebra_lowering.py | 2 +- ufl/compound_expressions.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/test/test_apply_algebra_lowering.py b/test/test_apply_algebra_lowering.py index a9a097b15..a66332313 100755 --- a/test/test_apply_algebra_lowering.py +++ b/test/test_apply_algebra_lowering.py @@ -55,7 +55,7 @@ def test_determinant2(A2): def test_determinant3(A3): assert determinant_expr(A3) == (A3[0, 0]*(A3[1, 1]*A3[2, 2] - A3[1, 2]*A3[2, 1]) - + A3[0, 1]*(A3[1, 2]*A3[2, 0] - A3[1, 0]*A3[2, 2]) + + (A3[1, 0]*A3[2, 2] - A3[1, 2]*A3[2, 0])*(-A3[0, 1]) + A3[0, 2]*(A3[1, 0]*A3[2, 1] - A3[1, 1]*A3[2, 0])) diff --git a/ufl/compound_expressions.py b/ufl/compound_expressions.py index a25307298..af3e05fd9 100644 --- a/ufl/compound_expressions.py +++ b/ufl/compound_expressions.py @@ -93,6 +93,8 @@ def determinant_expr(A): return determinant_expr_2x2(A) elif sh[0] == 3: return determinant_expr_3x3(A) + else: + return determinant_expr_nxn(A) else: return pseudo_determinant_expr(A) @@ -116,6 +118,12 @@ def determinant_expr_3x3(A): return codeterminant_expr_nxn(A, [0, 1, 2], [0, 1, 2]) +def determinant_expr_nxn(A): + nrow, ncol = A.ufl_shape + assert nrow == ncol + return codeterminant_expr_nxn(A, list(range(nrow)), list(range(ncol))) + + def codeterminant_expr_nxn(A, rows, cols): if len(rows) == 2: return _det_2x2(A, rows[0], rows[1], cols[0], cols[1]) @@ -123,8 +131,8 @@ def codeterminant_expr_nxn(A, rows, cols): r = rows[0] subrows = rows[1:] for i, c in enumerate(cols): - subcols = cols[i + 1:] + cols[:i] - codet += A[r, c] * codeterminant_expr_nxn(A, subrows, subcols) + subcols = cols[:i] + cols[i + 1:] + codet += (-1)**i * A[r, c] * codeterminant_expr_nxn(A, subrows, subcols) return codet