Skip to content

Commit

Permalink
merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam authored and pbrubeck committed Nov 12, 2024
1 parent c0cd874 commit 98e4ea3
Show file tree
Hide file tree
Showing 37 changed files with 138 additions and 136 deletions.
11 changes: 6 additions & 5 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,28 +342,29 @@ def function_arg(self, g):
del self._function_arg_update
except AttributeError:
pass
V = self.function_space()
if isinstance(g, firedrake.Function) and g.ufl_element().family() != "Real":
if g.function_space() != self.function_space():
if g.function_space() != V:
raise RuntimeError("%r is defined on incompatible FunctionSpace!" % g)
self._function_arg = g
elif isinstance(g, ufl.classes.Zero):
if g.ufl_shape and g.ufl_shape != self.function_space().ufl_element().value_shape:
if g.ufl_shape and g.ufl_shape != V.ufl_element().value_shape(V.mesh()):
raise ValueError(f"Provided boundary value {g} does not match shape of space")
# Special case. Scalar zero for direct Function.assign.
self._function_arg = g
elif isinstance(g, ufl.classes.Expr):
if g.ufl_shape != self.function_space().ufl_element().value_shape:
if g.ufl_shape != V.ufl_element().value_shape(V.mesh()):
raise RuntimeError(f"Provided boundary value {g} does not match shape of space")
try:
self._function_arg = firedrake.Function(self.function_space())
self._function_arg = firedrake.Function(V)
# Use `Interpolator` instead of assembling an `Interpolate` form
# as the expression compilation needs to happen at this stage to
# determine if we should use interpolation or projection
# -> e.g. interpolation may not be supported for the element.
self._function_arg_update = firedrake.Interpolator(g, self._function_arg)._interpolate
except (NotImplementedError, AttributeError):
# Element doesn't implement interpolation
self._function_arg = firedrake.Function(self.function_space()).project(g)
self._function_arg = firedrake.Function(V).project(g)
self._function_arg_update = firedrake.Projector(g, self._function_arg).project
else:
try:
Expand Down
7 changes: 3 additions & 4 deletions firedrake/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ def save_function(self, f, idx=None, name=None, timestepping_info={}):
self._update_function_name_function_space_name_map(tmesh.name, mesh.name, {f.name(): V_name})
# Embed if necessary
element = V.ufl_element()
_element = get_embedding_element_for_checkpointing(element)
_element = get_embedding_element_for_checkpointing(element, element.value_shape(mesh))
if _element != element:
path = self._path_to_function_embedded(tmesh.name, mesh.name, V_name, f.name())
self.require_group(path)
Expand Down Expand Up @@ -1337,7 +1337,7 @@ def load_function(self, mesh, name, idx=None):
_name = self.get_attr(path, PREFIX_EMBEDDED + "_function")
_f = self.load_function(mesh, _name, idx=idx)
element = V.ufl_element()
_element = get_embedding_element_for_checkpointing(element)
_element = get_embedding_element_for_checkpointing(element, element.value_shape(mesh))
method = get_embedding_method_for_checkpointing(element)
assert _element == _f.function_space().ufl_element()
f = Function(V, name=name)
Expand Down Expand Up @@ -1436,8 +1436,7 @@ def _get_shared_data_key_for_checkpointing(self, mesh, ufl_element):
shape = ufl_element.reference_value_shape
block_size = np.prod(shape)
elif isinstance(ufl_element, finat.ufl.VectorElement):
shape = ufl_element.value_shape[:1]
block_size = np.prod(shape)
block_size = ufl_element.reference_value_shape[0]
else:
block_size = 1
return (nodes_per_entity, real_tensorproduct, block_size)
Expand Down
2 changes: 1 addition & 1 deletion firedrake/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __new__(cls, value, domain=None, name=None, count=None):

if not isinstance(domain, ufl.AbstractDomain):
cell = ufl.as_cell(domain)
coordinate_element = finat.ufl.VectorElement("Lagrange", cell, 1, gdim=cell.geometric_dimension)
coordinate_element = finat.ufl.VectorElement("Lagrange", cell, 1, dim=cell.topological_dimension())
domain = ufl.Mesh(coordinate_element)

cell = domain.ufl_cell()
Expand Down
8 changes: 4 additions & 4 deletions firedrake/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import ufl


def get_embedding_dg_element(element, broken_cg=False):
def get_embedding_dg_element(element, value_shape, broken_cg=False):
cell = element.cell
family = lambda c: "DG" if c.is_simplex() else "DQ"
if isinstance(cell, ufl.TensorProductCell):
Expand All @@ -19,7 +19,7 @@ def get_embedding_dg_element(element, broken_cg=False):
scalar_element = finat.ufl.FiniteElement(family(cell), cell=cell, degree=degree)
if broken_cg:
scalar_element = finat.ufl.BrokenElement(scalar_element.reconstruct(family="Lagrange"))
shape = element.value_shape
shape = value_shape
if len(shape) == 0:
DG = scalar_element
elif len(shape) == 1:
Expand All @@ -37,12 +37,12 @@ def get_embedding_dg_element(element, broken_cg=False):
native_elements_for_checkpointing = {"Lagrange", "Discontinuous Lagrange", "Q", "DQ", "Real"}


def get_embedding_element_for_checkpointing(element):
def get_embedding_element_for_checkpointing(element, value_shape):
"""Convert the given UFL element to an element that :class:`~.CheckpointFile` can handle."""
if element.family() in native_elements_for_checkpointing:
return element
else:
return get_embedding_dg_element(element)
return get_embedding_dg_element(element, value_shape)


def get_embedding_method_for_checkpointing(element):
Expand Down
2 changes: 1 addition & 1 deletion firedrake/external_operators/point_expr_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, *operands, function_space, derivatives=None, argument_slots=(
if not isinstance(operator_data["func"], types.FunctionType):
raise TypeError("Expecting a FunctionType pointwise expression")
expr_shape = operator_data["func"](*operands).ufl_shape
if expr_shape != function_space.ufl_element().value_shape:
if expr_shape != function_space.ufl_element().value_shape(function_space.mesh()):
raise ValueError("The dimension does not match with the dimension of the function space %s" % function_space)

@property
Expand Down
2 changes: 1 addition & 1 deletion firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def argument(self, o):
else:
args += [Zero()
for j in numpy.ndindex(
V_is[i].ufl_element().value_shape)]
V_is[i].ufl_element().value_shape(V_is[i].mesh()))]
return self._arg_cache.setdefault(o, as_vector(args))


Expand Down
4 changes: 2 additions & 2 deletions firedrake/functionspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def VectorFunctionSpace(mesh, family, degree=None, dim=None,
"""
sub_element = make_scalar_element(mesh, family, degree, vfamily, vdegree, variant)
if dim is None:
dim = mesh.ufl_cell().geometric_dimension()
dim = mesh.geometric_dimension()
if not isinstance(dim, numbers.Integral) and dim > 0:
raise ValueError(f"Can't make VectorFunctionSpace with dim={dim}")
element = finat.ufl.VectorElement(sub_element, dim=dim)
Expand Down Expand Up @@ -237,7 +237,7 @@ def TensorFunctionSpace(mesh, family, degree=None, shape=None,
"""
sub_element = make_scalar_element(mesh, family, degree, vfamily, vdegree, variant)
shape = shape or (mesh.ufl_cell().geometric_dimension(),) * 2
shape = shape or (mesh.geometric_dimension(),) * 2
element = finat.ufl.TensorElement(sub_element, shape=shape, symmetry=symmetry)
return FunctionSpace(mesh, element, name=name)

Expand Down
2 changes: 1 addition & 1 deletion firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def __init__(self, mesh, element, name=None):
shape_element = element
if isinstance(element, finat.ufl.WithMapping):
shape_element = element.wrapee
sub = shape_element.sub_elements[0].value_shape
sub = shape_element.sub_elements[0].reference_value_shape
self.shape = rvs[:len(rvs) - len(sub)]
else:
self.shape = ()
Expand Down
14 changes: 7 additions & 7 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def __init__(
# VectorFunctionSpace equivalent is built from the scalar
# sub-element.
ufl_scalar_element = ufl_scalar_element.sub_elements[0]
if ufl_scalar_element.value_shape != ():
if ufl_scalar_element.value_shape(V_dest.mesh()) != ():
raise NotImplementedError(
"Can't yet cross-mesh interpolate onto function spaces made from VectorElements or TensorElements made from sub elements with value shape other than ()."
)
Expand Down Expand Up @@ -614,7 +614,7 @@ def __init__(
# I first point evaluate my expression at these locations, giving a
# P0DG function on the VOM. As described in the manual, this is an
# interpolation operation.
shape = V_dest.ufl_element().value_shape
shape = V_dest.ufl_element().value_shape(V_dest.mesh())
if len(shape) == 0:
fs_type = firedrake.FunctionSpace
elif len(shape) == 1:
Expand Down Expand Up @@ -988,7 +988,7 @@ def callable():
else:
# Make sure we have an expression of the right length i.e. a value for
# each component in the value shape of each function space
dims = [numpy.prod(fs.ufl_element().value_shape, dtype=int)
dims = [numpy.prod(fs.ufl_element().value_shape(fs.mesh()), dtype=int)
for fs in V]
loops = []
if numpy.prod(expr.ufl_shape, dtype=int) != sum(dims):
Expand Down Expand Up @@ -1024,13 +1024,13 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
if access is op2.READ:
raise ValueError("Can't have READ access for output function")

if len(expr.ufl_shape) != len(V.ufl_element().value_shape):
if len(expr.ufl_shape) != len(V.ufl_element().value_shape(V.mesh())):
raise RuntimeError('Rank mismatch: Expression rank %d, FunctionSpace rank %d'
% (len(expr.ufl_shape), len(V.ufl_element().value_shape)))
% (len(expr.ufl_shape), len(V.ufl_element().value_shape(V.mesh()))))

if expr.ufl_shape != V.ufl_element().value_shape:
if expr.ufl_shape != V.ufl_element().value_shape(V.mesh()):
raise RuntimeError('Shape mismatch: Expression shape %r, FunctionSpace shape %r'
% (expr.ufl_shape, V.ufl_element().value_shape))
% (expr.ufl_shape, V.ufl_element().value_shape(V.mesh())))

# NOTE: The par_loop is always over the target mesh cells.
target_mesh = as_domain(V)
Expand Down
47 changes: 21 additions & 26 deletions firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@
}


_supported_embedded_cell_types = [ufl.Cell('interval', 2),
ufl.Cell('triangle', 3),
ufl.Cell("quadrilateral", 3),
ufl.TensorProductCell(ufl.Cell('interval'), ufl.Cell('interval'), geometric_dimension=3)]
_supported_embedded_cell_types_and_gdims = [(ufl.Cell('interval'), 2),
(ufl.Cell('triangle'), 3),
(ufl.Cell("quadrilateral"), 3),
(ufl.TensorProductCell(ufl.Cell('interval'), ufl.Cell('interval')), 3)]


unmarked = -1
Expand Down Expand Up @@ -1477,8 +1477,8 @@ def mark_entities(self, tf, label_value, label_name=None):
elem = tV.ufl_element()
if tV.mesh() is not self:
raise RuntimeError(f"tf must be defined on {self}: {tf.mesh()} is not {self}")
if elem.value_shape != ():
raise RuntimeError(f"tf must be scalar: {elem.value_shape} != ()")
if elem.reference_value_shape != ():
raise RuntimeError(f"tf must be scalar: {elem.reference_value_shape} != ()")
if elem.family() in {"Discontinuous Lagrange", "DQ"} and elem.degree() == 0:
# cells
height = 0
Expand Down Expand Up @@ -2303,7 +2303,7 @@ def callback(self):
self.topology.init()
coordinates_fs = functionspace.FunctionSpace(self.topology, self.ufl_coordinate_element())
coordinates_data = dmcommon.reordered_coords(topology.topology_dm, coordinates_fs.dm.getDefaultSection(),
(self.num_vertices(), self.ufl_coordinate_element().cell.geometric_dimension()))
(self.num_vertices(), self.geometric_dimension()))
coordinates = function.CoordinatelessFunction(coordinates_fs,
val=coordinates_data,
name=_generate_default_mesh_coordinates_name(self.name))
Expand Down Expand Up @@ -2470,7 +2470,7 @@ def spatial_index(self):
from firedrake import function, functionspace
from firedrake.parloops import par_loop, READ, MIN, MAX

gdim = self.ufl_cell().geometric_dimension()
gdim = self.geometric_dimension()
if gdim <= 1:
info_red("libspatialindex does not support 1-dimension, falling back on brute force.")
return None
Expand Down Expand Up @@ -2765,7 +2765,7 @@ def init_cell_orientations(self, expr):
import firedrake.function as function
import firedrake.functionspace as functionspace

if self.ufl_cell() not in _supported_embedded_cell_types:
if (self.ufl_cell(), self.geometric_dimension()) not in _supported_embedded_cell_types_and_gdims:
raise NotImplementedError('Only implemented for intervals embedded in 2d and triangles and quadrilaterals embedded in 3d')

if hasattr(self, '_cell_orientations'):
Expand All @@ -2774,8 +2774,8 @@ def init_cell_orientations(self, expr):
if not isinstance(expr, ufl.classes.Expr):
raise TypeError("UFL expression expected!")

if expr.ufl_shape != (self.ufl_cell().geometric_dimension(), ):
raise ValueError(f"Mismatching shapes: expr.ufl_shape ({expr.ufl_shape}) != (self.ufl_cell().geometric_dimension(), ) (({self.ufl_cell().geometric_dimension}, ))")
if expr.ufl_shape != (self.geometric_dimension(), ):
raise ValueError(f"Mismatching shapes: expr.ufl_shape ({expr.ufl_shape}) != (self.geometric_dimension(), ) (({self.geometric_dimension}, ))")

fs = functionspace.FunctionSpace(self, 'DG', 0)
x = ufl.SpatialCoordinate(self)
Expand Down Expand Up @@ -2848,12 +2848,9 @@ def make_mesh_from_coordinates(coordinates, name, tolerance=0.5):

V = coordinates.function_space()
element = coordinates.ufl_element()
if V.rank != 1 or len(element.value_shape) != 1:
if V.rank != 1 or len(element.reference_value_shape) != 1:
raise ValueError("Coordinates must be from a rank-1 FunctionSpace with rank-1 value_shape.")
assert V.mesh().ufl_cell().topological_dimension() <= V.value_size
# Build coordinate element
cell = element.cell.reconstruct(geometric_dimension=V.value_size)
element = element.reconstruct(cell=cell)

mesh = MeshGeometry.__new__(MeshGeometry, element, coordinates.comm)
mesh.__init__(coordinates)
Expand Down Expand Up @@ -2886,11 +2883,10 @@ def make_mesh_from_mesh_topology(topology, name, tolerance=0.5):
# TODO: meshfile might indicates higher-order coordinate element
cell = topology.ufl_cell()
geometric_dim = topology.topology_dm.getCoordinateDim()
cell = cell.reconstruct(geometric_dimension=geometric_dim)
if not topology.topology_dm.getCoordinatesLocalized():
element = finat.ufl.VectorElement("Lagrange", cell, 1)
element = finat.ufl.VectorElement("Lagrange", cell, 1, dim=geometric_dim)
else:
element = finat.ufl.VectorElement("DQ" if cell in [ufl.quadrilateral, ufl.hexahedron] else "DG", cell, 1, variant="equispaced")
element = finat.ufl.VectorElement("DQ" if cell in [ufl.quadrilateral, ufl.hexahedron] else "DG", cell, 1, dim=geometric_dim, variant="equispaced")
# Create mesh object
mesh = MeshGeometry.__new__(MeshGeometry, element, topology.comm)
mesh._init_topology(topology)
Expand Down Expand Up @@ -2922,10 +2918,9 @@ def make_vom_from_vom_topology(topology, name, tolerance=0.5):
import firedrake.function as function

gdim = topology.topology_dm.getCoordinateDim()
tcell = topology.ufl_cell()
cell = tcell.reconstruct(geometric_dimension=gdim)
element = finat.ufl.VectorElement("DG", cell, 0)
vmesh = MeshGeometry.__new__(MeshGeometry, element, topology.comm)
cell = topology.ufl_cell()
element = finat.ufl.VectorElement("DG", cell, 0, dim=gdim)
vmesh = MeshGeometry.__new__(MeshGeometry, element)
vmesh._init_topology(topology)
# Save vertex reference coordinate (within reference cell) in function
parent_tdim = topology._parent_mesh.ufl_cell().topological_dimension()
Expand Down Expand Up @@ -3214,7 +3209,7 @@ def ExtrudedMesh(mesh, layers, layer_height=None, extrusion_type='uniform', peri
pass
elif extrusion_type in ("radial", "radial_hedgehog"):
# do not allow radial extrusion if tdim = gdim
if mesh.ufl_cell().geometric_dimension() == mesh.ufl_cell().topological_dimension():
if mesh.geometric_dimension() == mesh.topological_dimension():
raise RuntimeError("Cannot radially-extrude a mesh with equal geometric and topological dimension")
else:
# check for kernel
Expand All @@ -3234,7 +3229,7 @@ def ExtrudedMesh(mesh, layers, layer_height=None, extrusion_type='uniform', peri
element = finat.ufl.TensorProductElement(helement, velement)

if gdim is None:
gdim = mesh.ufl_cell().geometric_dimension() + (extrusion_type == "uniform")
gdim = mesh.geometric_dimension() + (extrusion_type == "uniform")
coordinates_fs = functionspace.VectorFunctionSpace(topology, element, dim=gdim)

coordinates = function.CoordinatelessFunction(coordinates_fs, name=_generate_default_mesh_coordinates_name(name))
Expand Down Expand Up @@ -4537,8 +4532,8 @@ def RelabeledMesh(mesh, indicator_functions, subdomain_ids, **kwargs):
plex1.createLabel(label_name)
for f, subid in zip(indicator_functions, subdomain_ids):
elem = f.topological.function_space().ufl_element()
if elem.value_shape != ():
raise RuntimeError(f"indicator functions must be scalar: got {elem.value_shape} != ()")
if elem.reference_value_shape != ():
raise RuntimeError(f"indicator functions must be scalar: got {elem.reference_value_shape} != ()")
if elem.family() in {"Discontinuous Lagrange", "DQ"} and elem.degree() == 0:
# cells
height = 0
Expand Down
Loading

0 comments on commit 98e4ea3

Please sign in to comment.