Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get value_shape from FunctionSpace #3862

Merged
merged 15 commits into from
Nov 15, 2024
3 changes: 3 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ jobs:
--install defcon \
--install gadopt \
--install asQ \
--package-branch tsfc mscroggs/gdim \
--package-branch finat mscroggs/gdim \
--package-branch ufl ksagiyam/merge_upstream \
|| (cat firedrake-install.log && /bin/false)
- name: Install test dependencies
run: |
Expand Down
6 changes: 3 additions & 3 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1628,7 +1628,7 @@ def _integral_type(self):

@cached_property
def _mesh(self):
return self._form.ufl_domains()[self._kinfo.domain_number]
pbrubeck marked this conversation as resolved.
Show resolved Hide resolved
return tuple(self._form.ufl_domains())[self._kinfo.domain_number]

@cached_property
def _needs_subset(self):
Expand Down Expand Up @@ -1749,7 +1749,7 @@ def _as_global_kernel_arg_coefficient(_, self):

ufl_element = V.ufl_element()
if ufl_element.family() == "Real":
return op2.GlobalKernelArg((ufl_element.value_size,))
return op2.GlobalKernelArg((V.value_size,))
else:
return self._make_dat_global_kernel_arg(V, index=index)

Expand Down Expand Up @@ -1970,7 +1970,7 @@ def _indexed_function_spaces(self):

@cached_property
def _mesh(self):
return self._form.ufl_domains()[self._kinfo.domain_number]
return tuple(self._form.ufl_domains())[self._kinfo.domain_number]

@cached_property
def _iterset(self):
Expand Down
11 changes: 6 additions & 5 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,28 +342,29 @@ def function_arg(self, g):
del self._function_arg_update
except AttributeError:
pass
V = self.function_space()
if isinstance(g, firedrake.Function) and g.ufl_element().family() != "Real":
if g.function_space() != self.function_space():
if g.function_space() != V:
raise RuntimeError("%r is defined on incompatible FunctionSpace!" % g)
self._function_arg = g
elif isinstance(g, ufl.classes.Zero):
if g.ufl_shape and g.ufl_shape != self.function_space().ufl_element().value_shape:
if g.ufl_shape and g.ufl_shape != V.value_shape:
raise ValueError(f"Provided boundary value {g} does not match shape of space")
# Special case. Scalar zero for direct Function.assign.
self._function_arg = g
elif isinstance(g, ufl.classes.Expr):
if g.ufl_shape != self.function_space().ufl_element().value_shape:
if g.ufl_shape != V.value_shape:
raise RuntimeError(f"Provided boundary value {g} does not match shape of space")
try:
self._function_arg = firedrake.Function(self.function_space())
self._function_arg = firedrake.Function(V)
# Use `Interpolator` instead of assembling an `Interpolate` form
# as the expression compilation needs to happen at this stage to
# determine if we should use interpolation or projection
# -> e.g. interpolation may not be supported for the element.
self._function_arg_update = firedrake.Interpolator(g, self._function_arg)._interpolate
except (NotImplementedError, AttributeError):
# Element doesn't implement interpolation
self._function_arg = firedrake.Function(self.function_space()).project(g)
self._function_arg = firedrake.Function(V).project(g)
self._function_arg_update = firedrake.Projector(g, self._function_arg).project
else:
try:
Expand Down
7 changes: 3 additions & 4 deletions firedrake/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ def save_function(self, f, idx=None, name=None, timestepping_info={}):
self._update_function_name_function_space_name_map(tmesh.name, mesh.name, {f.name(): V_name})
# Embed if necessary
element = V.ufl_element()
_element = get_embedding_element_for_checkpointing(element)
_element = get_embedding_element_for_checkpointing(element, V.value_shape)
if _element != element:
path = self._path_to_function_embedded(tmesh.name, mesh.name, V_name, f.name())
self.require_group(path)
Expand Down Expand Up @@ -1337,7 +1337,7 @@ def load_function(self, mesh, name, idx=None):
_name = self.get_attr(path, PREFIX_EMBEDDED + "_function")
_f = self.load_function(mesh, _name, idx=idx)
element = V.ufl_element()
_element = get_embedding_element_for_checkpointing(element)
_element = get_embedding_element_for_checkpointing(element, V.value_shape)
method = get_embedding_method_for_checkpointing(element)
assert _element == _f.function_space().ufl_element()
f = Function(V, name=name)
Expand Down Expand Up @@ -1436,8 +1436,7 @@ def _get_shared_data_key_for_checkpointing(self, mesh, ufl_element):
shape = ufl_element.reference_value_shape
block_size = np.prod(shape)
elif isinstance(ufl_element, finat.ufl.VectorElement):
shape = ufl_element.value_shape[:1]
block_size = np.prod(shape)
block_size = ufl_element.reference_value_shape[0]
else:
block_size = 1
return (nodes_per_entity, real_tensorproduct, block_size)
Expand Down
2 changes: 1 addition & 1 deletion firedrake/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __new__(cls, value, domain=None, name=None, count=None):

if not isinstance(domain, ufl.AbstractDomain):
cell = ufl.as_cell(domain)
coordinate_element = finat.ufl.VectorElement("Lagrange", cell, 1, gdim=cell.geometric_dimension)
pbrubeck marked this conversation as resolved.
Show resolved Hide resolved
coordinate_element = finat.ufl.VectorElement("Lagrange", cell, 1, dim=cell.topological_dimension())
domain = ufl.Mesh(coordinate_element)

cell = domain.ufl_cell()
Expand Down
8 changes: 4 additions & 4 deletions firedrake/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import ufl


def get_embedding_dg_element(element, broken_cg=False):
def get_embedding_dg_element(element, value_shape, broken_cg=False):
cell = element.cell
family = lambda c: "DG" if c.is_simplex() else "DQ"
if isinstance(cell, ufl.TensorProductCell):
Expand All @@ -19,7 +19,7 @@ def get_embedding_dg_element(element, broken_cg=False):
scalar_element = finat.ufl.FiniteElement(family(cell), cell=cell, degree=degree)
if broken_cg:
scalar_element = finat.ufl.BrokenElement(scalar_element.reconstruct(family="Lagrange"))
shape = element.value_shape
shape = value_shape
if len(shape) == 0:
DG = scalar_element
elif len(shape) == 1:
Expand All @@ -37,12 +37,12 @@ def get_embedding_dg_element(element, broken_cg=False):
native_elements_for_checkpointing = {"Lagrange", "Discontinuous Lagrange", "Q", "DQ", "Real"}


def get_embedding_element_for_checkpointing(element):
def get_embedding_element_for_checkpointing(element, value_shape):
"""Convert the given UFL element to an element that :class:`~.CheckpointFile` can handle."""
if element.family() in native_elements_for_checkpointing:
return element
else:
return get_embedding_dg_element(element)
return get_embedding_dg_element(element, value_shape)


def get_embedding_method_for_checkpointing(element):
Expand Down
2 changes: 1 addition & 1 deletion firedrake/external_operators/point_expr_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, *operands, function_space, derivatives=None, argument_slots=(
if not isinstance(operator_data["func"], types.FunctionType):
raise TypeError("Expecting a FunctionType pointwise expression")
expr_shape = operator_data["func"](*operands).ufl_shape
if expr_shape != function_space.ufl_element().value_shape:
if expr_shape != function_space.value_shape:
raise ValueError("The dimension does not match with the dimension of the function space %s" % function_space)

@property
Expand Down
3 changes: 1 addition & 2 deletions firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,7 @@ def argument(self, o):
args += [a_[j] for j in numpy.ndindex(a_.ufl_shape)]
else:
args += [Zero()
for j in numpy.ndindex(
V_is[i].ufl_element().value_shape)]
for j in numpy.ndindex(V_is[i].value_shape)]
return self._arg_cache.setdefault(o, as_vector(args))


Expand Down
2 changes: 1 addition & 1 deletion firedrake/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ def at(self, arg, *args, **kwargs):
raise NotImplementedError("Point evaluation not implemented for variable layers")

# Validate geometric dimension
gdim = mesh.ufl_cell().geometric_dimension()
gdim = mesh.geometric_dimension()
if arg.shape[-1] == gdim:
pass
elif len(arg.shape) == 1 and gdim == 1:
Expand Down
4 changes: 2 additions & 2 deletions firedrake/functionspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def VectorFunctionSpace(mesh, family, degree=None, dim=None,
"""
sub_element = make_scalar_element(mesh, family, degree, vfamily, vdegree, variant)
if dim is None:
dim = mesh.ufl_cell().geometric_dimension()
dim = mesh.geometric_dimension()
if not isinstance(dim, numbers.Integral) and dim > 0:
raise ValueError(f"Can't make VectorFunctionSpace with dim={dim}")
element = finat.ufl.VectorElement(sub_element, dim=dim)
Expand Down Expand Up @@ -237,7 +237,7 @@ def TensorFunctionSpace(mesh, family, degree=None, shape=None,

"""
sub_element = make_scalar_element(mesh, family, degree, vfamily, vdegree, variant)
shape = shape or (mesh.ufl_cell().geometric_dimension(),) * 2
shape = shape or (mesh.geometric_dimension(),) * 2
element = finat.ufl.TensorElement(sub_element, shape=shape, symmetry=symmetry)
return FunctionSpace(mesh, element, name=name)

Expand Down
32 changes: 16 additions & 16 deletions firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,13 @@ def split(self):
def _components(self):
if len(self) == 1:
return tuple(type(self).create(self.topological.sub(i), self.mesh())
for i in range(self.value_size))
for i in range(self.block_size))
else:
return self.subfunctions

@PETSc.Log.EventDecorator()
def sub(self, i):
if len(self) == 1:
bound = self.value_size
else:
bound = len(self)
bound = len(self._components)
if i < 0 or i >= bound:
raise IndexError("Invalid component %d, not in [0, %d)" % (i, bound))
return self._components[i]
Expand Down Expand Up @@ -489,14 +486,17 @@ def __init__(self, mesh, element, name=None):
shape_element = element
if isinstance(element, finat.ufl.WithMapping):
shape_element = element.wrapee
sub = shape_element.sub_elements[0].value_shape
sub = shape_element.sub_elements[0].reference_value_shape
self.shape = rvs[:len(rvs) - len(sub)]
else:
self.shape = ()
self._label = ""
self._ufl_function_space = ufl.FunctionSpace(mesh.ufl_mesh(), element, label=self._label)
self._mesh = mesh

self.value_size = self._ufl_function_space.value_size
r"""The number of scalar components of this :class:`FunctionSpace`."""

self.rank = len(self.shape)
r"""The rank of this :class:`FunctionSpace`. Spaces where the
element is scalar-valued (or intrinsically vector-valued) have
Expand All @@ -505,7 +505,7 @@ def __init__(self, mesh, element, name=None):
the number of components of their
:attr:`finat.ufl.finiteelementbase.FiniteElementBase.value_shape`."""

self.value_size = int(numpy.prod(self.shape, dtype=int))
self.block_size = int(numpy.prod(self.shape, dtype=int))
r"""The total number of degrees of freedom at each function
space node."""
self.name = name
Expand Down Expand Up @@ -654,7 +654,7 @@ def __getitem__(self, i):

@utils.cached_property
def _components(self):
return tuple(ComponentFunctionSpace(self, i) for i in range(self.value_size))
return tuple(ComponentFunctionSpace(self, i) for i in range(self.block_size))

def sub(self, i):
r"""Return a view into the ith component."""
Expand Down Expand Up @@ -684,7 +684,7 @@ def node_count(self):
def dof_count(self):
r"""The number of degrees of freedom (includes halo dofs) of this
function space on this process. Cf. :attr:`FunctionSpace.node_count` ."""
return self.node_count*self.value_size
return self.node_count*self.block_size

def dim(self):
r"""The global number of degrees of freedom for this function space.
Expand Down Expand Up @@ -821,7 +821,7 @@ def local_to_global_map(self, bcs, lgmap=None):
else:
indices = lgmap.block_indices.copy()
bsize = lgmap.getBlockSize()
assert bsize == self.value_size
assert bsize == self.block_size
else:
# MatBlock case, LGMap is already unrolled.
indices = lgmap.block_indices.copy()
Expand All @@ -830,11 +830,11 @@ def local_to_global_map(self, bcs, lgmap=None):
nodes = []
for bc in bcs:
if bc.function_space().component is not None:
nodes.append(bc.nodes * self.value_size
nodes.append(bc.nodes * self.block_size
+ bc.function_space().component)
elif unblocked:
tmp = bc.nodes * self.value_size
for i in range(self.value_size):
tmp = bc.nodes * self.block_size
for i in range(self.block_size):
nodes.append(tmp + i)
else:
nodes.append(bc.nodes)
Expand Down Expand Up @@ -1300,9 +1300,9 @@ def ComponentFunctionSpace(parent, component):
"""
element = parent.ufl_element()
assert type(element) in frozenset([finat.ufl.VectorElement, finat.ufl.TensorElement])
if not (0 <= component < parent.value_size):
if not (0 <= component < parent.block_size):
raise IndexError("Invalid component %d. not in [0, %d)" %
(component, parent.value_size))
(component, parent.block_size))
new = ProxyFunctionSpace(parent.mesh(), element.sub_elements[0], name=parent.name)
new.identifier = "component"
new.component = component
Expand Down Expand Up @@ -1346,7 +1346,7 @@ def make_dof_dset(self):
def make_dat(self, val=None, valuetype=None, name=None):
r"""Return a newly allocated :class:`pyop2.types.glob.Global` representing the
data for a :class:`.Function` on this space."""
return op2.Global(self.value_size, val, valuetype, name, self._comm)
return op2.Global(self.block_size, val, valuetype, name, self._comm)

def entity_node_map(self, source_mesh, source_integral_type, source_subdomain_id, source_all_integer_subdomain_ids):
return None
Expand Down
14 changes: 7 additions & 7 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def __init__(
# VectorFunctionSpace equivalent is built from the scalar
# sub-element.
ufl_scalar_element = ufl_scalar_element.sub_elements[0]
if ufl_scalar_element.value_shape != ():
if ufl_scalar_element.reference_value_shape != ():
raise NotImplementedError(
"Can't yet cross-mesh interpolate onto function spaces made from VectorElements or TensorElements made from sub elements with value shape other than ()."
)
Expand Down Expand Up @@ -614,7 +614,7 @@ def __init__(
# I first point evaluate my expression at these locations, giving a
# P0DG function on the VOM. As described in the manual, this is an
# interpolation operation.
shape = V_dest.ufl_element().value_shape
shape = V_dest.ufl_function_space().value_shape
pbrubeck marked this conversation as resolved.
Show resolved Hide resolved
if len(shape) == 0:
fs_type = firedrake.FunctionSpace
elif len(shape) == 1:
Expand Down Expand Up @@ -988,7 +988,7 @@ def callable():
else:
# Make sure we have an expression of the right length i.e. a value for
# each component in the value shape of each function space
dims = [numpy.prod(fs.ufl_element().value_shape, dtype=int)
dims = [numpy.prod(fs.value_shape, dtype=int)
for fs in V]
loops = []
if numpy.prod(expr.ufl_shape, dtype=int) != sum(dims):
Expand Down Expand Up @@ -1024,13 +1024,13 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
if access is op2.READ:
raise ValueError("Can't have READ access for output function")

if len(expr.ufl_shape) != len(V.ufl_element().value_shape):
if len(expr.ufl_shape) != len(V.value_shape):
raise RuntimeError('Rank mismatch: Expression rank %d, FunctionSpace rank %d'
% (len(expr.ufl_shape), len(V.ufl_element().value_shape)))
% (len(expr.ufl_shape), len(V.value_shape)))

if expr.ufl_shape != V.ufl_element().value_shape:
if expr.ufl_shape != V.value_shape:
raise RuntimeError('Shape mismatch: Expression shape %r, FunctionSpace shape %r'
% (expr.ufl_shape, V.ufl_element().value_shape))
% (expr.ufl_shape, V.value_shape))

# NOTE: The par_loop is always over the target mesh cells.
target_mesh = as_domain(V)
Expand Down
Loading
Loading