Skip to content

Commit

Permalink
Add function determinant_expr_nxn (#101)
Browse files Browse the repository at this point in the history
* Add function determinant_expr_nxn

* Change the order of the assert expression

---------

Co-authored-by: Matthew Scroggs <[email protected]>
  • Loading branch information
lrtfm and mscroggs authored Sep 12, 2023
1 parent 5d6df80 commit 551f673
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
2 changes: 1 addition & 1 deletion test/test_apply_algebra_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_determinant2(A2):

def test_determinant3(A3):
assert determinant_expr(A3) == (A3[0, 0]*(A3[1, 1]*A3[2, 2] - A3[1, 2]*A3[2, 1])
+ A3[0, 1]*(A3[1, 2]*A3[2, 0] - A3[1, 0]*A3[2, 2])
+ (A3[1, 0]*A3[2, 2] - A3[1, 2]*A3[2, 0])*(-A3[0, 1])
+ A3[0, 2]*(A3[1, 0]*A3[2, 1] - A3[1, 1]*A3[2, 0]))


Expand Down
12 changes: 10 additions & 2 deletions ufl/compound_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def determinant_expr(A):
return determinant_expr_2x2(A)
elif sh[0] == 3:
return determinant_expr_3x3(A)
else:
return determinant_expr_nxn(A)
else:
return pseudo_determinant_expr(A)

Expand All @@ -116,15 +118,21 @@ def determinant_expr_3x3(A):
return codeterminant_expr_nxn(A, [0, 1, 2], [0, 1, 2])


def determinant_expr_nxn(A):
nrow, ncol = A.ufl_shape
assert nrow == ncol
return codeterminant_expr_nxn(A, list(range(nrow)), list(range(ncol)))


def codeterminant_expr_nxn(A, rows, cols):
if len(rows) == 2:
return _det_2x2(A, rows[0], rows[1], cols[0], cols[1])
codet = 0.0
r = rows[0]
subrows = rows[1:]
for i, c in enumerate(cols):
subcols = cols[i + 1:] + cols[:i]
codet += A[r, c] * codeterminant_expr_nxn(A, subrows, subcols)
subcols = cols[:i] + cols[i + 1:]
codet += (-1)**i * A[r, c] * codeterminant_expr_nxn(A, subrows, subcols)
return codet


Expand Down

0 comments on commit 551f673

Please sign in to comment.