Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trefftz support for Firedrake #3775

Draft
wants to merge 13 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions firedrake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
from firedrake.randomfunctiongen import *
from firedrake.external_operators import *
from firedrake.progress_bar import ProgressBar # noqa: F401
from firedrake.trefftz import *

from firedrake.logging import *
# Set default log level
Expand Down
231 changes: 231 additions & 0 deletions firedrake/trefftz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
"""
This module provides a class to compute the Trefftz embedding of a given function space.
It is also used to compute aggregation embedding of a given function space.
"""
from firedrake.petsc import PETSc
from firedrake.cython.dmcommon import FACE_SETS_LABEL, CELL_SETS_LABEL
from firedrake.assemble import assemble
from firedrake.mesh import Mesh
from firedrake.functionspace import FunctionSpace
from firedrake.function import Function
from firedrake.ufl_expr import TestFunction, TrialFunction
from firedrake.constant import Constant
from ufl import dx, dS, inner, jump, grad, dot, CellDiameter, FacetNormal
import scipy.sparse as sp


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please define __all__ to avoid pollution of the namespace

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or equivalently you can replace from firedrake.trefftz import * with from firedrake.trefftz import OnlyWhatIWant inside __init__.py.

class TrefftzEmbedding(object):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class TrefftzEmbedding(object):
class TrefftzEmbedding:

"""
This class computes the Trefftz embedding of a given function space
Parameters
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
This class computes the Trefftz embedding of a given function space
Parameters
"""Compute the Trefftz embedding of a given function space.
Parameters

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly for other docstrings

----------
V : :class:`.FunctionSpace`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove types from here and add type-hinting, see policy.

Ambient function space.
b : :class:`.ufl.form.Form`
Bilinear form defining the Trefftz operator.
dim : int, optional
Dimension of the embedding.
Default is the dimension of the function space.
tol : float, optional
Tolerance for the singular values cutoff.
Default is 1e-12.
backend : str, optional
Backend to use for the computation of the SVD.
Default is "scipy".
"""
def __init__(self, V, b, dim=None, tol=1e-12, backend="scipy"):
self.V = V
self.b = b
self.dim = V.dim() if not dim else dim + 1
self.tol = tol
self.backend = backend

def assemble(self):
"""
Assemble the embedding, compute the SVD and return the embedding matrix
"""
self.B = assemble(self.b).M.handle
if self.backend == "scipy":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work in parallel? We should have a parallel test.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure having a backend way we want to do this either...

indptr, indices, data = self.B.getValuesCSR()
Bsp = sp.csr_matrix((data, indices, indptr), shape=self.B.getSize())
_, sig, VT = sp.linalg.svds(Bsp, k=self.dim-1, which="SM")
QT = sp.csr_matrix(VT[0:sum(sig < self.tol), :])
QTpsc = PETSc.Mat().createAIJ(size=QT.shape, csr=(QT.indptr, QT.indices, QT.data))
self.dimT = QT.shape[0]
self.sig = sig
else:
raise NotImplementedError("Only scipy backend is supported")
return QTpsc, sig


class trefftz_ksp(object):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class trefftz_ksp(object):
class TrefftzKSP:

"""
This class wraps a PETSc KSP object to solve the reduced
system obtained by the Trefftz embedding.
"""
def __init__(self):
pass

@staticmethod
def get_appctx(ksp):
"""
Get the application context from the KSP
Parameters
----------
ksp : :class:`PETSc.KSP`
The KSP object
"""
from firedrake.dmhooks import get_appctx
return get_appctx(ksp.getDM()).appctx

def setUp(self, ksp):
"""
Set up the Trefftz KSP
Parameters
----------
ksp : :class:`PETSc.KSP`
The KSP object
"""
appctx = self.get_appctx(ksp)
self.QT, _ = appctx["trefftz_embedding"].assemble()

def solve(self, ksp, b, x):
"""
Solve the Trefftz KSP
Parameters
----------
ksp : :class:`PETSc.KSP`
The KSP object
b : :class:`PETSc.Vec`
The right-hand side
x : :class:`PETSc.Vec`
The solution
"""
A, P = ksp.getOperators()
self.Q = PETSc.Mat().createTranspose(self.QT)
ATF = self.QT @ A @ self.Q
PTF = self.QT @ P @ self.Q
bTF = self.QT.createVecLeft()
self.QT.mult(b, bTF)

tiny_ksp = PETSc.KSP().create()
tiny_ksp.setOperators(ATF, PTF)
tiny_ksp.setOptionsPrefix("trefftz_")
tiny_ksp.setFromOptions()
xTF = ATF.createVecRight()
tiny_ksp.solve(bTF, xTF)
self.QT.multTranspose(xTF, x)
ksp.setConvergedReason(tiny_ksp.getConvergedReason())


class AggregationEmbedding(TrefftzEmbedding):
"""
This class computes the aggregation embedding of a given function space.
Parameters
----------
V : :class:`.FunctionSpace`
Ambient function space.
mesh : :class:`.Mesh`
The mesh on which the aggregation is defined.
polyMesh : :class:`.Function`
The function defining the aggregation.
dim : int
Dimension of the embedding.
Default is the dimension of the function space.
tol : float
Tolerance for the singular values cutoff.
Default is 1e-12.
"""
def __init__(self, V, mesh, polyMesh, dim=None, tol=1e-12):
# Relabel facets that are inside an aggregated region
offset = 1 + mesh.topology_dm.getLabelSize(FACE_SETS_LABEL)
offset += mesh.topology_dm.getLabelSize(CELL_SETS_LABEL)
nPoly = int(max(polyMesh.dat.data[:])) # Number of aggregates
getIdx = mesh._cell_numbering.getOffset
plex = mesh.topology_dm
pStart, pEnd = plex.getDepthStratum(2)
self.facet_index = []
for poly in range(nPoly+1):
facets = []
for i in range(pStart, pEnd):
if polyMesh.dat.data[getIdx(i)] == poly:
for f in plex.getCone(i):
if f in facets:
plex.setLabelValue(FACE_SETS_LABEL, f, offset+poly)
if offset+poly not in self.facet_index:
self.facet_index = self.facet_index + [offset+poly]
facets = facets + list(plex.getCone(i))
self.mesh = Mesh(plex)
h = CellDiameter(self.mesh)
n = FacetNormal(self.mesh)
W = FunctionSpace(self.mesh, V.ufl_element())
u = TrialFunction(W)
v = TestFunction(W)
self.b = Constant(0)*inner(u, v)*dx
for i in self.facet_index:
self.b += inner(jump(u), jump(v))*dS(i)
for k in range(1, V.ufl_element().degree()+1):
for i in self.facet_index:
self.b += ((0.5 * h("+") + 0.5 * h("-"))**(2*k)) *\
inner(jump_normal(u, n("+"), k), jump_normal(v, n("+"), k))*dS(i)
super().__init__(W, self.b, dim, tol)


def jump_normal(u, n, k):
"""
Compute the jump of the normal derivative of a function u
Parameters
----------
u : :class:`.Function`
The function.
n : :class:`.ufc.Normal`
The normal vector.
k : int
The order of the normal derivative we aim to compute.
"""
j = 0.5*dot(n, (grad(u)("+")-grad(u)("-")))
for _ in range(1, k):
j = 0.5*dot(n, (grad(j)-grad(j)))
return j


def dumb_aggregation(mesh):
"""
Compute a dumb aggregation of the mesh
Parameters
----------
mesh : :class:`.Mesh`
The mesh we aim to aggregate.
"""
if mesh.comm.size > 1:
raise NotImplementedError("Parallel mesh aggregation not supported")
plex = mesh.topology_dm
pStart, pEnd = plex.getDepthStratum(2)
_, eEnd = plex.getDepthStratum(1)
adjacency = []
for i in range(pStart, pEnd):
ad = plex.getAdjacency(i)
local = []
for a in ad:
supp = plex.getSupport(a)
supp = supp[supp < eEnd]
for s in supp:
if s < pEnd and s != ad[0]:
local = local + [s]
adjacency = adjacency + [(i, local)]
adjacency = sorted(adjacency, key=lambda x: len(x[1]))[::-1]
u = Function(FunctionSpace(mesh, "DG", 0))

getIdx = mesh._cell_numbering.getOffset
av = list(range(pStart, pEnd))
col = 0
for a in adjacency:
if a[0] in av:
for k in a[1]:
if k in av:
av.remove(k)
u.dat.data[getIdx(k)] = col
av.remove(a[0])
u.dat.data[getIdx(a[0])] = col
col = col + 1
return u
80 changes: 80 additions & 0 deletions tests/regression/test_trefftz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from firedrake import *

Check failure on line 1 in tests/regression/test_trefftz.py

View workflow job for this annotation

GitHub Actions / Firedrake complex

test_trefftz.tests.regression.test_trefftz

tests.regression.test_trefftz

Check failure on line 1 in tests/regression/test_trefftz.py

View workflow job for this annotation

GitHub Actions / Firedrake real

test_trefftz.tests.regression.test_trefftz

tests.regression.test_trefftz
from firedrake.trefftz import TrefftzEmbedding, AggregationEmbedding, dumb_aggregation


@pytest.mark.skipcomplex
def test_trefftz_laplace():
order = 6
mesh = UnitSquareMesh(2, 2)
x, y = SpatialCoordinate(mesh)
h = CellDiameter(mesh)
n = FacetNormal(mesh)
V = FunctionSpace(mesh, "DG", order)
u = TrialFunction(V)
v = TestFunction(V)

def delta(u):
return div(grad(u))

a = inner(delta(u), delta(v)) * dx
alpha = 4
mean_dudn = 0.5 * dot(grad(u("+"))+grad(u("-")), n("+"))
mean_dvdn = 0.5 * dot(grad(v("+"))+grad(v("-")), n("+"))
aDG = inner(grad(u), grad(v)) * dx
aDG += inner((alpha*order**2/(h("+")+h("-")))*jump(u), jump(v))*dS
aDG += inner(-mean_dudn, jump(v))*dS-inner(mean_dvdn, jump(u))*dS
aDG += alpha*order**2/h*inner(u, v)*ds
aDG += -inner(dot(n, grad(u)), v)*ds - inner(dot(n, grad(v)), u)*ds
f = Function(V).interpolate(exp(x)*sin(y))
L = alpha*order**2/h*inner(f, v)*ds - inner(dot(n, grad(v)), f)*ds
# Solve the problem
uDG = Function(V)
uDG.rename("uDG")
embd = TrefftzEmbedding(V, a, tol=1e-8)
appctx = {"trefftz_embedding": embd}
uDG = Function(V)
solve(aDG == L, uDG, solver_parameters={"ksp_type": "python",
"ksp_python_type": "firedrake.trefftz.trefftz_ksp"},
appctx=appctx)
assert (assemble(inner(uDG-f, uDG-f)*dx) < 1e-6)
assert (embd.dimT < V.dim()/2)


@pytest.mark.skipcomplex
def test_trefftz_aggregation():
from netgen.occ import WorkPlane, OCCGeometry
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test needs to be skipped if ngsPETSc is not installed


Rectangle = WorkPlane().Rectangle(1, 1).Face()
geo = OCCGeometry(Rectangle, dim=2)
ngmesh = geo.GenerateMesh(maxh=0.3)
mesh = Mesh(ngmesh)

polymesh = dumb_aggregation(mesh)

order = 3
x, y = SpatialCoordinate(mesh)
h = CellDiameter(mesh)
n = FacetNormal(mesh)
V = FunctionSpace(mesh, "DG", order)
u = TrialFunction(V)
v = TestFunction(V)

alpha = 1e3
mean_dudn = 0.5 * dot(grad(u("+"))+grad(u("-")), n("+"))
mean_dvdn = 0.5 * dot(grad(v("+"))+grad(v("-")), n("+"))
aDG = inner(grad(u), grad(v)) * dx
aDG += inner((alpha*order**2/(h("+")+h("-")))*jump(u), jump(v))*dS
aDG += inner(-mean_dudn, jump(v))*dS-inner(mean_dvdn, jump(u))*dS
aDG += alpha*order**2/h*inner(u, v)*ds
aDG += -inner(dot(n, grad(u)), v)*ds - inner(dot(n, grad(v)), u)*ds
f = Function(V).interpolate(exp(x)*sin(y))
L = alpha*order**2/h*inner(f, v)*ds - inner(dot(n, grad(v)), f)*ds
agg_embd = AggregationEmbedding(V, mesh, polymesh)
appctx = {"trefftz_embedding": agg_embd}

uDG = Function(V)
solve(aDG == L, uDG, solver_parameters={"ksp_type": "python",
"ksp_python_type": "firedrake.trefftz.trefftz_ksp"},
appctx=appctx)

assert (assemble(inner(uDG-f, uDG-f)*dx) < 1e-9)
Loading