diff --git a/festim/__init__.py b/festim/__init__.py index 853092014..be0baf70c 100644 --- a/festim/__init__.py +++ b/festim/__init__.py @@ -28,11 +28,15 @@ from .hydrogen_transport_problem import HydrogenTransportProblem +from .settings import Settings + from .species import Species, Trap, ImplicitSpecies from .subdomain.surface_subdomain import SurfaceSubdomain1D from .subdomain.volume_subdomain import VolumeSubdomain1D +from .stepsize import Stepsize + from .exports.vtx import VTXExport from .exports.xdmf import XDMFExport diff --git a/festim/hydrogen_transport_problem.py b/festim/hydrogen_transport_problem.py index 581b5ed35..63b3cf37c 100644 --- a/festim/hydrogen_transport_problem.py +++ b/festim/hydrogen_transport_problem.py @@ -78,7 +78,7 @@ def __init__( temperature=None, sources=[], boundary_conditions=[], - solver_parameters=None, + settings=None, exports=[], ) -> None: self.mesh = mesh @@ -87,7 +87,7 @@ def __init__( self.temperature = temperature self.sources = sources self.boundary_conditions = boundary_conditions - self.solver_parameters = solver_parameters + self.settings = settings self.exports = exports self.dx = None @@ -116,6 +116,7 @@ def initialise(self): self.assign_functions_to_species() self.t = fem.Constant(self.mesh.mesh, 0.0) + self.dt = self.settings.stepsize.get_dt(self.mesh.mesh) self.define_boundary_conditions() self.create_formulation() @@ -224,11 +225,6 @@ def create_formulation(self): if len(self.species) > 1: raise NotImplementedError("Multiple species not implemented yet") - # TODO expose dt as parameter of the model - dt = fem.Constant(self.mesh.mesh, 1 / 20) - - self.dt = dt # TODO remove this - self.formulation = 0 for spe in self.species: @@ -242,7 +238,7 @@ def create_formulation(self): ) self.formulation += dot(D * grad(u), grad(v)) * self.dx(vol.id) - self.formulation += ((u - u_n) / dt) * v * self.dx(vol.id) + self.formulation += ((u - u_n) / self.dt) * v * self.dx(vol.id) # add sources # TODO implement this @@ -264,15 +260,14 @@ def create_solver(self): self.species[0].solution, bcs=self.bc_forms, ) - solver = NewtonSolver(MPI.COMM_WORLD, problem) - self.solver = solver + self.solver = NewtonSolver(MPI.COMM_WORLD, problem) + self.solver.atol = self.settings.atol + self.solver.rtol = self.settings.rtol + self.solver.max_it = self.settings.max_iterations - def run(self, final_time: float): + def run(self): """Runs the model for a given time - Args: - final_time (float): the final time of the simulation - Returns: list of float: the times of the simulation list of float: the fluxes of the simulation @@ -285,9 +280,11 @@ def run(self, final_time: float): ) cm = self.species[0].solution progress = tqdm.autonotebook.tqdm( - desc="Solving H transport problem", total=final_time + desc="Solving H transport problem", + total=self.settings.final_time, + unit_scale=True, ) - while self.t.value < final_time: + while self.t.value < self.settings.final_time: progress.update(self.dt.value) self.t.value += self.dt.value diff --git a/festim/settings.py b/festim/settings.py new file mode 100644 index 000000000..fe8e1aa4e --- /dev/null +++ b/festim/settings.py @@ -0,0 +1,53 @@ +import festim as F + + +class Settings: + """Settings for a festim simulation. + + Args: + atol (float): Absolute tolerance for the solver. + rtol (float): Relative tolerance for the solver. + max_iterations (int, optional): Maximum number of iterations for the + solver. Defaults to 30. + final_time (float, optional): Final time for a transient simulation. + Defaults to None + stepsize (festim.Stepsize, optional): stepsize for a transient + simulation. Defaults to None + + Attributes: + atol (float): Absolute tolerance for the solver. + rtol (float): Relative tolerance for the solver. + max_iterations (int): Maximum number of iterations for the solver. + final_time (float): Final time for a transient simulation. + stepsize (festim.Stepsize): stepsize for a transient + simulation. + """ + + def __init__( + self, + atol, + rtol, + max_iterations=30, + final_time=None, + stepsize=None, + ) -> None: + self.atol = atol + self.rtol = rtol + self.max_iterations = max_iterations + self.final_time = final_time + self.stepsize = stepsize + + @property + def stepsize(self): + return self._stepsize + + @stepsize.setter + def stepsize(self, value): + if value is None: + self._stepsize = None + elif isinstance(value, (float, int)): + self._stepsize = F.Stepsize(initial_value=value) + elif isinstance(value, F.Stepsize): + self._stepsize = value + else: + raise TypeError("stepsize must be an of type int, float or festim.Stepsize") diff --git a/festim/stepsize.py b/festim/stepsize.py new file mode 100644 index 000000000..2dae7071c --- /dev/null +++ b/festim/stepsize.py @@ -0,0 +1,28 @@ +import festim as F + + +class Stepsize: + """ + A class for evaluating the stepsize of transient simulations. + + Args: + initial_value (float, int): initial stepsize. + + Attributes: + initial_value (float, int): initial stepsize. + """ + + def __init__( + self, + initial_value, + ) -> None: + self.initial_value = initial_value + + def get_dt(self, mesh): + """Defines the dt value + Args: + mesh (dolfinx.mesh.Mesh): the domain mesh + Returns: + fem.Constant: the dt value + """ + return F.as_fenics_constant(self.initial_value, mesh) diff --git a/test/benchmark.py b/test/benchmark.py index 4aade08f4..e86360126 100644 --- a/test/benchmark.py +++ b/test/benchmark.py @@ -136,7 +136,7 @@ def siverts_law(T, S_0, E_S, pressure): times = [] t = 0 progress = tqdm.autonotebook.tqdm( - desc="Solving H transport problem", total=final_time + desc="Solving H transport problem", total=final_time, unit_scale=True ) while t < final_time: progress.update(float(dt)) diff --git a/test/test_dirichlet_bc.py b/test/test_dirichlet_bc.py index cb299ac39..48e7ac36f 100644 --- a/test/test_dirichlet_bc.py +++ b/test/test_dirichlet_bc.py @@ -281,13 +281,16 @@ def test_integration_with_HTransportProblem(value): my_model.temperature = fem.Constant(my_model.mesh.mesh, 550.0) + my_model.settings = F.Settings(atol=1, rtol=0.1, final_time=2) + my_model.settings.stepsize = F.Stepsize(initial_value=1) + # RUN my_model.initialise() assert my_bc.value_fenics is not None - my_model.run(final_time=2) + my_model.run() # TEST diff --git a/test/test_permeation_problem.py b/test/test_permeation_problem.py index ea2ea041a..2911cf753 100644 --- a/test/test_permeation_problem.py +++ b/test/test_permeation_problem.py @@ -35,13 +35,18 @@ def test_permeation_problem(mesh_size=1001): ] my_model.exports = [F.XDMFExport("mobile_concentration.xdmf", field=mobile_H)] + my_model.settings = F.Settings( + atol=1e10, + rtol=1e-10, + max_iterations=30, + final_time=50, + ) + + my_model.settings.stepsize = F.Stepsize(initial_value=1 / 20) + my_model.initialise() my_model.solver.convergence_criterion = "incremental" - my_model.solver.rtol = 1e-10 - my_model.solver.atol = 1e10 - - my_model.solver.report = True ksp = my_model.solver.krylov_solver opts = PETSc.Options() option_prefix = ksp.getOptionsPrefix() @@ -50,9 +55,7 @@ def test_permeation_problem(mesh_size=1001): opts[f"{option_prefix}pc_factor_mat_solver_type"] = "mumps" ksp.setFromOptions() - final_time = 50 - - times, flux_values = my_model.run(final_time=final_time) + times, flux_values = my_model.run() # -------------------------- analytical solution ------------------------------------- @@ -127,13 +130,18 @@ def test_permeation_problem_multi_volume(): ] my_model.exports = [F.VTXExport("test.bp", field=mobile_H)] + my_model.settings = F.Settings( + atol=1e10, + rtol=1e-10, + max_iterations=30, + final_time=50, + ) + + my_model.settings.stepsize = F.Stepsize(initial_value=1 / 20) + my_model.initialise() my_model.solver.convergence_criterion = "incremental" - my_model.solver.rtol = 1e-10 - my_model.solver.atol = 1e10 - - my_model.solver.report = True ksp = my_model.solver.krylov_solver opts = PETSc.Options() option_prefix = ksp.getOptionsPrefix() @@ -142,9 +150,7 @@ def test_permeation_problem_multi_volume(): opts[f"{option_prefix}pc_factor_mat_solver_type"] = "mumps" ksp.setFromOptions() - final_time = 50 - - times, flux_values = my_model.run(final_time=final_time) + times, flux_values = my_model.run() # -------------------------- analytical solution ------------------------------------- D = my_mat.get_diffusion_coefficient(my_mesh.mesh, temperature) diff --git a/test/test_settings.py b/test/test_settings.py new file mode 100644 index 000000000..5ffeeaa60 --- /dev/null +++ b/test/test_settings.py @@ -0,0 +1,22 @@ +import festim as F +import numpy as np +import pytest + + +@pytest.mark.parametrize("test_type", [int, F.Stepsize, float]) +def test_stepsize_value(test_type): + """Test that the stepsize is correctly set""" + test_value = 23.0 + my_settings = F.Settings(atol=1, rtol=0.1) + my_settings.stepsize = test_type(test_value) + + assert isinstance(my_settings.stepsize, F.Stepsize) + assert np.isclose(my_settings.stepsize.initial_value, test_value) + + +def test_stepsize_value_wrong_type(): + """Checks that an error is raised when the wrong type is given""" + my_settings = F.Settings(atol=1, rtol=0.1) + + with pytest.raises(TypeError): + my_settings.stepsize = "coucou" diff --git a/test/test_sievertsbc.py b/test/test_sievertsbc.py index b1fd252df..2e73780d5 100644 --- a/test/test_sievertsbc.py +++ b/test/test_sievertsbc.py @@ -93,10 +93,13 @@ def test_integration_with_HTransportProblem(pressure): my_model.temperature = fem.Constant(my_model.mesh.mesh, 550.0) + my_model.settings = F.Settings(atol=1, rtol=0.1, final_time=2) + my_model.settings.stepsize = F.Stepsize(initial_value=1) + # RUN my_model.initialise() assert my_bc.value_fenics is not None - my_model.run(final_time=2) + my_model.run() diff --git a/test/test_vtx.py b/test/test_vtx.py index 71634720f..15c7ded6e 100644 --- a/test/test_vtx.py +++ b/test/test_vtx.py @@ -60,6 +60,8 @@ def test_vtx_integration_with_h_transport_problem(tmpdir): filename = str(tmpdir.join("my_export.bp")) my_export = F.VTXExport(filename, field=my_model.species[0]) my_model.exports = [my_export] + my_model.settings = F.Settings(atol=1, rtol=0.1) + my_model.settings.stepsize = F.Stepsize(initial_value=1) my_model.initialise() diff --git a/test/test_xdmf.py b/test/test_xdmf.py index 154967d87..cc1e57a25 100644 --- a/test/test_xdmf.py +++ b/test/test_xdmf.py @@ -47,8 +47,10 @@ def test_integration_with_HTransportProblem(tmp_path): filename = os.path.join(tmp_path, "test.xdmf") my_model.exports = [F.XDMFExport(filename=filename, field=my_model.species)] + my_model.settings = F.Settings(atol=1, rtol=0.1, final_time=1) + my_model.settings.stepsize = F.Stepsize(initial_value=0.5) my_model.initialise() - my_model.run(1) + my_model.run() # checks that filename exists assert os.path.exists(filename)