Skip to content

Commit

Permalink
Get value_shape from FunctionSpace
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Nov 13, 2024
1 parent 98e4ea3 commit adb93f4
Show file tree
Hide file tree
Showing 28 changed files with 74 additions and 86 deletions.
4 changes: 2 additions & 2 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,12 +348,12 @@ def function_arg(self, g):
raise RuntimeError("%r is defined on incompatible FunctionSpace!" % g)
self._function_arg = g
elif isinstance(g, ufl.classes.Zero):
if g.ufl_shape and g.ufl_shape != V.ufl_element().value_shape(V.mesh()):
if g.ufl_shape and g.ufl_shape != V.value_shape:
raise ValueError(f"Provided boundary value {g} does not match shape of space")
# Special case. Scalar zero for direct Function.assign.
self._function_arg = g
elif isinstance(g, ufl.classes.Expr):
if g.ufl_shape != V.ufl_element().value_shape(V.mesh()):
if g.ufl_shape != V.value_shape:
raise RuntimeError(f"Provided boundary value {g} does not match shape of space")
try:
self._function_arg = firedrake.Function(V)
Expand Down
4 changes: 2 additions & 2 deletions firedrake/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ def save_function(self, f, idx=None, name=None, timestepping_info={}):
self._update_function_name_function_space_name_map(tmesh.name, mesh.name, {f.name(): V_name})
# Embed if necessary
element = V.ufl_element()
_element = get_embedding_element_for_checkpointing(element, element.value_shape(mesh))
_element = get_embedding_element_for_checkpointing(element, V.value_shape)
if _element != element:
path = self._path_to_function_embedded(tmesh.name, mesh.name, V_name, f.name())
self.require_group(path)
Expand Down Expand Up @@ -1337,7 +1337,7 @@ def load_function(self, mesh, name, idx=None):
_name = self.get_attr(path, PREFIX_EMBEDDED + "_function")
_f = self.load_function(mesh, _name, idx=idx)
element = V.ufl_element()
_element = get_embedding_element_for_checkpointing(element, element.value_shape(mesh))
_element = get_embedding_element_for_checkpointing(element, V.value_shape)
method = get_embedding_method_for_checkpointing(element)
assert _element == _f.function_space().ufl_element()
f = Function(V, name=name)
Expand Down
2 changes: 1 addition & 1 deletion firedrake/external_operators/point_expr_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, *operands, function_space, derivatives=None, argument_slots=(
if not isinstance(operator_data["func"], types.FunctionType):
raise TypeError("Expecting a FunctionType pointwise expression")
expr_shape = operator_data["func"](*operands).ufl_shape
if expr_shape != function_space.ufl_element().value_shape(function_space.mesh()):
if expr_shape != function_space.value_shape:
raise ValueError("The dimension does not match with the dimension of the function space %s" % function_space)

@property
Expand Down
3 changes: 1 addition & 2 deletions firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,7 @@ def argument(self, o):
args += [a_[j] for j in numpy.ndindex(a_.ufl_shape)]
else:
args += [Zero()
for j in numpy.ndindex(
V_is[i].ufl_element().value_shape(V_is[i].mesh()))]
for j in numpy.ndindex(V_is[i].value_shape)]
return self._arg_cache.setdefault(o, as_vector(args))


Expand Down
9 changes: 3 additions & 6 deletions firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,13 @@ def split(self):
def _components(self):
if len(self) == 1:
return tuple(type(self).create(self.topological.sub(i), self.mesh())
for i in range(self.value_size))
for i in range(numpy.prod(self.shape)))
else:
return self.subfunctions

@PETSc.Log.EventDecorator()
def sub(self, i):
if len(self) == 1:
bound = self.value_size
else:
bound = len(self)
bound = len(self._components)
if i < 0 or i >= bound:
raise IndexError("Invalid component %d, not in [0, %d)" % (i, bound))
return self._components[i]
Expand Down Expand Up @@ -654,7 +651,7 @@ def __getitem__(self, i):

@utils.cached_property
def _components(self):
return tuple(ComponentFunctionSpace(self, i) for i in range(self.value_size))
return tuple(ComponentFunctionSpace(self, i) for i in range(numpy.prod(self.shape)))

def sub(self, i):
r"""Return a view into the ith component."""
Expand Down
15 changes: 7 additions & 8 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,8 +546,7 @@ def __init__(
# For a VectorElement or TensorElement the correct
# VectorFunctionSpace equivalent is built from the scalar
# sub-element.
ufl_scalar_element = ufl_scalar_element.sub_elements[0]
if ufl_scalar_element.value_shape(V_dest.mesh()) != ():
if V_dest.sub(0).value_shape != ():
raise NotImplementedError(
"Can't yet cross-mesh interpolate onto function spaces made from VectorElements or TensorElements made from sub elements with value shape other than ()."
)
Expand Down Expand Up @@ -614,7 +613,7 @@ def __init__(
# I first point evaluate my expression at these locations, giving a
# P0DG function on the VOM. As described in the manual, this is an
# interpolation operation.
shape = V_dest.ufl_element().value_shape(V_dest.mesh())
shape = V_dest.value_shape
if len(shape) == 0:
fs_type = firedrake.FunctionSpace
elif len(shape) == 1:
Expand Down Expand Up @@ -988,7 +987,7 @@ def callable():
else:
# Make sure we have an expression of the right length i.e. a value for
# each component in the value shape of each function space
dims = [numpy.prod(fs.ufl_element().value_shape(fs.mesh()), dtype=int)
dims = [numpy.prod(fs.value_shape, dtype=int)
for fs in V]
loops = []
if numpy.prod(expr.ufl_shape, dtype=int) != sum(dims):
Expand Down Expand Up @@ -1024,13 +1023,13 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
if access is op2.READ:
raise ValueError("Can't have READ access for output function")

if len(expr.ufl_shape) != len(V.ufl_element().value_shape(V.mesh())):
if len(expr.ufl_shape) != len(V.value_shape):
raise RuntimeError('Rank mismatch: Expression rank %d, FunctionSpace rank %d'
% (len(expr.ufl_shape), len(V.ufl_element().value_shape(V.mesh()))))
% (len(expr.ufl_shape), len(V.value_shape)))

if expr.ufl_shape != V.ufl_element().value_shape(V.mesh()):
if expr.ufl_shape != V.value_shape:
raise RuntimeError('Shape mismatch: Expression shape %r, FunctionSpace shape %r'
% (expr.ufl_shape, V.ufl_element().value_shape(V.mesh())))
% (expr.ufl_shape, V.value_shape))

# NOTE: The par_loop is always over the target mesh cells.
target_mesh = as_domain(V)
Expand Down
2 changes: 1 addition & 1 deletion firedrake/mg/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def cache(self, key):

def get_cache_key(self, V):
elem = V.ufl_element()
value_shape = elem.value_shape(V.mesh())
value_shape = V.value_shape
return elem, value_shape

def V_dof_weights(self, V):
Expand Down
19 changes: 9 additions & 10 deletions firedrake/mg/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,10 @@ def compile_element(expression, dual_space=None, parameters=None,
return_variable = gem.Indexed(gem.Variable('R', finat_elem.index_shape), argument_multiindex)
result = gem.Indexed(result, tensor_indices)
if dual_space:
elem = create_element(dual_space.ufl_element())
if elem.value_shape:
var = gem.Indexed(gem.Variable("b", elem.value_shape),
tensor_indices)
b_arg = [lp.GlobalArg("b", dtype=ScalarType, shape=elem.value_shape)]
value_shape = dual_space.value_shape
if value_shape:
var = gem.Indexed(gem.Variable("b", value_shape), tensor_indices)
b_arg = [lp.GlobalArg("b", dtype=ScalarType, shape=value_shape)]
else:
var = gem.Indexed(gem.Variable("b", (1, )), (0, ))
b_arg = [lp.GlobalArg("b", dtype=ScalarType, shape=(1,))]
Expand Down Expand Up @@ -220,7 +219,7 @@ def prolong_kernel(expression):
assert hierarchy._meshes[int(idx)].cell_set._extruded
V = expression.function_space()
key = (("prolong",)
+ expression.ufl_element().value_shape(meshc)
+ V.value_shape
+ entity_dofs_key(V.finat_element.complex.get_topology())
+ entity_dofs_key(V.finat_element.entity_dofs())
+ entity_dofs_key(coordinates.function_space().finat_element.entity_dofs()))
Expand Down Expand Up @@ -284,7 +283,7 @@ def prolong_kernel(expression):
"evaluate": eval_code,
"spacedim": element.cell.get_spatial_dimension(),
"ncandidate": hierarchy.fine_to_coarse_cells[levelf].shape[1],
"Rdim": numpy.prod(element.value_shape),
"Rdim": numpy.prod(V.value_shape),
"inside_cell": inside_check(element.cell, eps=1e-8, X="Xref"),
"celldist_l1_c_expr": celldist_l1_c_expr(element.cell, X="Xref"),
"Xc_cell_inc": coords_element.space_dimension(),
Expand All @@ -302,7 +301,7 @@ def restrict_kernel(Vf, Vc):
if Vf.extruded:
assert Vc.extruded
key = (("restrict",)
+ Vf.ufl_element().value_shape(Vf.mesh())
+ Vf.value_shape
+ entity_dofs_key(Vf.finat_element.complex.get_topology())
+ entity_dofs_key(Vc.finat_element.complex.get_topology())
+ entity_dofs_key(Vf.finat_element.entity_dofs())
Expand Down Expand Up @@ -390,7 +389,7 @@ def inject_kernel(Vf, Vc):
else:
level_ratio = 1
key = (("inject", level_ratio)
+ Vf.ufl_element().value_shape(Vf.mesh())
+ Vf.value_shape
+ entity_dofs_key(Vc.finat_element.complex.get_topology())
+ entity_dofs_key(Vf.finat_element.complex.get_topology())
+ entity_dofs_key(Vc.finat_element.entity_dofs())
Expand Down Expand Up @@ -465,7 +464,7 @@ def inject_kernel(Vf, Vc):
"celldist_l1_c_expr": celldist_l1_c_expr(Vc.finat_element.cell, X="Xref"),
"tdim": Vc.mesh().topological_dimension(),
"ncandidate": ncandidate,
"Rdim": numpy.prod(Vf_element.value_shape),
"Rdim": numpy.prod(Vf.value_shape),
"Xf_cell_inc": coords_element.space_dimension(),
"f_cell_inc": Vf_element.space_dimension()
}
Expand Down
2 changes: 1 addition & 1 deletion firedrake/mg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def coarse_cell_to_fine_node_map(Vc, Vf):

def physical_node_locations(V):
element = V.ufl_element()
if element.value_shape(V.mesh()):
if V.value_shape:
assert isinstance(element, (finat.ufl.VectorElement, finat.ufl.TensorElement))
element = element.sub_elements[0]
mesh = V.mesh()
Expand Down
34 changes: 17 additions & 17 deletions firedrake/preconditioners/hiptmair.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ def coarsen(self, pc):
element = V.ufl_element()
formdegree = V.finat_element.formdegree
if formdegree == 1:
celement = curl_to_grad(element, mesh)
celement = curl_to_grad(element)
elif formdegree == 2:
celement = div_to_curl(element, mesh)
celement = div_to_curl(element)
else:
raise ValueError("Hiptmair decomposition not available for", element)

Expand Down Expand Up @@ -211,15 +211,15 @@ def coarsen(self, pc):
return coarse_operator, coarse_space_bcs, interp_petscmat


def curl_to_grad(ele, mesh):
def curl_to_grad(ele):
if isinstance(ele, finat.ufl.VectorElement):
return type(ele)(curl_to_grad(ele._sub_element, mesh), dim=ele.num_sub_elements)
return type(ele)(curl_to_grad(ele._sub_element), dim=ele.num_sub_elements)
elif isinstance(ele, finat.ufl.TensorElement):
return type(ele)(curl_to_grad(ele._sub_element, mesh), shape=ele.value_shape(mesh), symmetry=ele.symmetry())
return type(ele)(curl_to_grad(ele._sub_element), shape=ele._shape, symmetry=ele.symmetry())
elif isinstance(ele, finat.ufl.MixedElement):
return type(ele)(*(curl_to_grad(e, mesh) for e in ele.sub_elements))
return type(ele)(*(curl_to_grad(e) for e in ele.sub_elements))
elif isinstance(ele, finat.ufl.RestrictedElement):
return finat.ufl.RestrictedElement(curl_to_grad(ele._element, mesh), ele.restriction_domain())
return finat.ufl.RestrictedElement(curl_to_grad(ele._element), ele.restriction_domain())
else:
cell = ele.cell
family = ele.family()
Expand All @@ -238,25 +238,25 @@ def curl_to_grad(ele, mesh):
return finat.ufl.FiniteElement(family, cell=cell, degree=degree, variant=variant)


def div_to_curl(ele, mesh):
def div_to_curl(ele):
if isinstance(ele, finat.ufl.VectorElement):
return type(ele)(div_to_curl(ele._sub_element, mesh), dim=ele.num_sub_elements)
return type(ele)(div_to_curl(ele._sub_element), dim=ele.num_sub_elements)
elif isinstance(ele, finat.ufl.TensorElement):
return type(ele)(div_to_curl(ele._sub_element, mesh), shape=ele.value_shape(mesh), symmetry=ele.symmetry())
return type(ele)(div_to_curl(ele._sub_element), shape=ele._shape, symmetry=ele.symmetry())
elif isinstance(ele, finat.ufl.MixedElement):
return type(ele)(*(div_to_curl(e, mesh) for e in ele.sub_elements))
return type(ele)(*(div_to_curl(e) for e in ele.sub_elements))
elif isinstance(ele, finat.ufl.RestrictedElement):
return finat.ufl.RestrictedElement(div_to_curl(ele._element, mesh), ele.restriction_domain())
return finat.ufl.RestrictedElement(div_to_curl(ele._element), ele.restriction_domain())
elif isinstance(ele, finat.ufl.EnrichedElement):
return type(ele)(*(div_to_curl(e, mesh) for e in reversed(ele._elements)))
return type(ele)(*(div_to_curl(e) for e in reversed(ele._elements)))
elif isinstance(ele, finat.ufl.TensorProductElement):
return type(ele)(*(div_to_curl(e, mesh) for e in ele.sub_elements), cell=ele.cell)
return type(ele)(*(div_to_curl(e) for e in ele.sub_elements), cell=ele.cell)
elif isinstance(ele, finat.ufl.WithMapping):
return type(ele)(div_to_curl(ele.wrapee, mesh), ele.mapping())
return type(ele)(div_to_curl(ele.wrapee), ele.mapping())
elif isinstance(ele, finat.ufl.BrokenElement):
return type(ele)(div_to_curl(ele._element, mesh))
return type(ele)(div_to_curl(ele._element))
elif isinstance(ele, finat.ufl.HDivElement):
return finat.ufl.HCurlElement(div_to_curl(ele._element, mesh))
return finat.ufl.HCurlElement(div_to_curl(ele._element))
elif isinstance(ele, finat.ufl.HCurlElement):
raise ValueError("Expecting an H(div) element")
else:
Expand Down
10 changes: 4 additions & 6 deletions firedrake/preconditioners/pmg.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,7 @@ def initialize(self, obj):
elements = [ele]
while True:
try:
ele_ = self.coarsen_element(ele)
assert ele_.value_shape(V.mesh()) == ele.value_shape(V.mesh())
ele = ele_
ele = self.coarsen_element(ele)
except ValueError:
break
elements.append(ele)
Expand Down Expand Up @@ -1098,7 +1096,7 @@ def make_mapping_code(Q, cmapping, fmapping, t_in, t_out):
if B:
tensor = ufl.dot(B, tensor) if tensor else B
if tensor is None:
tensor = ufl.Identity(Q.ufl_element().value_shape(Q.mesh())[0])
tensor = ufl.Identity(Q.value_shape[0])

u = ufl.Coefficient(Q)
expr = ufl.dot(tensor, u)
Expand Down Expand Up @@ -1347,8 +1345,8 @@ def make_blas_kernels(self, Vf, Vc):
in_place_mapping = True
except Exception:
qelem = finat.ufl.FiniteElement("DQ", cell=felem.cell, degree=PMGBase.max_degree(felem))
if felem.value_shape(Vf.mesh()):
qelem = finat.ufl.TensorElement(qelem, shape=felem.value_shape(Vf.mesh()), symmetry=felem.symmetry())
if Vf.value_shape:
qelem = finat.ufl.TensorElement(qelem, shape=Vf.value_shape, symmetry=felem.symmetry())
Qf = firedrake.FunctionSpace(Vf.mesh(), qelem)
mapping_output = make_mapping_code(Qf, cmapping, fmapping, "t0", "t1")

Expand Down
4 changes: 2 additions & 2 deletions firedrake/pyplot/pgf.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,12 @@ def pgfplot(f, filename, degree=1, complex_component='real', print_latex_example
raise NotImplementedError(f"Not yet implemented for functions in spatial dimension {dim}")
if mesh.extruded:
raise NotImplementedError("Not yet implemented for functions on extruded meshes")
if elem.value_shape(mesh):
if V.value_shape:
raise NotImplementedError("Currently only implemeted for scalar functions")
coordelem = get_embedding_dg_element(mesh.coordinates.function_space().ufl_element(), (dim, )).reconstruct(degree=degree, variant="equispaced")
coordV = FunctionSpace(mesh, coordelem)
coords = Function(coordV).interpolate(SpatialCoordinate(mesh))
elemdg = get_embedding_dg_element(elem, elem.value_shape(mesh)).reconstruct(degree=degree, variant="equispaced")
elemdg = get_embedding_dg_element(elem, V.value_shape).reconstruct(degree=degree, variant="equispaced")
Vdg = FunctionSpace(mesh, elemdg)
fdg = Function(Vdg)
method = get_embedding_method_for_checkpointing(elem)
Expand Down
2 changes: 1 addition & 1 deletion firedrake/slate/static_condensation/hybridization.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def initialize(self, pc):
if len(V) != 2:
raise ValueError("Expecting two function spaces.")

if all(Vi.ufl_element().value_shape(Vi.mesh()) for Vi in V):
if all(Vi.value_shape for Vi in V):
raise ValueError("Expecting an H(div) x L2 pair of spaces.")

# Automagically determine which spaces are vector and scalar
Expand Down
6 changes: 2 additions & 4 deletions firedrake/ufl_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ def reconstruct(self, function_space=None,
return self
if not isinstance(number, int):
raise TypeError(f"Expecting an int, not {number}")
mesh = self.function_space().mesh()
if function_space.ufl_element().value_shape(mesh) != self.ufl_element().value_shape(mesh):
if function_space.value_shape != self.function_space().value_shape:
raise ValueError("Cannot reconstruct an Argument with a different value shape.")
return Argument(function_space, number, part=part)

Expand Down Expand Up @@ -141,8 +140,7 @@ def reconstruct(self, function_space=None,
return self
if not isinstance(number, int):
raise TypeError(f"Expecting an int, not {number}")
mesh = self.function_space().mesh()
if function_space.ufl_element().value_shape(mesh) != self.ufl_element().value_shape(mesh):
if function_space.value_shape != self.function_space().value_shape:
raise ValueError("Cannot reconstruct an Coargument with a different value shape.")
return Coargument(function_space, number, part=part)

Expand Down
2 changes: 1 addition & 1 deletion tests/output/test_io_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _get_mesh(cell_type, comm):
def _get_expr(V):
mesh = V.mesh()
dim = mesh.geometric_dimension()
shape = V.ufl_element().value_shape(mesh)
shape = V.value_shape
if dim == 2:
x, y = SpatialCoordinate(mesh)
z = x * y
Expand Down
2 changes: 1 addition & 1 deletion tests/output/test_io_timestepping.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _get_expr(V, i):
mesh = V.mesh()
element = V.ufl_element()
x, y = SpatialCoordinate(mesh)
shape = element.value_shape(mesh)
shape = V.value_shape
if element.family() == "Real":
return 7. + i * i
elif shape == ():
Expand Down
6 changes: 3 additions & 3 deletions tests/regression/test_ensembleparallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,13 @@ def test_ensemble_reduce(ensemble, mesh, W, urank, urank_sum, root, blocking):
parallel_assert(
lambda: error < 1e-12,
subset=root_ranks,
msg=f"{error = :.5f}"
msg=f"{error=:.5f}"
)
error = errornorm(Function(W).assign(10), u_reduce)
parallel_assert(
lambda: error < 1e-12,
subset={range(COMM_WORLD.size)} - root_ranks,
msg=f"{error = :.5f}"
msg=f"{error=:.5f}"
)

# check that u_reduce dat vector is still synchronised
Expand Down Expand Up @@ -347,7 +347,7 @@ def test_send_and_recv(ensemble, mesh, W, blocking):
parallel_assert(
lambda: error < 1e-12,
subset=root_ranks,
msg=f"{error = :.5f}"
msg=f"{error=:.5f}"
)


Expand Down
Loading

0 comments on commit adb93f4

Please sign in to comment.