Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ksagiyam/submesh core #3478

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ jobs:
--install defcon \
--install gadopt \
--install asQ \
--package-branch tsfc ksagiyam/introduce_mixed_map \
--package-branch ufl ksagiyam/introduce_mixed_map \
|| (cat firedrake-install.log && /bin/false)
- name: Install test dependencies
run: |
Expand Down
12 changes: 6 additions & 6 deletions demos/saddle_point_pc/saddle_point_systems.py.rst
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ Finally, at each mesh size, we print out the number of cells in the
mesh and the number of iterations the solver took to converge ::

#
print(w.function_space().mesh().num_cells(), solver.snes.ksp.getIterationNumber())
print(w.function_space().mesh().unique().num_cells(), solver.snes.ksp.getIterationNumber())

The resulting convergence is unimpressive:

Expand Down Expand Up @@ -289,7 +289,7 @@ applying the action of blocks, so we can use a block matrix format. ::
for n in range(8):
solver, w = build_problem(n, parameters, block_matrix=True)
solver.solve()
print(w.function_space().mesh().num_cells(), solver.snes.ksp.getIterationNumber())
print(w.function_space().mesh().unique().num_cells(), solver.snes.ksp.getIterationNumber())

The resulting convergence is algorithmically good, however, the larger
problems still take a long time.
Expand Down Expand Up @@ -374,7 +374,7 @@ Let's see what happens. ::
for n in range(8):
solver, w = build_problem(n, parameters, block_matrix=True)
solver.solve()
print(w.function_space().mesh().num_cells(), solver.snes.ksp.getIterationNumber())
print(w.function_space().mesh().unique().num_cells(), solver.snes.ksp.getIterationNumber())

This is much better, the problem takes much less time to solve and
when observing the iteration counts for inverting :math:`S` we can see
Expand Down Expand Up @@ -429,7 +429,7 @@ and so we no longer need a flexible Krylov method. ::
for n in range(8):
solver, w = build_problem(n, parameters, block_matrix=True)
solver.solve()
print(w.function_space().mesh().num_cells(), solver.snes.ksp.getIterationNumber())
print(w.function_space().mesh().unique().num_cells(), solver.snes.ksp.getIterationNumber())

This results in the following GMRES iteration counts

Expand Down Expand Up @@ -494,7 +494,7 @@ variable. We can provide it as an :class:`~.AuxiliaryOperatorPC` via a python pr
for n in range(8):
solver, w = build_problem(n, parameters, aP=None, block_matrix=False)
solver.solve()
print(w.function_space().mesh().num_cells(), solver.snes.ksp.getIterationNumber())
print(w.function_space().mesh().unique().num_cells(), solver.snes.ksp.getIterationNumber())

This actually results in slightly worse convergence than the diagonal
approximation we used above.
Expand Down Expand Up @@ -578,7 +578,7 @@ Let's see what the iteration count looks like now. ::
for n in range(8):
solver, w = build_problem(n, parameters, aP=riesz, block_matrix=True)
solver.solve()
print(w.function_space().mesh().num_cells(), solver.snes.ksp.getIterationNumber())
print(w.function_space().mesh().unique().num_cells(), solver.snes.ksp.getIterationNumber())

============== ==================
Mesh elements GMRES iterations
Expand Down
227 changes: 180 additions & 47 deletions firedrake/assemble.py

Large diffs are not rendered by default.

2 changes: 0 additions & 2 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,6 @@ def hermite_stride(bcnodes):
# take intersection of facet nodes, and add it to bcnodes
# i, j, k can also be strings.
bcnodes1 = []
if len(s) > 1 and not isinstance(self._function_space.finat_element, (finat.Lagrange, finat.GaussLobattoLegendre)):
raise TypeError("Currently, edge conditions have only been tested with CG Lagrange elements")
for ss in s:
# intersection of facets
# Edge conditions have only been tested with Lagrange elements.
Expand Down
12 changes: 11 additions & 1 deletion firedrake/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,8 @@ def save_mesh(self, mesh, distribution_name=None, permutation_name=None):
:kwarg distribution_name: the name under which distribution is saved; if `None`, auto-generated name will be used.
:kwarg permutation_name: the name under which permutation is saved; if `None`, auto-generated name will be used.
"""
# TODO: Add general MixedMesh support.
mesh = mesh.unique()
mesh.init()
# Handle extruded mesh
tmesh = mesh.topology
Expand Down Expand Up @@ -836,6 +838,8 @@ def get_timestepping_history(self, mesh, name):
@PETSc.Log.EventDecorator("SaveFunctionSpace")
def _save_function_space(self, V):
mesh = V.mesh()
# TODO: Add general MixedMesh support.
mesh = mesh.unique()
if isinstance(V.topological, impl.MixedFunctionSpace):
V_name = self._generate_function_space_name(V)
base_path = self._path_to_mixed_function_space(mesh.name, V_name)
Expand Down Expand Up @@ -911,10 +915,12 @@ def save_function(self, f, idx=None, name=None, timestepping_info={}):
each index.
"""
V = f.function_space()
mesh = V.mesh()
if name:
g = Function(V, val=f.dat, name=name)
return self.save_function(g, idx=idx, timestepping_info=timestepping_info)
mesh = V.mesh()
# TODO: Add general MixedMesh support.
mesh = mesh.unique()
# -- Save function space --
self._save_function_space(V)
# -- Save function --
Expand Down Expand Up @@ -1233,6 +1239,8 @@ def _load_mesh_topology(self, tmesh_name, reorder, distribution_parameters):

@PETSc.Log.EventDecorator("LoadFunctionSpace")
def _load_function_space(self, mesh, name):
# TODO: Add general MixedMesh support.
mesh = mesh.unique()
mesh.init()
mesh_key = self._generate_mesh_key_from_names(mesh.name,
mesh.topology._distribution_name,
Expand Down Expand Up @@ -1310,6 +1318,8 @@ def load_function(self, mesh, name, idx=None):
be loaded with idx only when it was saved with idx.
:returns: the loaded :class:`~.Function`.
"""
# TODO: Add general MixedMesh support.
mesh = mesh.unique()
tmesh = mesh.topology
if name in self._get_mixed_function_name_mixed_function_space_name_map(mesh.name):
V_name = self._get_mixed_function_name_mixed_function_space_name_map(mesh.name)[name]
Expand Down
17 changes: 11 additions & 6 deletions firedrake/dmhooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

import firedrake
from firedrake.petsc import PETSc
from firedrake.mesh import MixedMeshGeometry


@PETSc.Log.EventDecorator()
Expand All @@ -53,8 +54,11 @@ def get_function_space(dm):
:raises RuntimeError: if no function space was found.
"""
info = dm.getAttr("__fs_info__")
meshref, element, indices, (name, names) = info
mesh = meshref()
meshref_tuple, element, indices, (name, names) = info
if len(meshref_tuple) == 1:
mesh = meshref_tuple[0]()
else:
mesh = MixedMeshGeometry([meshref() for meshref in meshref_tuple])
if mesh is None:
raise RuntimeError("Somehow your mesh was collected, this should never happen")
V = firedrake.FunctionSpace(mesh, element, name=name)
Expand All @@ -78,8 +82,6 @@ def set_function_space(dm, V):
This stores the information necessary to make a function space given a DM.

"""
mesh = V.mesh()

indices = []
names = []
while V.parent is not None:
Expand All @@ -90,11 +92,12 @@ def set_function_space(dm, V):
assert V.index is None
indices.append(V.component)
V = V.parent
mesh = V.mesh()
if len(V) > 1:
names = tuple(V_.name for V_ in V)
element = V.ufl_element()

info = (weakref.ref(mesh), element, tuple(reversed(indices)), (V.name, names))
info = (tuple(weakref.ref(m) for m in mesh), element, tuple(reversed(indices)), (V.name, names))
dm.setAttr("__fs_info__", info)


Expand Down Expand Up @@ -412,7 +415,9 @@ def coarsen(dm, comm):
"""
from firedrake.mg.utils import get_level
V = get_function_space(dm)
hierarchy, level = get_level(V.mesh())
# TODO: Think harder.
m, = set(m_ for m_ in V.mesh())
hierarchy, level = get_level(m)
if level < 1:
raise RuntimeError("Cannot coarsen coarsest DM")
coarsen = get_ctx_coarsener(dm)
Expand Down
10 changes: 7 additions & 3 deletions firedrake/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,16 +620,20 @@ def at(self, arg, *args, **kwargs):

tolerance = kwargs.get('tolerance', None)
mesh = self.function_space().mesh()
if len(set(mesh)) == 1:
mesh_unique = mesh.unique()
else:
raise NotImplementedError("Not implemented for general mixed meshes")
if tolerance is None:
tolerance = mesh.tolerance
tolerance = mesh_unique.tolerance
else:
mesh.tolerance = tolerance
mesh_unique.tolerance = tolerance

# Handle f.at(0.3)
if not arg.shape:
arg = arg.reshape(-1)

if mesh.variable_layers:
if mesh_unique.variable_layers:
raise NotImplementedError("Point evaluation not implemented for variable layers")

# Validate geometric dimension
Expand Down
21 changes: 8 additions & 13 deletions firedrake/functionspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
API is functional, rather than object-based, to allow for simple
backwards-compatibility, argument checking, and dispatch.
"""
import itertools
import ufl
import finat.ufl

Expand Down Expand Up @@ -258,6 +259,8 @@ def MixedFunctionSpace(spaces, name=None, mesh=None):
:class:`finat.ufl.mixedelement.MixedElement`, ignored otherwise.

"""
from firedrake.mesh import MixedMeshGeometry

if isinstance(spaces, finat.ufl.FiniteElementBase):
# Build the spaces if we got a mixed element
assert type(spaces) is finat.ufl.MixedElement and mesh is not None
Expand All @@ -272,22 +275,15 @@ def rec(eles):
sub_elements.append(ele)
rec(spaces.sub_elements)
spaces = [FunctionSpace(mesh, element) for element in sub_elements]

# Check that function spaces are on the same mesh
meshes = [space.mesh() for space in spaces]
for i in range(1, len(meshes)):
if meshes[i] is not meshes[0]:
raise ValueError("All function spaces must be defined on the same mesh!")

# Flatten MixedMeshes.
meshes = list(itertools.chain(*[space.mesh() for space in spaces]))
try:
cls, = set(type(s) for s in spaces)
except ValueError:
# Neither primal nor dual
# We had not implemented something in between, so let's make it primal
cls = impl.WithGeometry

# Select mesh
mesh = meshes[0]
# Get topological spaces
spaces = tuple(s.topological for s in flatten(spaces))
# Error checking
Expand All @@ -301,10 +297,9 @@ def rec(eles):
else:
raise ValueError("Can't make mixed space with %s" % type(space))

new = impl.MixedFunctionSpace(spaces, name=name)
if mesh is not mesh.topology:
new = cls.create(new, mesh)
return new
mixed_mesh_geometry = MixedMeshGeometry(meshes)
new = impl.MixedFunctionSpace(spaces, mixed_mesh_geometry.topology, name=name)
return cls.create(new, mixed_mesh_geometry)


@PETSc.Log.EventDecorator("CreateFunctionSpace")
Expand Down
Loading
Loading