From e0429bb44552180382125819ab72037ca674c6db Mon Sep 17 00:00:00 2001 From: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> Date: Fri, 10 May 2024 14:07:01 +0200 Subject: [PATCH] PINN variants addition and Solvers Update (#263) * gpinn/basepinn new classes, pinn restructure * codacy fix gpinn/basepinn/pinn * inverse problem fix * Causal PINN (#267) * fix GPU training in inverse problem (#283) * Create a `compute_residual` attribute for `PINNInterface` * Modify dataloading in solvers (#286) * Modify PINNInterface by removing _loss_phys, _loss_data * Adding in PINNInterface a variable to track the current condition during training * Modify GPINN,PINN,CausalPINN to match changes in PINNInterface * Competitive Pinn Addition (#288) * fixing after rebase/ fix loss * fixing final issues --------- Co-authored-by: Dario Coscia * Modify min max formulation to max min for paper consistency * Adding SAPINN solver (#291) * rom solver * fix import --------- Co-authored-by: Dario Coscia Co-authored-by: Anna Ivagnes <75523024+annaivagnes@users.noreply.github.com> Co-authored-by: valc89 <103250118+valc89@users.noreply.github.com> Co-authored-by: Monthly Tag bot Co-authored-by: Nicola Demo --- docs/source/_rst/_code.rst | 6 + docs/source/_rst/solvers/basepinn.rst | 7 + docs/source/_rst/solvers/causalpinn.rst | 7 + docs/source/_rst/solvers/competitivepinn.rst | 7 + docs/source/_rst/solvers/gpinn.rst | 7 + docs/source/_rst/solvers/pinn.rst | 2 +- docs/source/_rst/solvers/rom.rst | 7 + docs/source/_rst/solvers/sapinn.rst | 7 + pina/model/avno.py | 4 +- pina/solvers/__init__.py | 21 +- pina/solvers/garom.py | 9 +- pina/solvers/pinn.py | 232 --------- pina/solvers/pinns/__init__.py | 15 + pina/solvers/pinns/basepinn.py | 247 ++++++++++ pina/solvers/pinns/causalpinn.py | 221 +++++++++ pina/solvers/pinns/competitive_pinn.py | 360 ++++++++++++++ pina/solvers/pinns/gpinn.py | 134 +++++ pina/solvers/pinns/pinn.py | 170 +++++++ pina/solvers/pinns/sapinn.py | 494 +++++++++++++++++++ pina/solvers/rom.py | 190 +++++++ pina/solvers/solver.py | 15 + pina/solvers/supervised.py | 54 +- pina/trainer.py | 7 + tests/test_solvers/test_causalpinn.py | 266 ++++++++++ tests/test_solvers/test_competitive_pinn.py | 418 ++++++++++++++++ tests/test_solvers/test_gpinn.py | 432 ++++++++++++++++ tests/test_solvers/test_pinn.py | 315 ++++++++---- tests/test_solvers/test_rom_solver.py | 105 ++++ tests/test_solvers/test_sapinn.py | 437 ++++++++++++++++ 29 files changed, 3838 insertions(+), 358 deletions(-) create mode 100644 docs/source/_rst/solvers/basepinn.rst create mode 100644 docs/source/_rst/solvers/causalpinn.rst create mode 100644 docs/source/_rst/solvers/competitivepinn.rst create mode 100644 docs/source/_rst/solvers/gpinn.rst create mode 100644 docs/source/_rst/solvers/rom.rst create mode 100644 docs/source/_rst/solvers/sapinn.rst delete mode 100644 pina/solvers/pinn.py create mode 100644 pina/solvers/pinns/__init__.py create mode 100644 pina/solvers/pinns/basepinn.py create mode 100644 pina/solvers/pinns/causalpinn.py create mode 100644 pina/solvers/pinns/competitive_pinn.py create mode 100644 pina/solvers/pinns/gpinn.py create mode 100644 pina/solvers/pinns/pinn.py create mode 100644 pina/solvers/pinns/sapinn.py create mode 100644 pina/solvers/rom.py create mode 100644 tests/test_solvers/test_causalpinn.py create mode 100644 tests/test_solvers/test_competitive_pinn.py create mode 100644 tests/test_solvers/test_gpinn.py create mode 100644 tests/test_solvers/test_rom_solver.py create mode 100644 tests/test_solvers/test_sapinn.py diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 77072a50..d954920e 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -35,8 +35,14 @@ Solvers :titlesonly: SolverInterface + PINNInterface PINN + GPINN + CausalPINN + CompetitivePINN + SAPINN Supervised solver + ReducedOrderModelSolver GAROM diff --git a/docs/source/_rst/solvers/basepinn.rst b/docs/source/_rst/solvers/basepinn.rst new file mode 100644 index 00000000..c6507953 --- /dev/null +++ b/docs/source/_rst/solvers/basepinn.rst @@ -0,0 +1,7 @@ +PINNInterface +================= +.. currentmodule:: pina.solvers.pinns.basepinn + +.. autoclass:: PINNInterface + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/solvers/causalpinn.rst b/docs/source/_rst/solvers/causalpinn.rst new file mode 100644 index 00000000..28f7f15e --- /dev/null +++ b/docs/source/_rst/solvers/causalpinn.rst @@ -0,0 +1,7 @@ +CausalPINN +============== +.. currentmodule:: pina.solvers.pinns.causalpinn + +.. autoclass:: CausalPINN + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/solvers/competitivepinn.rst b/docs/source/_rst/solvers/competitivepinn.rst new file mode 100644 index 00000000..2bbe242b --- /dev/null +++ b/docs/source/_rst/solvers/competitivepinn.rst @@ -0,0 +1,7 @@ +CompetitivePINN +================= +.. currentmodule:: pina.solvers.pinns.competitive_pinn + +.. autoclass:: CompetitivePINN + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/solvers/gpinn.rst b/docs/source/_rst/solvers/gpinn.rst new file mode 100644 index 00000000..ee076a5d --- /dev/null +++ b/docs/source/_rst/solvers/gpinn.rst @@ -0,0 +1,7 @@ +GPINN +====== +.. currentmodule:: pina.solvers.pinns.gpinn + +.. autoclass:: GPINN + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/solvers/pinn.rst b/docs/source/_rst/solvers/pinn.rst index 3e9b2ef0..e1c2b59c 100644 --- a/docs/source/_rst/solvers/pinn.rst +++ b/docs/source/_rst/solvers/pinn.rst @@ -1,6 +1,6 @@ PINN ====== -.. currentmodule:: pina.solvers.pinn +.. currentmodule:: pina.solvers.pinns.pinn .. autoclass:: PINN :members: diff --git a/docs/source/_rst/solvers/rom.rst b/docs/source/_rst/solvers/rom.rst new file mode 100644 index 00000000..3ee534bb --- /dev/null +++ b/docs/source/_rst/solvers/rom.rst @@ -0,0 +1,7 @@ +ReducedOrderModelSolver +========================== +.. currentmodule:: pina.solvers.rom + +.. autoclass:: ReducedOrderModelSolver + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/solvers/sapinn.rst b/docs/source/_rst/solvers/sapinn.rst new file mode 100644 index 00000000..b20891ff --- /dev/null +++ b/docs/source/_rst/solvers/sapinn.rst @@ -0,0 +1,7 @@ +SAPINN +====== +.. currentmodule:: pina.solvers.pinns.sapinn + +.. autoclass:: SAPINN + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/model/avno.py b/pina/model/avno.py index 878185bc..2ac3b3f7 100644 --- a/pina/model/avno.py +++ b/pina/model/avno.py @@ -110,9 +110,9 @@ def forward(self, x): """ points_tmp = x.extract(self.coordinates_indices) new_batch = x.extract(self.field_indices) - new_batch = concatenate((new_batch, points_tmp), dim=2) + new_batch = concatenate((new_batch, points_tmp), dim=-1) new_batch = self._lifting_operator(new_batch) new_batch = self._integral_kernels(new_batch) - new_batch = concatenate((new_batch, points_tmp), dim=2) + new_batch = concatenate((new_batch, points_tmp), dim=-1) new_batch = self._projection_operator(new_batch) return new_batch diff --git a/pina/solvers/__init__.py b/pina/solvers/__init__.py index 0562dc2d..2751e481 100644 --- a/pina/solvers/__init__.py +++ b/pina/solvers/__init__.py @@ -1,6 +1,19 @@ -__all__ = ["PINN", "GAROM", "SupervisedSolver", "SolverInterface"] +__all__ = [ + "SolverInterface", + "PINNInterface", + "PINN", + "GPINN", + "CausalPINN", + "CompetitivePINN", + "SAPINN", + "SupervisedSolver", + "ReducedOrderModelSolver", + "GAROM", + ] -from .garom import GAROM -from .pinn import PINN -from .supervised import SupervisedSolver from .solver import SolverInterface +from .pinns import * +from .supervised import SupervisedSolver +from .rom import ReducedOrderModelSolver +from .garom import GAROM + diff --git a/pina/solvers/garom.py b/pina/solvers/garom.py index 08856704..d6cd6246 100644 --- a/pina/solvers/garom.py +++ b/pina/solvers/garom.py @@ -253,18 +253,11 @@ def training_step(self, batch, batch_idx): :rtype: LabelTensor """ - dataloader = self.trainer.train_dataloader condition_idx = batch["condition"] for condition_id in range(condition_idx.min(), condition_idx.max() + 1): - if sys.version_info >= (3, 8): - condition_name = dataloader.condition_names[condition_id] - else: - condition_name = dataloader.loaders.condition_names[ - condition_id - ] - + condition_name = self._dataloader.condition_names[condition_id] condition = self.problem.conditions[condition_name] pts = batch["pts"].detach() out = batch["output"] diff --git a/pina/solvers/pinn.py b/pina/solvers/pinn.py deleted file mode 100644 index 008034f3..00000000 --- a/pina/solvers/pinn.py +++ /dev/null @@ -1,232 +0,0 @@ -""" Module for PINN """ - -import torch - -try: - from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0 -except ImportError: - from torch.optim.lr_scheduler import ( - _LRScheduler as LRScheduler, - ) # torch < 2.0 - -import sys -from torch.optim.lr_scheduler import ConstantLR - -from .solver import SolverInterface -from ..label_tensor import LabelTensor -from ..utils import check_consistency -from ..loss import LossInterface -from ..problem import InverseProblem -from torch.nn.modules.loss import _Loss - -torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 - - -class PINN(SolverInterface): - """ - PINN solver class. This class implements Physics Informed Neural - Network solvers, using a user specified ``model`` to solve a specific - ``problem``. It can be used for solving both forward and inverse problems. - - .. seealso:: - - **Original reference**: Karniadakis, G. E., Kevrekidis, I. G., Lu, L., - Perdikaris, P., Wang, S., & Yang, L. (2021). - Physics-informed machine learning. Nature Reviews Physics, 3(6), 422-440. - `_. - """ - - def __init__( - self, - problem, - model, - extra_features=None, - loss=torch.nn.MSELoss(), - optimizer=torch.optim.Adam, - optimizer_kwargs={"lr": 0.001}, - scheduler=ConstantLR, - scheduler_kwargs={"factor": 1, "total_iters": 0}, - ): - """ - :param AbstractProblem problem: The formulation of the problem. - :param torch.nn.Module model: The neural network model to use. - :param torch.nn.Module loss: The loss function used as minimizer, - default :class:`torch.nn.MSELoss`. - :param torch.nn.Module extra_features: The additional input - features to use as augmented input. - :param torch.optim.Optimizer optimizer: The neural network optimizer to - use; default is :class:`torch.optim.Adam`. - :param dict optimizer_kwargs: Optimizer constructor keyword args. - :param torch.optim.LRScheduler scheduler: Learning - rate scheduler. - :param dict scheduler_kwargs: LR scheduler constructor keyword args. - """ - super().__init__( - models=[model], - problem=problem, - optimizers=[optimizer], - optimizers_kwargs=[optimizer_kwargs], - extra_features=extra_features, - ) - - # check consistency - check_consistency(scheduler, LRScheduler, subclass=True) - check_consistency(scheduler_kwargs, dict) - check_consistency(loss, (LossInterface, _Loss), subclass=False) - - # assign variables - self._scheduler = scheduler(self.optimizers[0], **scheduler_kwargs) - self._loss = loss - self._neural_net = self.models[0] - - # inverse problem handling - if isinstance(self.problem, InverseProblem): - self._params = self.problem.unknown_parameters - else: - self._params = None - - def forward(self, x): - """ - Forward pass implementation for the PINN - solver. - - :param torch.Tensor x: Input tensor. - :return: PINN solution. - :rtype: torch.Tensor - """ - return self.neural_net(x) - - def configure_optimizers(self): - """ - Optimizer configuration for the PINN - solver. - - :return: The optimizers and the schedulers - :rtype: tuple(list, list) - """ - # if the problem is an InverseProblem, add the unknown parameters - # to the parameters that the optimizer needs to optimize - if isinstance(self.problem, InverseProblem): - self.optimizers[0].add_param_group( - { - "params": [ - self._params[var] - for var in self.problem.unknown_variables - ] - } - ) - return self.optimizers, [self.scheduler] - - def _clamp_inverse_problem_params(self): - for v in self._params: - self._params[v].data.clamp_( - self.problem.unknown_parameter_domain.range_[v][0], - self.problem.unknown_parameter_domain.range_[v][1], - ) - - def _loss_data(self, input, output): - return self.loss(self.forward(input), output) - - def _loss_phys(self, samples, equation): - try: - residual = equation.residual(samples, self.forward(samples)) - except ( - TypeError - ): # this occurs when the function has three inputs, i.e. inverse problem - residual = equation.residual( - samples, self.forward(samples), self._params - ) - return self.loss( - torch.zeros_like(residual, requires_grad=True), residual - ) - - def training_step(self, batch, batch_idx): - """ - PINN solver training step. - - :param batch: The batch element in the dataloader. - :type batch: tuple - :param batch_idx: The batch index. - :type batch_idx: int - :return: The sum of the loss functions. - :rtype: LabelTensor - """ - - dataloader = self.trainer.train_dataloader - condition_losses = [] - - condition_idx = batch["condition"] - - for condition_id in range(condition_idx.min(), condition_idx.max() + 1): - - if sys.version_info >= (3, 8): - condition_name = dataloader.condition_names[condition_id] - else: - condition_name = dataloader.loaders.condition_names[ - condition_id - ] - condition = self.problem.conditions[condition_name] - pts = batch["pts"] - - if len(batch) == 2: - samples = pts[condition_idx == condition_id] - loss = self._loss_phys(samples, condition.equation) - elif len(batch) == 3: - samples = pts[condition_idx == condition_id] - ground_truth = batch["output"][condition_idx == condition_id] - loss = self._loss_data(samples, ground_truth) - else: - raise ValueError("Batch size not supported") - - # TODO for users this us hard to remember when creating a new solver, to fix in a smarter way - loss = loss.as_subclass(torch.Tensor) - - # # add condition losses and accumulate logging for each epoch - condition_losses.append(loss * condition.data_weight) - self.log( - condition_name + "_loss", - float(loss), - prog_bar=True, - logger=True, - on_epoch=True, - on_step=False, - ) - - # clamp unknown parameters of the InverseProblem to their domain ranges (if needed) - if isinstance(self.problem, InverseProblem): - self._clamp_inverse_problem_params() - - # TODO Fix the bug, tot_loss is a label tensor without labels - # we need to pass it as a torch tensor to make everything work - total_loss = sum(condition_losses) - self.log( - "mean_loss", - float(total_loss / len(condition_losses)), - prog_bar=True, - logger=True, - on_epoch=True, - on_step=False, - ) - - return total_loss - - @property - def scheduler(self): - """ - Scheduler for the PINN training. - """ - return self._scheduler - - @property - def neural_net(self): - """ - Neural network for the PINN training. - """ - return self._neural_net - - @property - def loss(self): - """ - Loss for the PINN training. - """ - return self._loss diff --git a/pina/solvers/pinns/__init__.py b/pina/solvers/pinns/__init__.py new file mode 100644 index 00000000..c8aa904c --- /dev/null +++ b/pina/solvers/pinns/__init__.py @@ -0,0 +1,15 @@ +__all__ = [ + "PINNInterface", + "PINN", + "GPINN", + "CausalPINN", + "CompetitivePINN", + "SAPINN", +] + +from .basepinn import PINNInterface +from .pinn import PINN +from .gpinn import GPINN +from .causalpinn import CausalPINN +from .competitive_pinn import CompetitivePINN +from .sapinn import SAPINN diff --git a/pina/solvers/pinns/basepinn.py b/pina/solvers/pinns/basepinn.py new file mode 100644 index 00000000..726cdf92 --- /dev/null +++ b/pina/solvers/pinns/basepinn.py @@ -0,0 +1,247 @@ +""" Module for PINN """ + +import sys +from abc import ABCMeta, abstractmethod +import torch + +from ...solvers.solver import SolverInterface +from pina.utils import check_consistency +from pina.loss import LossInterface +from pina.problem import InverseProblem +from torch.nn.modules.loss import _Loss + +torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 + +class PINNInterface(SolverInterface, metaclass=ABCMeta): + """ + Base PINN solver class. This class implements the Solver Interface + for Physics Informed Neural Network solvers. + + This class can be used to + define PINNs with multiple ``optimizers``, and/or ``models``. + By default it takes + an :class:`~pina.problem.abstract_problem.AbstractProblem`, so it is up + to the user to choose which problem the implemented solver inheriting from + this class is suitable for. + """ + + def __init__( + self, + models, + problem, + optimizers, + optimizers_kwargs, + extra_features, + loss, + ): + """ + :param models: Multiple torch neural network models instances. + :type models: list(torch.nn.Module) + :param problem: A problem definition instance. + :type problem: AbstractProblem + :param list(torch.optim.Optimizer) optimizer: A list of neural network + optimizers to use. + :param list(dict) optimizer_kwargs: A list of optimizer constructor + keyword args. + :param list(torch.nn.Module) extra_features: The additional input + features to use as augmented input. If ``None`` no extra features + are passed. If it is a list of :class:`torch.nn.Module`, + the extra feature list is passed to all models. If it is a list + of extra features' lists, each single list of extra feature + is passed to a model. + :param torch.nn.Module loss: The loss function used as minimizer, + default :class:`torch.nn.MSELoss`. + """ + super().__init__( + models=models, + problem=problem, + optimizers=optimizers, + optimizers_kwargs=optimizers_kwargs, + extra_features=extra_features, + ) + + # check consistency + check_consistency(loss, (LossInterface, _Loss), subclass=False) + + # assign variables + self._loss = loss + + # inverse problem handling + if isinstance(self.problem, InverseProblem): + self._params = self.problem.unknown_parameters + self._clamp_params = self._clamp_inverse_problem_params + else: + self._params = None + self._clamp_params = lambda : None + + # variable used internally to store residual losses at each epoch + # this variable save the residual at each iteration (not weighted) + self.__logged_res_losses = [] + + # variable used internally in pina for logging. This variable points to + # the current condition during the training step and returns the + # condition name. Whenever :meth:`store_log` is called the logged + # variable will be stored with name = self.__logged_metric + self.__logged_metric = None + + def training_step(self, batch, _): + """ + The Physics Informed Solver Training Step. This function takes care + of the physics informed training step, and it must not be override + if not intentionally. It handles the batching mechanism, the workload + division for the various conditions, the inverse problem clamping, + and loggers. + + :param tuple batch: The batch element in the dataloader. + :param int batch_idx: The batch index. + :return: The sum of the loss functions. + :rtype: LabelTensor + """ + + condition_losses = [] + condition_idx = batch["condition"] + + for condition_id in range(condition_idx.min(), condition_idx.max() + 1): + + condition_name = self._dataloader.condition_names[condition_id] + condition = self.problem.conditions[condition_name] + pts = batch["pts"] + # condition name is logged (if logs enabled) + self.__logged_metric = condition_name + + if len(batch) == 2: + samples = pts[condition_idx == condition_id] + loss = self.loss_phys(samples, condition.equation) + elif len(batch) == 3: + samples = pts[condition_idx == condition_id] + ground_truth = batch["output"][condition_idx == condition_id] + loss = self.loss_data(samples, ground_truth) + else: + raise ValueError("Batch size not supported") + + # add condition losses for each epoch + condition_losses.append(loss * condition.data_weight) + + # clamp unknown parameters in InverseProblem (if needed) + self._clamp_params() + + # total loss (must be a torch.Tensor) + total_loss = sum(condition_losses) + return total_loss.as_subclass(torch.Tensor) + + def loss_data(self, input_tensor, output_tensor): + """ + The data loss for the PINN solver. It computes the loss between + the network output against the true solution. This function + should not be override if not intentionally. + + :param LabelTensor input_tensor: The input to the neural networks. + :param LabelTensor output_tensor: The true solution to compare the + network solution. + :return: The residual loss averaged on the input coordinates + :rtype: torch.Tensor + """ + loss_value = self.loss(self.forward(input_tensor), output_tensor) + self.store_log(loss_value=float(loss_value)) + return self.loss(self.forward(input_tensor), output_tensor) + + @abstractmethod + def loss_phys(self, samples, equation): + """ + Computes the physics loss for the physics informed solver based on given + samples and equation. This method must be override by all inherited + classes and it is the core to define a new physics informed solver. + + :param LabelTensor samples: The samples to evaluate the physics loss. + :param EquationInterface equation: The governing equation + representing the physics. + :return: The physics loss calculated based on given + samples and equation. + :rtype: LabelTensor + """ + pass + + def compute_residual(self, samples, equation): + """ + Compute the residual for Physics Informed learning. This function + returns the :obj:`~pina.equation.equation.Equation` specified in the + :obj:`~pina.condition.Condition` evaluated at the ``samples`` points. + + :param LabelTensor samples: The samples to evaluate the physics loss. + :param EquationInterface equation: The governing equation + representing the physics. + :return: The residual of the neural network solution. + :rtype: LabelTensor + """ + try: + residual = equation.residual(samples, self.forward(samples)) + except ( + TypeError + ): # this occurs when the function has three inputs, i.e. inverse problem + residual = equation.residual( + samples, self.forward(samples), self._params + ) + return residual + + def store_log(self, loss_value): + """ + Stores the loss value in the logger. This function should be + called for all conditions. It automatically handles the storing + conditions names. It must be used + anytime a specific variable wants to be stored for a specific condition. + A simple example is to use the variable to store the residual. + + :param str name: The name of the loss. + :param torch.Tensor loss_value: The value of the loss. + """ + self.log( + self.__logged_metric+'_loss', + loss_value, + prog_bar=True, + logger=True, + on_epoch=True, + on_step=False, + ) + self.__logged_res_losses.append(loss_value) + + def on_train_epoch_end(self): + """ + At the end of each epoch we free the stored losses. This function + should not be override if not intentionally. + """ + if self.__logged_res_losses: + # storing mean loss + self.__logged_metric = 'mean' + self.store_log( + sum(self.__logged_res_losses)/len(self.__logged_res_losses) + ) + # free the logged losses + self.__logged_res_losses = [] + return super().on_train_epoch_end() + + def _clamp_inverse_problem_params(self): + """ + Clamps the parameters of the inverse problem + solver to the specified ranges. + """ + for v in self._params: + self._params[v].data.clamp_( + self.problem.unknown_parameter_domain.range_[v][0], + self.problem.unknown_parameter_domain.range_[v][1], + ) + + @property + def loss(self): + """ + Loss used for training. + """ + return self._loss + + @property + def current_condition_name(self): + """ + Returns the condition name. This function can be used inside the + :meth:`loss_phys` to extract the condition at which the loss is + computed. + """ + return self.__logged_metric \ No newline at end of file diff --git a/pina/solvers/pinns/causalpinn.py b/pina/solvers/pinns/causalpinn.py new file mode 100644 index 00000000..fea0fe47 --- /dev/null +++ b/pina/solvers/pinns/causalpinn.py @@ -0,0 +1,221 @@ +""" Module for CausalPINN """ + +import torch + + +from torch.optim.lr_scheduler import ConstantLR + +from .pinn import PINN +from pina.problem import TimeDependentProblem +from pina.utils import check_consistency + + +class CausalPINN(PINN): + r""" + Causal Physics Informed Neural Network (PINN) solver class. + This class implements Causal Physics Informed Neural + Network solvers, using a user specified ``model`` to solve a specific + ``problem``. It can be used for solving both forward and inverse problems. + + The Causal Physics Informed Network aims to find + the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` + of the differential problem: + + .. math:: + + \begin{cases} + \mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\ + \mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad, + \mathbf{x}\in\partial\Omega + \end{cases} + + minimizing the loss function + + .. math:: + \mathcal{L}_{\rm{problem}} = \frac{1}{N_t}\sum_{i=1}^{N_t} + \omega_{i}\mathcal{L}_r(t_i), + + where: + + .. math:: + \mathcal{L}_r(t) = \frac{1}{N}\sum_{i=1}^N + \mathcal{L}(\mathcal{A}[\mathbf{u}](\mathbf{x}_i, t)) + + \frac{1}{N}\sum_{i=1}^N + \mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i, t)) + + and, + + .. math:: + \omega_i = \exp\left(\epsilon \sum_{k=1}^{i-1}\mathcal{L}_r(t_k)\right). + + :math:`\epsilon` is an hyperparameter, default set to :math:`100`, while + :math:`\mathcal{L}` is a specific loss function, + default Mean Square Error: + + .. math:: + \mathcal{L}(v) = \| v \|^2_2. + + + .. seealso:: + + **Original reference**: Wang, Sifan, Shyam Sankaran, and Paris + Perdikaris. "Respecting causality for training physics-informed + neural networks." Computer Methods in Applied Mechanics + and Engineering 421 (2024): 116813. + DOI `10.1016 `_. + + .. note:: + This class can only work for problems inheriting + from at least + :class:`~pina.problem.timedep_problem.TimeDependentProblem` class. + """ + + def __init__( + self, + problem, + model, + extra_features=None, + loss=torch.nn.MSELoss(), + optimizer=torch.optim.Adam, + optimizer_kwargs={"lr": 0.001}, + scheduler=ConstantLR, + scheduler_kwargs={"factor": 1, "total_iters": 0}, + eps=100, + ): + """ + :param AbstractProblem problem: The formulation of the problem. + :param torch.nn.Module model: The neural network model to use. + :param torch.nn.Module loss: The loss function used as minimizer, + default :class:`torch.nn.MSELoss`. + :param torch.nn.Module extra_features: The additional input + features to use as augmented input. + :param torch.optim.Optimizer optimizer: The neural network optimizer to + use; default is :class:`torch.optim.Adam`. + :param dict optimizer_kwargs: Optimizer constructor keyword args. + :param torch.optim.LRScheduler scheduler: Learning + rate scheduler. + :param dict scheduler_kwargs: LR scheduler constructor keyword args. + :param int | float eps: The exponential decay parameter. Note that this + value is kept fixed during the training, but can be changed by means + of a callback, e.g. for annealing. + """ + super().__init__( + problem=problem, + model=model, + extra_features=extra_features, + loss=loss, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, + scheduler=scheduler, + scheduler_kwargs=scheduler_kwargs, + ) + + # checking consistency + check_consistency(eps, (int,float)) + self._eps = eps + if not isinstance(self.problem, TimeDependentProblem): + raise ValueError('Casual PINN works only for problems' + 'inheritig from TimeDependentProblem.') + + def loss_phys(self, samples, equation): + """ + Computes the physics loss for the Causal PINN solver based on given + samples and equation. + + :param LabelTensor samples: The samples to evaluate the physics loss. + :param EquationInterface equation: The governing equation + representing the physics. + :return: The physics loss calculated based on given + samples and equation. + :rtype: LabelTensor + """ + # split sequentially ordered time tensors into chunks + chunks, labels = self._split_tensor_into_chunks(samples) + # compute residuals - this correspond to ordered loss functions + # values for each time step. We apply `flatten` such that after + # concataning the residuals we obtain a tensor of shape #chunks + time_loss = [] + for chunk in chunks: + chunk.labels = labels + # classical PINN loss + residual = self.compute_residual(samples=chunk, equation=equation) + loss_val = self.loss( + torch.zeros_like(residual, requires_grad=True), residual + ) + time_loss.append(loss_val) + # store results + self.store_log(loss_value=float(sum(time_loss)/len(time_loss))) + # concatenate residuals + time_loss = torch.stack(time_loss) + # compute weights (without the gradient storing) + with torch.no_grad(): + weights = self._compute_weights(time_loss) + return (weights * time_loss).mean() + + @property + def eps(self): + """ + The exponential decay parameter. + """ + return self._eps + + @eps.setter + def eps(self, value): + """ + Setter method for the eps parameter. + + :param float value: The exponential decay parameter. + """ + check_consistency(value, float) + self._eps = value + + def _sort_label_tensor(self, tensor): + """ + Sorts the label tensor based on time variables. + + :param LabelTensor tensor: The label tensor to be sorted. + :return: The sorted label tensor based on time variables. + :rtype: LabelTensor + """ + # labels input tensors + labels = tensor.labels + # extract time tensor + time_tensor = tensor.extract(self.problem.temporal_domain.variables) + # sort the time tensors (this is very bad for GPU) + _, idx = torch.sort(time_tensor.tensor.flatten()) + tensor = tensor[idx] + tensor.labels = labels + return tensor + + def _split_tensor_into_chunks(self, tensor): + """ + Splits the label tensor into chunks based on time. + + :param LabelTensor tensor: The label tensor to be split. + :return: Tuple containing the chunks and the original labels. + :rtype: Tuple[List[LabelTensor], List] + """ + # labels input tensors + labels = tensor.labels + # labels input tensors + tensor = self._sort_label_tensor(tensor) + # extract time tensor + time_tensor = tensor.extract(self.problem.temporal_domain.variables) + # count unique tensors in time + _, idx_split = time_tensor.unique(return_counts=True) + # splitting + chunks = torch.split(tensor, tuple(idx_split)) + return chunks, labels # return chunks + + def _compute_weights(self, loss): + """ + Computes the weights for the physics loss based on the cumulative loss. + + :param LabelTensor loss: The physics loss values. + :return: The computed weights for the physics loss. + :rtype: LabelTensor + """ + # compute comulative loss and multiply by epsilos + cumulative_loss = self._eps * torch.cumsum(loss, dim=0) + # return the exponential of the weghited negative cumulative sum + return torch.exp(-cumulative_loss) diff --git a/pina/solvers/pinns/competitive_pinn.py b/pina/solvers/pinns/competitive_pinn.py new file mode 100644 index 00000000..6404c0bb --- /dev/null +++ b/pina/solvers/pinns/competitive_pinn.py @@ -0,0 +1,360 @@ +""" Module for CompetitivePINN """ + +import torch +import copy + +try: + from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0 +except ImportError: + from torch.optim.lr_scheduler import ( + _LRScheduler as LRScheduler, + ) # torch < 2.0 + +from torch.optim.lr_scheduler import ConstantLR + +from .basepinn import PINNInterface +from pina.utils import check_consistency +from pina.problem import InverseProblem + + +class CompetitivePINN(PINNInterface): + r""" + Competitive Physics Informed Neural Network (PINN) solver class. + This class implements Competitive Physics Informed Neural + Network solvers, using a user specified ``model`` to solve a specific + ``problem``. It can be used for solving both forward and inverse problems. + + The Competitive Physics Informed Network aims to find + the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` + of the differential problem: + + .. math:: + + \begin{cases} + \mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\ + \mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad, + \mathbf{x}\in\partial\Omega + \end{cases} + + with a minimization (on ``model`` parameters) maximation ( + on ``discriminator`` parameters) of the loss function + + .. math:: + \mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N + \mathcal{L}(D(\mathbf{x}_i)\mathcal{A}[\mathbf{u}](\mathbf{x}_i))+ + \frac{1}{N}\sum_{i=1}^N + \mathcal{L}(D(\mathbf{x}_i)\mathcal{B}[\mathbf{u}](\mathbf{x}_i)) + + where :math:`D` is the discriminator network, which tries to find the points + where the network performs worst, and :math:`\mathcal{L}` is a specific loss + function, default Mean Square Error: + + .. math:: + \mathcal{L}(v) = \| v \|^2_2. + + .. seealso:: + + **Original reference**: Zeng, Qi, et al. + "Competitive physics informed networks." International Conference on + Learning Representations, ICLR 2022 + `OpenReview Preprint `_. + + .. warning:: + This solver does not currently support the possibility to pass + ``extra_feature``. + """ + + def __init__( + self, + problem, + model, + discriminator=None, + loss=torch.nn.MSELoss(), + optimizer_model=torch.optim.Adam, + optimizer_model_kwargs={"lr": 0.001}, + optimizer_discriminator=torch.optim.Adam, + optimizer_discriminator_kwargs={"lr": 0.001}, + scheduler_model=ConstantLR, + scheduler_model_kwargs={"factor": 1, "total_iters": 0}, + scheduler_discriminator=ConstantLR, + scheduler_discriminator_kwargs={"factor": 1, "total_iters": 0}, + ): + """ + :param AbstractProblem problem: The formualation of the problem. + :param torch.nn.Module model: The neural network model to use + for the model. + :param torch.nn.Module discriminator: The neural network model to use + for the discriminator. If ``None``, the discriminator network will + have the same architecture as the model network. + :param torch.nn.Module loss: The loss function used as minimizer, + default :class:`torch.nn.MSELoss`. + :param torch.optim.Optimizer optimizer_model: The neural + network optimizer to use for the model network + , default is `torch.optim.Adam`. + :param dict optimizer_model_kwargs: Optimizer constructor keyword + args. for the model. + :param torch.optim.Optimizer optimizer_discriminator: The neural + network optimizer to use for the discriminator network + , default is `torch.optim.Adam`. + :param dict optimizer_discriminator_kwargs: Optimizer constructor + keyword args. for the discriminator. + :param torch.optim.LRScheduler scheduler_model: Learning + rate scheduler for the model. + :param dict scheduler_model_kwargs: LR scheduler constructor + keyword args. + :param torch.optim.LRScheduler scheduler_discriminator: Learning + rate scheduler for the discriminator. + """ + if discriminator is None: + discriminator = copy.deepcopy(model) + + super().__init__( + models=[model, discriminator], + problem=problem, + optimizers=[optimizer_model, optimizer_discriminator], + optimizers_kwargs=[ + optimizer_model_kwargs, + optimizer_discriminator_kwargs, + ], + extra_features=None, # CompetitivePINN doesn't take extra features + loss=loss + ) + + # set automatic optimization for GANs + self.automatic_optimization = False + + # check consistency + check_consistency(scheduler_model, LRScheduler, subclass=True) + check_consistency(scheduler_model_kwargs, dict) + check_consistency(scheduler_discriminator, LRScheduler, subclass=True) + check_consistency(scheduler_discriminator_kwargs, dict) + + # assign schedulers + self._schedulers = [ + scheduler_model( + self.optimizers[0], **scheduler_model_kwargs + ), + scheduler_discriminator( + self.optimizers[1], **scheduler_discriminator_kwargs + ), + ] + + self._model = self.models[0] + self._discriminator = self.models[1] + + def forward(self, x): + r""" + Forward pass implementation for the PINN solver. It returns the function + evaluation :math:`\mathbf{u}(\mathbf{x})` at the control points + :math:`\mathbf{x}`. + + :param LabelTensor x: Input tensor for the PINN solver. It expects + a tensor :math:`N \times D`, where :math:`N` the number of points + in the mesh, :math:`D` the dimension of the problem, + :return: PINN solution evaluated at contro points. + :rtype: LabelTensor + """ + return self.neural_net(x) + + def loss_phys(self, samples, equation): + """ + Computes the physics loss for the Competitive PINN solver based on given + samples and equation. + + :param LabelTensor samples: The samples to evaluate the physics loss. + :param EquationInterface equation: The governing equation + representing the physics. + :return: The physics loss calculated based on given + samples and equation. + :rtype: LabelTensor + """ + # train one step of the model + with torch.no_grad(): + discriminator_bets = self.discriminator(samples) + loss_val = self._train_model(samples, equation, discriminator_bets) + self.store_log(loss_value=float(loss_val)) + # detaching samples from the computational graph to erase it and setting + # the gradient to true to create a new computational graph. + # In alternative set `retain_graph=True`. + samples = samples.detach() + samples.requires_grad = True + # train one step of discriminator + discriminator_bets = self.discriminator(samples) + self._train_discriminator(samples, equation, discriminator_bets) + return loss_val + + def loss_data(self, input_tensor, output_tensor): + """ + The data loss for the PINN solver. It computes the loss between the + network output against the true solution. + + :param LabelTensor input_tensor: The input to the neural networks. + :param LabelTensor output_tensor: The true solution to compare the + network solution. + :return: The computed data loss. + :rtype: torch.Tensor + """ + self.optimizer_model.zero_grad() + loss_val = super().loss_data( + input_tensor, output_tensor).as_subclass(torch.Tensor) + loss_val.backward() + self.optimizer_model.step() + return loss_val + + def configure_optimizers(self): + """ + Optimizer configuration for the Competitive PINN solver. + + :return: The optimizers and the schedulers + :rtype: tuple(list, list) + """ + # if the problem is an InverseProblem, add the unknown parameters + # to the parameters that the optimizer needs to optimize + if isinstance(self.problem, InverseProblem): + self.optimizer_model.add_param_group( + { + "params": [ + self._params[var] + for var in self.problem.unknown_variables + ] + } + ) + return self.optimizers, self._schedulers + + def on_train_batch_end(self,outputs, batch, batch_idx): + """ + This method is called at the end of each training batch, and ovverides + the PytorchLightining implementation for logging the checkpoints. + + :param torch.Tensor outputs: The output from the model for the + current batch. + :param tuple batch: The current batch of data. + :param int batch_idx: The index of the current batch. + :return: Whatever is returned by the parent + method ``on_train_batch_end``. + :rtype: Any + """ + # increase by one the counter of optimization to save loggers + self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed += 1 + return super().on_train_batch_end(outputs, batch, batch_idx) + + def _train_discriminator(self, samples, equation, discriminator_bets): + """ + Trains the discriminator network of the Competitive PINN. + + :param LabelTensor samples: Input samples to evaluate the physics loss. + :param EquationInterface equation: The governing equation representing + the physics. + :param Tensor discriminator_bets: Predictions made by the discriminator + network. + """ + # manual optimization + self.optimizer_discriminator.zero_grad() + # compute residual, we detach because the weights of the generator + # model are fixed + residual = self.compute_residual(samples=samples, + equation=equation).detach() + # compute competitive residual, the minus is because we maximise + competitive_residual = residual * discriminator_bets + loss_val = - self.loss( + torch.zeros_like(competitive_residual, requires_grad=True), + competitive_residual + ).as_subclass(torch.Tensor) + # backprop + self.manual_backward(loss_val) + self.optimizer_discriminator.step() + return + + def _train_model(self, samples, equation, discriminator_bets): + """ + Trains the model network of the Competitive PINN. + + :param LabelTensor samples: Input samples to evaluate the physics loss. + :param EquationInterface equation: The governing equation representing + the physics. + :param Tensor discriminator_bets: Predictions made by the discriminator. + network. + :return: The computed data loss. + :rtype: torch.Tensor + """ + # manual optimization + self.optimizer_model.zero_grad() + # compute residual (detached for discriminator) and log + residual = self.compute_residual(samples=samples, equation=equation) + # store logging + with torch.no_grad(): + loss_residual = self.loss( + torch.zeros_like(residual), + residual + ) + # compute competitive residual, discriminator_bets are detached becase + # we optimize only the generator model + competitive_residual = residual * discriminator_bets.detach() + loss_val = self.loss( + torch.zeros_like(competitive_residual, requires_grad=True), + competitive_residual + ).as_subclass(torch.Tensor) + # backprop + self.manual_backward(loss_val) + self.optimizer_model.step() + return loss_residual + + @property + def neural_net(self): + """ + Returns the neural network model. + + :return: The neural network model. + :rtype: torch.nn.Module + """ + return self._model + + @property + def discriminator(self): + """ + Returns the discriminator model (if applicable). + + :return: The discriminator model. + :rtype: torch.nn.Module + """ + return self._discriminator + + @property + def optimizer_model(self): + """ + Returns the optimizer associated with the neural network model. + + :return: The optimizer for the neural network model. + :rtype: torch.optim.Optimizer + """ + return self.optimizers[0] + + @property + def optimizer_discriminator(self): + """ + Returns the optimizer associated with the discriminator (if applicable). + + :return: The optimizer for the discriminator. + :rtype: torch.optim.Optimizer + """ + return self.optimizers[1] + + @property + def scheduler_model(self): + """ + Returns the scheduler associated with the neural network model. + + :return: The scheduler for the neural network model. + :rtype: torch.optim.lr_scheduler._LRScheduler + """ + return self._schedulers[0] + + @property + def scheduler_discriminator(self): + """ + Returns the scheduler associated with the discriminator (if applicable). + + :return: The scheduler for the discriminator. + :rtype: torch.optim.lr_scheduler._LRScheduler + """ + return self._schedulers[1] \ No newline at end of file diff --git a/pina/solvers/pinns/gpinn.py b/pina/solvers/pinns/gpinn.py new file mode 100644 index 00000000..6eca1eac --- /dev/null +++ b/pina/solvers/pinns/gpinn.py @@ -0,0 +1,134 @@ +""" Module for GPINN """ + +import torch + + +from torch.optim.lr_scheduler import ConstantLR + +from .pinn import PINN +from pina.operators import grad +from pina.problem import SpatialProblem + + +class GPINN(PINN): + r""" + Gradient Physics Informed Neural Network (GPINN) solver class. + This class implements Gradient Physics Informed Neural + Network solvers, using a user specified ``model`` to solve a specific + ``problem``. It can be used for solving both forward and inverse problems. + + The Gradient Physics Informed Network aims to find + the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` + of the differential problem: + + .. math:: + + \begin{cases} + \mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\ + \mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad, + \mathbf{x}\in\partial\Omega + \end{cases} + + minimizing the loss function + + .. math:: + \mathcal{L}_{\rm{problem}} =& \frac{1}{N}\sum_{i=1}^N + \mathcal{L}(\mathcal{A}[\mathbf{u}](\mathbf{x}_i)) + + \frac{1}{N}\sum_{i=1}^N + \mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i)) + \\ + &\frac{1}{N}\sum_{i=1}^N + \nabla_{\mathbf{x}}\mathcal{L}(\mathcal{A}[\mathbf{u}](\mathbf{x}_i)) + + \frac{1}{N}\sum_{i=1}^N + \nabla_{\mathbf{x}}\mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i)) + + + where :math:`\mathcal{L}` is a specific loss function, default Mean Square Error: + + .. math:: + \mathcal{L}(v) = \| v \|^2_2. + + .. seealso:: + + **Original reference**: Yu, Jeremy, et al. "Gradient-enhanced + physics-informed neural networks for forward and inverse + PDE problems." Computer Methods in Applied Mechanics + and Engineering 393 (2022): 114823. + DOI: `10.1016 `_. + + .. note:: + This class can only work for problems inheriting + from at least :class:`~pina.problem.spatial_problem.SpatialProblem` + class. + """ + + def __init__( + self, + problem, + model, + extra_features=None, + loss=torch.nn.MSELoss(), + optimizer=torch.optim.Adam, + optimizer_kwargs={"lr": 0.001}, + scheduler=ConstantLR, + scheduler_kwargs={"factor": 1, "total_iters": 0}, + ): + """ + :param AbstractProblem problem: The formulation of the problem. It must + inherit from at least + :class:`~pina.problem.spatial_problem.SpatialProblem` in order to + compute the gradient of the loss. + :param torch.nn.Module model: The neural network model to use. + :param torch.nn.Module loss: The loss function used as minimizer, + default :class:`torch.nn.MSELoss`. + :param torch.nn.Module extra_features: The additional input + features to use as augmented input. + :param torch.optim.Optimizer optimizer: The neural network optimizer to + use; default is :class:`torch.optim.Adam`. + :param dict optimizer_kwargs: Optimizer constructor keyword args. + :param torch.optim.LRScheduler scheduler: Learning + rate scheduler. + :param dict scheduler_kwargs: LR scheduler constructor keyword args. + """ + super().__init__( + problem=problem, + model=model, + extra_features=extra_features, + loss=loss, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, + scheduler=scheduler, + scheduler_kwargs=scheduler_kwargs, + ) + if not isinstance(self.problem, SpatialProblem): + raise ValueError('Gradient PINN computes the gradient of the ' + 'PINN loss with respect to the spatial ' + 'coordinates, thus the PINA problem must be ' + 'a SpatialProblem.') + + + def loss_phys(self, samples, equation): + """ + Computes the physics loss for the GPINN solver based on given + samples and equation. + + :param LabelTensor samples: The samples to evaluate the physics loss. + :param EquationInterface equation: The governing equation + representing the physics. + :return: The physics loss calculated based on given + samples and equation. + :rtype: LabelTensor + """ + # classical PINN loss + residual = self.compute_residual(samples=samples, equation=equation) + loss_value = self.loss( + torch.zeros_like(residual, requires_grad=True), residual + ) + self.store_log(loss_value=float(loss_value)) + # gradient PINN loss + loss_value = loss_value.reshape(-1, 1) + loss_value.labels = ['__LOSS'] + loss_grad = grad(loss_value, samples, d=self.problem.spatial_variables) + g_loss_phys = self.loss( + torch.zeros_like(loss_grad, requires_grad=True), loss_grad + ) + return loss_value + g_loss_phys \ No newline at end of file diff --git a/pina/solvers/pinns/pinn.py b/pina/solvers/pinns/pinn.py new file mode 100644 index 00000000..318283a3 --- /dev/null +++ b/pina/solvers/pinns/pinn.py @@ -0,0 +1,170 @@ +""" Module for Physics Informed Neural Network. """ + +import torch + +try: + from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0 +except ImportError: + from torch.optim.lr_scheduler import ( + _LRScheduler as LRScheduler, + ) # torch < 2.0 + +from torch.optim.lr_scheduler import ConstantLR + +from .basepinn import PINNInterface +from pina.utils import check_consistency +from pina.problem import InverseProblem + + +class PINN(PINNInterface): + r""" + Physics Informed Neural Network (PINN) solver class. + This class implements Physics Informed Neural + Network solvers, using a user specified ``model`` to solve a specific + ``problem``. It can be used for solving both forward and inverse problems. + + The Physics Informed Network aims to find + the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` + of the differential problem: + + .. math:: + + \begin{cases} + \mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\ + \mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad, + \mathbf{x}\in\partial\Omega + \end{cases} + + minimizing the loss function + + .. math:: + \mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N + \mathcal{L}(\mathcal{A}[\mathbf{u}](\mathbf{x}_i)) + + \frac{1}{N}\sum_{i=1}^N + \mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i)) + + where :math:`\mathcal{L}` is a specific loss function, default Mean Square Error: + + .. math:: + \mathcal{L}(v) = \| v \|^2_2. + + .. seealso:: + + **Original reference**: Karniadakis, G. E., Kevrekidis, I. G., Lu, L., + Perdikaris, P., Wang, S., & Yang, L. (2021). + Physics-informed machine learning. Nature Reviews Physics, 3, 422-440. + DOI: `10.1038 `_. + """ + + def __init__( + self, + problem, + model, + extra_features=None, + loss=torch.nn.MSELoss(), + optimizer=torch.optim.Adam, + optimizer_kwargs={"lr": 0.001}, + scheduler=ConstantLR, + scheduler_kwargs={"factor": 1, "total_iters": 0}, + ): + """ + :param AbstractProblem problem: The formulation of the problem. + :param torch.nn.Module model: The neural network model to use. + :param torch.nn.Module loss: The loss function used as minimizer, + default :class:`torch.nn.MSELoss`. + :param torch.nn.Module extra_features: The additional input + features to use as augmented input. + :param torch.optim.Optimizer optimizer: The neural network optimizer to + use; default is :class:`torch.optim.Adam`. + :param dict optimizer_kwargs: Optimizer constructor keyword args. + :param torch.optim.LRScheduler scheduler: Learning + rate scheduler. + :param dict scheduler_kwargs: LR scheduler constructor keyword args. + """ + super().__init__( + models=[model], + problem=problem, + optimizers=[optimizer], + optimizers_kwargs=[optimizer_kwargs], + extra_features=extra_features, + loss=loss + ) + + # check consistency + check_consistency(scheduler, LRScheduler, subclass=True) + check_consistency(scheduler_kwargs, dict) + + # assign variables + self._scheduler = scheduler(self.optimizers[0], **scheduler_kwargs) + self._neural_net = self.models[0] + + def forward(self, x): + r""" + Forward pass implementation for the PINN solver. It returns the function + evaluation :math:`\mathbf{u}(\mathbf{x})` at the control points + :math:`\mathbf{x}`. + + :param LabelTensor x: Input tensor for the PINN solver. It expects + a tensor :math:`N \times D`, where :math:`N` the number of points + in the mesh, :math:`D` the dimension of the problem, + :return: PINN solution evaluated at contro points. + :rtype: LabelTensor + """ + return self.neural_net(x) + + def loss_phys(self, samples, equation): + """ + Computes the physics loss for the PINN solver based on given + samples and equation. + + :param LabelTensor samples: The samples to evaluate the physics loss. + :param EquationInterface equation: The governing equation + representing the physics. + :return: The physics loss calculated based on given + samples and equation. + :rtype: LabelTensor + """ + residual = self.compute_residual(samples=samples, equation=equation) + loss_value = self.loss( + torch.zeros_like(residual, requires_grad=True), residual + ) + self.store_log(loss_value=float(loss_value)) + return loss_value + + + def configure_optimizers(self): + """ + Optimizer configuration for the PINN + solver. + + :return: The optimizers and the schedulers + :rtype: tuple(list, list) + """ + # if the problem is an InverseProblem, add the unknown parameters + # to the parameters that the optimizer needs to optimize + if isinstance(self.problem, InverseProblem): + self.optimizers[0].add_param_group( + { + "params": [ + self._params[var] + for var in self.problem.unknown_variables + ] + } + ) + return self.optimizers, [self.scheduler] + + + @property + def scheduler(self): + """ + Scheduler for the PINN training. + """ + return self._scheduler + + + @property + def neural_net(self): + """ + Neural network for the PINN training. + """ + return self._neural_net \ No newline at end of file diff --git a/pina/solvers/pinns/sapinn.py b/pina/solvers/pinns/sapinn.py new file mode 100644 index 00000000..8de2d14c --- /dev/null +++ b/pina/solvers/pinns/sapinn.py @@ -0,0 +1,494 @@ +import torch +from copy import deepcopy + +try: + from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0 +except ImportError: + from torch.optim.lr_scheduler import ( + _LRScheduler as LRScheduler, + ) # torch < 2.0 + +from .basepinn import PINNInterface +from pina.utils import check_consistency +from pina.problem import InverseProblem + +from torch.optim.lr_scheduler import ConstantLR + +class Weights(torch.nn.Module): + """ + This class aims to implements the mask model for + self adaptive weights of the Self-Adaptive + PINN solver. + """ + + def __init__(self, func): + """ + :param torch.nn.Module func: the mask module of SAPINN + """ + super().__init__() + check_consistency(func, torch.nn.Module) + self.sa_weights = torch.nn.Parameter( + torch.Tensor() + ) + self.func = func + + def forward(self): + """ + Forward pass implementation for the mask module. + It returns the function on the weights + evaluation. + + :return: evaluation of self adaptive weights through the mask. + :rtype: torch.Tensor + """ + return self.func(self.sa_weights) + +class SAPINN(PINNInterface): + r""" + Self Adaptive Physics Informed Neural Network (SAPINN) solver class. + This class implements Self-Adaptive Physics Informed Neural + Network solvers, using a user specified ``model`` to solve a specific + ``problem``. It can be used for solving both forward and inverse problems. + + The Self Adapive Physics Informed Neural Network aims to find + the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` + of the differential problem: + + .. math:: + + \begin{cases} + \mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\ + \mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad, + \mathbf{x}\in\partial\Omega + \end{cases} + + integrating the pointwise loss evaluation through a mask :math:`m` and + self adaptive weights that permit to focus the loss function on + specific training samples. + The loss function to solve the problem is + + .. math:: + + \mathcal{L}_{\rm{problem}} = \frac{1}{N} \sum_{i=1}^{N_\Omega} m + \left( \lambda_{\Omega}^{i} \right) \mathcal{L} \left( \mathcal{A} + [\mathbf{u}](\mathbf{x}) \right) + \frac{1}{N} + \sum_{i=1}^{N_{\partial\Omega}} + m \left( \lambda_{\partial\Omega}^{i} \right) \mathcal{L} + \left( \mathcal{B}[\mathbf{u}](\mathbf{x}) + \right), + + + denoting the self adaptive weights as + :math:`\lambda_{\Omega}^1, \dots, \lambda_{\Omega}^{N_\Omega}` and + :math:`\lambda_{\partial \Omega}^1, \dots, + \lambda_{\Omega}^{N_\partial \Omega}` + for :math:`\Omega` and :math:`\partial \Omega`, respectively. + + Self Adaptive Physics Informed Neural Network identifies the solution + and appropriate self adaptive weights by solving the following problem + + .. math:: + + \min_{w} \max_{\lambda_{\Omega}^k, \lambda_{\partial \Omega}^s} + \mathcal{L} , + + where :math:`w` denotes the network parameters, and + :math:`\mathcal{L}` is a specific loss + function, default Mean Square Error: + + .. math:: + \mathcal{L}(v) = \| v \|^2_2. + + .. seealso:: + **Original reference**: McClenny, Levi D., and Ulisses M. Braga-Neto. + "Self-adaptive physics-informed neural networks." + Journal of Computational Physics 474 (2023): 111722. + DOI: `10.1016/ + j.jcp.2022.111722 `_. + """ + + def __init__( + self, + problem, + model, + weights_function=torch.nn.Sigmoid(), + extra_features=None, + loss=torch.nn.MSELoss(), + optimizer_model=torch.optim.Adam, + optimizer_model_kwargs={"lr" : 0.001}, + optimizer_weights=torch.optim.Adam, + optimizer_weights_kwargs={"lr" : 0.001}, + scheduler_model=ConstantLR, + scheduler_model_kwargs={"factor" : 1, "total_iters" : 0}, + scheduler_weights=ConstantLR, + scheduler_weights_kwargs={"factor" : 1, "total_iters" : 0} + ): + """ + :param AbstractProblem problem: The formualation of the problem. + :param torch.nn.Module model: The neural network model to use + for the model. + :param torch.nn.Module weights_function: The neural network model + related to the mask of SAPINN. + default :obj:`~torch.nn.Sigmoid`. + :param list(torch.nn.Module) extra_features: The additional input + features to use as augmented input. If ``None`` no extra features + are passed. If it is a list of :class:`torch.nn.Module`, + the extra feature list is passed to all models. If it is a list + of extra features' lists, each single list of extra feature + is passed to a model. + :param torch.nn.Module loss: The loss function used as minimizer, + default :class:`torch.nn.MSELoss`. + :param torch.optim.Optimizer optimizer_model: The neural + network optimizer to use for the model network + , default is `torch.optim.Adam`. + :param dict optimizer_model_kwargs: Optimizer constructor keyword + args. for the model. + :param torch.optim.Optimizer optimizer_weights: The neural + network optimizer to use for mask model model, + default is `torch.optim.Adam`. + :param dict optimizer_weights_kwargs: Optimizer constructor + keyword args. for the mask module. + :param torch.optim.LRScheduler scheduler_model: Learning + rate scheduler for the model. + :param dict scheduler_model_kwargs: LR scheduler constructor + keyword args. + :param torch.optim.LRScheduler scheduler_weights: Learning + rate scheduler for the mask model. + :param dict scheduler_model_kwargs: LR scheduler constructor + keyword args. + """ + + # check consistency weitghs_function + check_consistency(weights_function, torch.nn.Module) + + # create models for weights + weights_dict = {} + for condition_name in problem.conditions: + weights_dict[condition_name] = Weights(weights_function) + weights_dict = torch.nn.ModuleDict(weights_dict) + + + super().__init__( + models=[model, weights_dict], + problem=problem, + optimizers=[optimizer_model, optimizer_weights], + optimizers_kwargs=[ + optimizer_model_kwargs, + optimizer_weights_kwargs + ], + extra_features=extra_features, + loss=loss + ) + + # set automatic optimization + self.automatic_optimization = False + + # check consistency + check_consistency(scheduler_model, LRScheduler, subclass=True) + check_consistency(scheduler_model_kwargs, dict) + check_consistency(scheduler_weights, LRScheduler, subclass=True) + check_consistency(scheduler_weights_kwargs, dict) + + # assign schedulers + self._schedulers = [ + scheduler_model( + self.optimizers[0], **scheduler_model_kwargs + ), + scheduler_weights( + self.optimizers[1], **scheduler_weights_kwargs + ), + ] + + self._model = self.models[0] + self._weights = self.models[1] + + self._vectorial_loss = deepcopy(loss) + self._vectorial_loss.reduction = "none" + + def forward(self, x): + """ + Forward pass implementation for the PINN + solver. It returns the function + evaluation :math:`\mathbf{u}(\mathbf{x})` at the control points + :math:`\mathbf{x}`. + + :param LabelTensor x: Input tensor for the SAPINN solver. It expects + a tensor :math:`N \\times D`, where :math:`N` the number of points + in the mesh, :math:`D` the dimension of the problem, + :return: PINN solution. + :rtype: LabelTensor + """ + return self.neural_net(x) + + def loss_phys(self, samples, equation): + """ + Computes the physics loss for the SAPINN solver based on given + samples and equation. + + :param LabelTensor samples: The samples to evaluate the physics loss. + :param EquationInterface equation: The governing equation + representing the physics. + :return: The physics loss calculated based on given + samples and equation. + :rtype: torch.Tensor + """ + # train weights + self.optimizer_weights.zero_grad() + weighted_loss, _ = self._loss_phys(samples, equation) + loss_value = - weighted_loss.as_subclass(torch.Tensor) + self.manual_backward(loss_value) + self.optimizer_weights.step() + + # detaching samples from the computational graph to erase it and setting + # the gradient to true to create a new computational graph. + # In alternative set `retain_graph=True`. + samples = samples.detach() + samples.requires_grad = True + + # train model + self.optimizer_model.zero_grad() + weighted_loss, loss = self._loss_phys(samples, equation) + loss_value = weighted_loss.as_subclass(torch.Tensor) + self.manual_backward(loss_value) + self.optimizer_model.step() + + # store loss without weights + self.store_log(loss_value=float(loss)) + return loss_value + + def loss_data(self, input_tensor, output_tensor): + """ + Computes the data loss for the SAPINN solver based on input and + output. It computes the loss between the + network output against the true solution. + + :param LabelTensor input_tensor: The input to the neural networks. + :param LabelTensor output_tensor: The true solution to compare the + network solution. + :return: The computed data loss. + :rtype: torch.Tensor + """ + # train weights + self.optimizer_weights.zero_grad() + weighted_loss, _ = self._loss_data(input_tensor, output_tensor) + loss_value = - weighted_loss.as_subclass(torch.Tensor) + self.manual_backward(loss_value) + self.optimizer_weights.step() + + # detaching samples from the computational graph to erase it and setting + # the gradient to true to create a new computational graph. + # In alternative set `retain_graph=True`. + input_tensor = input_tensor.detach() + input_tensor.requires_grad = True + + # train model + self.optimizer_model.zero_grad() + weighted_loss, loss = self._loss_data(input_tensor, output_tensor) + loss_value = weighted_loss.as_subclass(torch.Tensor) + self.manual_backward(loss_value) + self.optimizer_model.step() + + # store loss without weights + self.store_log(loss_value=float(loss)) + return loss_value + + def configure_optimizers(self): + """ + Optimizer configuration for the SAPINN + solver. + + :return: The optimizers and the schedulers + :rtype: tuple(list, list) + """ + # if the problem is an InverseProblem, add the unknown parameters + # to the parameters that the optimizer needs to optimize + if isinstance(self.problem, InverseProblem): + self.optimizers[0].add_param_group( + { + "params": [ + self._params[var] + for var in self.problem.unknown_variables + ] + } + ) + return self.optimizers, self._schedulers + + def on_train_batch_end(self,outputs, batch, batch_idx): + """ + This method is called at the end of each training batch, and ovverides + the PytorchLightining implementation for logging the checkpoints. + + :param torch.Tensor outputs: The output from the model for the + current batch. + :param tuple batch: The current batch of data. + :param int batch_idx: The index of the current batch. + :return: Whatever is returned by the parent + method ``on_train_batch_end``. + :rtype: Any + """ + # increase by one the counter of optimization to save loggers + self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed += 1 + return super().on_train_batch_end(outputs, batch, batch_idx) + + def on_train_start(self): + """ + This method is called at the start of the training for setting + the self adaptive weights as parameters of the mask model. + + :return: Whatever is returned by the parent + method ``on_train_start``. + :rtype: Any + """ + device = torch.device( + self.trainer._accelerator_connector._accelerator_flag + ) + for condition_name, tensor in self.problem.input_pts.items(): + self.weights_dict.torchmodel[condition_name].sa_weights.data = torch.rand( + (tensor.shape[0], 1), + device = device + ) + return super().on_train_start() + + def on_load_checkpoint(self, checkpoint): + """ + Overriding the Pytorch Lightning ``on_load_checkpoint`` to handle + checkpoints for Self Adaptive Weights. This method should not be + overridden if not intentionally. + + :param dict checkpoint: Pytorch Lightning checkpoint dict. + """ + for condition_name, tensor in self.problem.input_pts.items(): + self.weights_dict.torchmodel[condition_name].sa_weights.data = torch.rand( + (tensor.shape[0], 1) + ) + return super().on_load_checkpoint(checkpoint) + + def _loss_phys(self, samples, equation): + """ + Elaboration of the physical loss for the SAPINN solver. + + :param LabelTensor samples: Input samples to evaluate the physics loss. + :param EquationInterface equation: the governing equation representing + the physics. + + :return: tuple with weighted and not weighted scalar loss + :rtype: List[LabelTensor, LabelTensor] + """ + residual = self.compute_residual(samples, equation) + return self._compute_loss(residual) + + def _loss_data(self, input_tensor, output_tensor): + """ + Elaboration of the loss related to data for the SAPINN solver. + + :param LabelTensor input_tensor: The input to the neural networks. + :param LabelTensor output_tensor: The true solution to compare the + network solution. + + :return: tuple with weighted and not weighted scalar loss + :rtype: List[LabelTensor, LabelTensor] + """ + residual = self.forward(input_tensor) - output_tensor + return self._compute_loss(residual) + + def _compute_loss(self, residual): + """ + Elaboration of the pointwise loss through the mask model and the + self adaptive weights + + :param LabelTensor residual: the matrix of residuals that have to + be weighted + + :return: tuple with weighted and not weighted loss + :rtype List[LabelTensor, LabelTensor] + """ + weights = self.weights_dict.torchmodel[ + self.current_condition_name].forward() + loss_value = self._vectorial_loss(torch.zeros_like( + residual, requires_grad=True), residual) + return ( + self._vect_to_scalar(weights * loss_value), + self._vect_to_scalar(loss_value) + ) + + def _vect_to_scalar(self, loss_value): + """ + Elaboration of the pointwise loss through the mask model and the + self adaptive weights + + :param LabelTensor loss_value: the matrix of pointwise loss + + :return: the scalar loss + :rtype LabelTensor + """ + if self.loss.reduction == "mean": + ret = torch.mean(loss_value) + elif self.loss.reduction == "sum": + ret = torch.sum(loss_value) + else: + raise RuntimeError(f"Invalid reduction, got {self.loss.reduction} " + "but expected mean or sum.") + return ret + + + @property + def neural_net(self): + """ + Returns the neural network model. + + :return: The neural network model. + :rtype: torch.nn.Module + """ + return self.models[0] + + @property + def weights_dict(self): + """ + Return the mask models associate to the application of + the mask to the self adaptive weights for each loss that + compones the global loss of the problem. + + :return: The ModuleDict for mask models. + :rtype: torch.nn.ModuleDict + """ + return self.models[1] + + @property + def scheduler_model(self): + """ + Returns the scheduler associated with the neural network model. + + :return: The scheduler for the neural network model. + :rtype: torch.optim.lr_scheduler._LRScheduler + """ + return self._scheduler[0] + + @property + def scheduler_weights(self): + """ + Returns the scheduler associated with the mask model (if applicable). + + :return: The scheduler for the mask model. + :rtype: torch.optim.lr_scheduler._LRScheduler + """ + return self._scheduler[1] + + @property + def optimizer_model(self): + """ + Returns the optimizer associated with the neural network model. + + :return: The optimizer for the neural network model. + :rtype: torch.optim.Optimizer + """ + return self.optimizers[0] + + @property + def optimizer_weights(self): + """ + Returns the optimizer associated with the mask model (if applicable). + + :return: The optimizer for the mask model. + :rtype: torch.optim.Optimizer + """ + return self.optimizers[1] \ No newline at end of file diff --git a/pina/solvers/rom.py b/pina/solvers/rom.py new file mode 100644 index 00000000..733d76f4 --- /dev/null +++ b/pina/solvers/rom.py @@ -0,0 +1,190 @@ +""" Module for ReducedOrderModelSolver """ + +import torch + +from pina.solvers import SupervisedSolver + +class ReducedOrderModelSolver(SupervisedSolver): + r""" + ReducedOrderModelSolver solver class. This class implements a + Reduced Order Model solver, using user specified ``reduction_network`` and + ``interpolation_network`` to solve a specific ``problem``. + + The Reduced Order Model approach aims to find + the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` + of the differential problem: + + .. math:: + + \begin{cases} + \mathcal{A}[\mathbf{u}(\mu)](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\ + \mathcal{B}[\mathbf{u}(\mu)](\mathbf{x})=0\quad, + \mathbf{x}\in\partial\Omega + \end{cases} + + This is done by using two neural networks. The ``reduction_network``, which + contains an encoder :math:`\mathcal{E}_{\rm{net}}`, a decoder + :math:`\mathcal{D}_{\rm{net}}`; and an ``interpolation_network`` + :math:`\mathcal{I}_{\rm{net}}`. The input is assumed to be discretised in + the spatial dimensions. + + The following loss function is minimized during training + + .. math:: + \mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N + \mathcal{L}(\mathcal{E}_{\rm{net}}[\mathbf{u}(\mu_i)] - + \mathcal{I}_{\rm{net}}[\mu_i]) + + \mathcal{L}( + \mathcal{D}_{\rm{net}}[\mathcal{E}_{\rm{net}}[\mathbf{u}(\mu_i)]] - + \mathbf{u}(\mu_i)) + + where :math:`\mathcal{L}` is a specific loss function, default Mean Square Error: + + .. math:: + \mathcal{L}(v) = \| v \|^2_2. + + + .. seealso:: + + **Original reference**: Hesthaven, Jan S., and Stefano Ubbiali. + "Non-intrusive reduced order modeling of nonlinear problems + using neural networks." Journal of Computational + Physics 363 (2018): 55-78. + DOI `10.1016/j.jcp.2018.02.037 + `_. + + .. note:: + The specified ``reduction_network`` must contain two methods, + namely ``encode`` for input encoding and ``decode`` for decoding the + former result. The ``interpolation_network`` network ``forward`` output + represents the interpolation of the latent space obtain with + ``reduction_network.encode``. + + .. note:: + This solver uses the end-to-end training strategy, i.e. the + ``reduction_network`` and ``interpolation_network`` are trained + simultaneously. For reference on this trainig strategy look at: + Pichi, Federico, Beatriz Moya, and Jan S. Hesthaven. + "A graph convolutional autoencoder approach to model order reduction + for parametrized PDEs." Journal of + Computational Physics 501 (2024): 112762. + DOI + `10.1016/j.jcp.2024.112762 `_. + + .. warning:: + This solver works only for data-driven model. Hence in the ``problem`` + definition the codition must only contain ``input_points`` + (e.g. coefficient parameters, time parameters), and ``output_points``. + + .. warning:: + This solver does not currently support the possibility to pass + ``extra_feature``. + """ + + def __init__( + self, + problem, + reduction_network, + interpolation_network, + loss=torch.nn.MSELoss(), + optimizer=torch.optim.Adam, + optimizer_kwargs={"lr": 0.001}, + scheduler=torch.optim.lr_scheduler.ConstantLR, + scheduler_kwargs={"factor": 1, "total_iters": 0}, + ): + """ + :param AbstractProblem problem: The formualation of the problem. + :param torch.nn.Module reduction_network: The reduction network used + for reducing the input space. It must contain two methods, + namely ``encode`` for input encoding and ``decode`` for decoding the + former result. + :param torch.nn.Module interpolation_network: The interpolation network + for interpolating the control parameters to latent space obtain by + the ``reduction_network`` encoding. + :param torch.nn.Module loss: The loss function used as minimizer, + default :class:`torch.nn.MSELoss`. + :param torch.nn.Module extra_features: The additional input + features to use as augmented input. + :param torch.optim.Optimizer optimizer: The neural network optimizer to + use; default is :class:`torch.optim.Adam`. + :param dict optimizer_kwargs: Optimizer constructor keyword args. + :param float lr: The learning rate; default is 0.001. + :param torch.optim.LRScheduler scheduler: Learning + rate scheduler. + :param dict scheduler_kwargs: LR scheduler constructor keyword args. + """ + model = torch.nn.ModuleDict({ + 'reduction_network' : reduction_network, + 'interpolation_network' : interpolation_network}) + + super().__init__( + model=model, + problem=problem, + loss=loss, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, + scheduler=scheduler, + scheduler_kwargs=scheduler_kwargs + ) + + # assert reduction object contains encode/ decode + if not hasattr(self.neural_net['reduction_network'], 'encode'): + raise SyntaxError('reduction_network must have encode method. ' + 'The encode method should return a lower ' + 'dimensional representation of the input.') + if not hasattr(self.neural_net['reduction_network'], 'decode'): + raise SyntaxError('reduction_network must have decode method. ' + 'The decode method should return a high ' + 'dimensional representation of the encoding.') + + def forward(self, x): + """ + Forward pass implementation for the solver. It finds the encoder + representation by calling ``interpolation_network.forward`` on the + input, and maps this representation to output space by calling + ``reduction_network.decode``. + + :param torch.Tensor x: Input tensor. + :return: Solver solution. + :rtype: torch.Tensor + """ + reduction_network = self.neural_net['reduction_network'] + interpolation_network = self.neural_net['interpolation_network'] + return reduction_network.decode(interpolation_network(x)) + + def loss_data(self, input_pts, output_pts): + """ + The data loss for the ReducedOrderModelSolver solver. + It computes the loss between + the network output against the true solution. This function + should not be override if not intentionally. + + :param LabelTensor input_tensor: The input to the neural networks. + :param LabelTensor output_tensor: The true solution to compare the + network solution. + :return: The residual loss averaged on the input coordinates + :rtype: torch.Tensor + """ + # extract networks + reduction_network = self.neural_net['reduction_network'] + interpolation_network = self.neural_net['interpolation_network'] + # encoded representations loss + encode_repr_inter_net = interpolation_network(input_pts) + encode_repr_reduction_network = reduction_network.encode(output_pts) + loss_encode = self.loss(encode_repr_inter_net, + encode_repr_reduction_network) + # reconstruction loss + loss_reconstruction = self.loss( + reduction_network.decode(encode_repr_reduction_network), + output_pts) + + return loss_encode + loss_reconstruction + + @property + def neural_net(self): + """ + Neural network for training. It returns a :obj:`~torch.nn.ModuleDict` + containing the ``reduction_network`` and ``interpolation_network``. + """ + return self._neural_net.torchmodel diff --git a/pina/solvers/solver.py b/pina/solvers/solver.py index 324a023d..729a9d48 100644 --- a/pina/solvers/solver.py +++ b/pina/solvers/solver.py @@ -6,6 +6,7 @@ from ..utils import check_consistency from ..problem import AbstractProblem import torch +import sys class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta): @@ -141,6 +142,20 @@ def problem(self): """ The problem formulation.""" return self._pina_problem + + def on_train_start(self): + """ + On training epoch start this function is call to do global checks for + the different solvers. + """ + + # 1. Check the verison for dataloader + dataloader = self.trainer.train_dataloader + if sys.version_info < (3, 8): + dataloader = dataloader.loaders + self._dataloader = dataloader + + return super().on_train_start() # @model.setter # def model(self, new_model): diff --git a/pina/solvers/supervised.py b/pina/solvers/supervised.py index c6a8a35b..28a634b0 100644 --- a/pina/solvers/supervised.py +++ b/pina/solvers/supervised.py @@ -1,7 +1,6 @@ """ Module for SupervisedSolver """ import torch -import sys try: from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0 @@ -20,9 +19,32 @@ class SupervisedSolver(SolverInterface): - """ + r""" SupervisedSolver solver class. This class implements a SupervisedSolver, using a user specified ``model`` to solve a specific ``problem``. + + The Supervised Solver class aims to find + a map between the input :math:`\mathbf{s}:\Omega\rightarrow\mathbb{R}^m` + and the output :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`. The input + can be discretised in space (as in :obj:`~pina.solvers.rom.ROMe2eSolver`), + or not (e.g. when training Neural Operators). + + Given a model :math:`\mathcal{M}`, the following loss function is + minimized during training: + + .. math:: + \mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N + \mathcal{L}(\mathbf{u}_i - \mathcal{M}(\mathbf{v}_i)) + + where :math:`\mathcal{L}` is a specific loss function, + default Mean Square Error: + + .. math:: + \mathcal{L}(v) = \| v \|^2_2. + + In this context :math:`\mathbf{u}_i` and :math:`\mathbf{v}_i` means that + we are seeking to approximate multiple (discretised) functions given + multiple (discretised) input functions. """ def __init__( @@ -96,18 +118,12 @@ def training_step(self, batch, batch_idx): :return: The sum of the loss functions. :rtype: LabelTensor """ - - dataloader = self.trainer.train_dataloader + condition_idx = batch["condition"] for condition_id in range(condition_idx.min(), condition_idx.max() + 1): - if sys.version_info >= (3, 8): - condition_name = dataloader.condition_names[condition_id] - else: - condition_name = dataloader.loaders.condition_names[ - condition_id - ] + condition_name = self._dataloader.condition_names[condition_id] condition = self.problem.conditions[condition_name] pts = batch["pts"] out = batch["output"] @@ -118,14 +134,14 @@ def training_step(self, batch, batch_idx): # for data driven mode if not hasattr(condition, "output_points"): raise NotImplementedError( - "Supervised solver works only in data-driven mode." + f"{type(self).__name__} works only in data-driven mode." ) output_pts = out[condition_idx == condition_id] input_pts = pts[condition_idx == condition_id] loss = ( - self.loss(self.forward(input_pts), output_pts) + self.loss_data(input_pts=input_pts, output_pts=output_pts) * condition.data_weight ) loss = loss.as_subclass(torch.Tensor) @@ -133,6 +149,20 @@ def training_step(self, batch, batch_idx): self.log("mean_loss", float(loss), prog_bar=True, logger=True) return loss + def loss_data(self, input_pts, output_pts): + """ + The data loss for the Supervised solver. It computes the loss between + the network output against the true solution. This function + should not be override if not intentionally. + + :param LabelTensor input_tensor: The input to the neural networks. + :param LabelTensor output_tensor: The true solution to compare the + network solution. + :return: The residual loss averaged on the input coordinates + :rtype: torch.Tensor + """ + return self.loss(self.forward(input_pts), output_pts) + @property def scheduler(self): """ diff --git a/pina/trainer.py b/pina/trainer.py index 0acecaaa..90779a6e 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -1,5 +1,6 @@ """ Trainer module. """ +import torch import pytorch_lightning from .utils import check_consistency from .dataset import SamplePointDataset, SamplePointLoader, DataPointDataset @@ -63,6 +64,12 @@ def _create_or_update_loader(self): self._loader = SamplePointLoader( dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True ) + pb = self._model.problem + if hasattr(pb, "unknown_parameters"): + for key in pb.unknown_parameters: + pb.unknown_parameters[key] = torch.nn.Parameter(pb.unknown_parameters[key].data.to(device)) + + def train(self, **kwargs): """ diff --git a/tests/test_solvers/test_causalpinn.py b/tests/test_solvers/test_causalpinn.py new file mode 100644 index 00000000..58518ae4 --- /dev/null +++ b/tests/test_solvers/test_causalpinn.py @@ -0,0 +1,266 @@ +import torch +import pytest + +from pina.problem import TimeDependentProblem, InverseProblem, SpatialProblem +from pina.operators import grad +from pina.geometry import CartesianDomain +from pina import Condition, LabelTensor +from pina.solvers import CausalPINN +from pina.trainer import Trainer +from pina.model import FeedForward +from pina.equation.equation import Equation +from pina.equation.equation_factory import FixedValue +from pina.loss import LpLoss + + + +class FooProblem(SpatialProblem): + ''' + Foo problem formulation. + ''' + output_variables = ['u'] + conditions = {} + spatial_domain = None + + +class InverseDiffusionReactionSystem(TimeDependentProblem, SpatialProblem, InverseProblem): + + def diffusionreaction(input_, output_, params_): + x = input_.extract('x') + t = input_.extract('t') + u_t = grad(output_, input_, d='t') + u_x = grad(output_, input_, d='x') + u_xx = grad(u_x, input_, d='x') + r = torch.exp(-t) * (1.5 * torch.sin(2*x) + (8/3)*torch.sin(3*x) + + (15/4)*torch.sin(4*x) + (63/8)*torch.sin(8*x)) + return u_t - params_['mu']*u_xx - r + + def _solution(self, pts): + t = pts.extract('t') + x = pts.extract('x') + return torch.exp(-t) * (torch.sin(x) + (1/2)*torch.sin(2*x) + + (1/3)*torch.sin(3*x) + (1/4)*torch.sin(4*x) + + (1/8)*torch.sin(8*x)) + + # assign output/ spatial and temporal variables + output_variables = ['u'] + spatial_domain = CartesianDomain({'x': [-torch.pi, torch.pi]}) + temporal_domain = CartesianDomain({'t': [0, 1]}) + unknown_parameter_domain = CartesianDomain({'mu': [-1, 1]}) + + # problem condition statement + conditions = { + 'D': Condition(location=CartesianDomain({'x': [-torch.pi, torch.pi], + 't': [0, 1]}), + equation=Equation(diffusionreaction)), + 'data' : Condition(input_points=LabelTensor(torch.tensor([[0., 0.]]), ['x', 't']), + output_points=LabelTensor(torch.tensor([[0.]]), ['u'])), + } + +class DiffusionReactionSystem(TimeDependentProblem, SpatialProblem): + + def diffusionreaction(input_, output_): + x = input_.extract('x') + t = input_.extract('t') + u_t = grad(output_, input_, d='t') + u_x = grad(output_, input_, d='x') + u_xx = grad(u_x, input_, d='x') + r = torch.exp(-t) * (1.5 * torch.sin(2*x) + (8/3)*torch.sin(3*x) + + (15/4)*torch.sin(4*x) + (63/8)*torch.sin(8*x)) + return u_t - u_xx - r + + def _solution(self, pts): + t = pts.extract('t') + x = pts.extract('x') + return torch.exp(-t) * (torch.sin(x) + (1/2)*torch.sin(2*x) + + (1/3)*torch.sin(3*x) + (1/4)*torch.sin(4*x) + + (1/8)*torch.sin(8*x)) + + # assign output/ spatial and temporal variables + output_variables = ['u'] + spatial_domain = CartesianDomain({'x': [-torch.pi, torch.pi]}) + temporal_domain = CartesianDomain({'t': [0, 1]}) + + # problem condition statement + conditions = { + 'D': Condition(location=CartesianDomain({'x': [-torch.pi, torch.pi], + 't': [0, 1]}), + equation=Equation(diffusionreaction)), + } + +class myFeature(torch.nn.Module): + """ + Feature: sin(x) + """ + + def __init__(self): + super(myFeature, self).__init__() + + def forward(self, x): + t = (torch.sin(x.extract(['x']) * torch.pi)) + return LabelTensor(t, ['sin(x)']) + + +# make the problem +problem = DiffusionReactionSystem() +model = FeedForward(len(problem.input_variables), + len(problem.output_variables)) +model_extra_feats = FeedForward( + len(problem.input_variables) + 1, + len(problem.output_variables)) +extra_feats = [myFeature()] + + +def test_constructor(): + CausalPINN(problem=problem, model=model, extra_features=None) + + with pytest.raises(ValueError): + CausalPINN(FooProblem(), model=model, extra_features=None) + + +def test_constructor_extra_feats(): + model_extra_feats = FeedForward( + len(problem.input_variables) + 1, + len(problem.output_variables)) + CausalPINN(problem=problem, + model=model_extra_feats, + extra_features=extra_feats) + + +def test_train_cpu(): + problem = DiffusionReactionSystem() + boundaries = ['D'] + n = 10 + problem.discretise_domain(n, 'grid', locations=boundaries) + pinn = CausalPINN(problem = problem, + model=model, extra_features=None, loss=LpLoss()) + trainer = Trainer(solver=pinn, max_epochs=1, + accelerator='cpu', batch_size=20) + trainer.train() + + +def test_train_restore(): + tmpdir = "tests/tmp_restore" + problem = DiffusionReactionSystem() + boundaries = ['D'] + n = 10 + problem.discretise_domain(n, 'grid', locations=boundaries) + pinn = CausalPINN(problem=problem, + model=model, + extra_features=None, + loss=LpLoss()) + trainer = Trainer(solver=pinn, + max_epochs=5, + accelerator='cpu', + default_root_dir=tmpdir) + trainer.train() + ntrainer = Trainer(solver=pinn, max_epochs=15, accelerator='cpu') + t = ntrainer.train( + ckpt_path=f'{tmpdir}/lightning_logs/version_0/' + 'checkpoints/epoch=4-step=5.ckpt') + import shutil + shutil.rmtree(tmpdir) + + +def test_train_load(): + tmpdir = "tests/tmp_load" + problem = DiffusionReactionSystem() + boundaries = ['D'] + n = 10 + problem.discretise_domain(n, 'grid', locations=boundaries) + pinn = CausalPINN(problem=problem, + model=model, + extra_features=None, + loss=LpLoss()) + trainer = Trainer(solver=pinn, + max_epochs=15, + accelerator='cpu', + default_root_dir=tmpdir) + trainer.train() + new_pinn = CausalPINN.load_from_checkpoint( + f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=15.ckpt', + problem = problem, model=model) + test_pts = CartesianDomain({'x': [0, 1], 't': [0, 1]}).sample(10) + assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1) + assert new_pinn.forward(test_pts).extract( + ['u']).shape == pinn.forward(test_pts).extract(['u']).shape + torch.testing.assert_close( + new_pinn.forward(test_pts).extract(['u']), + pinn.forward(test_pts).extract(['u'])) + import shutil + shutil.rmtree(tmpdir) + +def test_train_inverse_problem_cpu(): + problem = InverseDiffusionReactionSystem() + boundaries = ['D'] + n = 100 + problem.discretise_domain(n, 'random', locations=boundaries) + pinn = CausalPINN(problem = problem, + model=model, extra_features=None, loss=LpLoss()) + trainer = Trainer(solver=pinn, max_epochs=1, + accelerator='cpu', batch_size=20) + trainer.train() + + +# # TODO does not currently work +# def test_train_inverse_problem_restore(): +# tmpdir = "tests/tmp_restore_inv" +# problem = InverseDiffusionReactionSystem() +# boundaries = ['D'] +# n = 100 +# problem.discretise_domain(n, 'random', locations=boundaries) +# pinn = CausalPINN(problem=problem, +# model=model, +# extra_features=None, +# loss=LpLoss()) +# trainer = Trainer(solver=pinn, +# max_epochs=5, +# accelerator='cpu', +# default_root_dir=tmpdir) +# trainer.train() +# ntrainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu') +# t = ntrainer.train( +# ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt') +# import shutil +# shutil.rmtree(tmpdir) + + +def test_train_inverse_problem_load(): + tmpdir = "tests/tmp_load_inv" + problem = InverseDiffusionReactionSystem() + boundaries = ['D'] + n = 100 + problem.discretise_domain(n, 'random', locations=boundaries) + pinn = CausalPINN(problem=problem, + model=model, + extra_features=None, + loss=LpLoss()) + trainer = Trainer(solver=pinn, + max_epochs=15, + accelerator='cpu', + default_root_dir=tmpdir) + trainer.train() + new_pinn = CausalPINN.load_from_checkpoint( + f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt', + problem = problem, model=model) + test_pts = CartesianDomain({'x': [0, 1], 't': [0, 1]}).sample(10) + assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1) + assert new_pinn.forward(test_pts).extract( + ['u']).shape == pinn.forward(test_pts).extract(['u']).shape + torch.testing.assert_close( + new_pinn.forward(test_pts).extract(['u']), + pinn.forward(test_pts).extract(['u'])) + import shutil + shutil.rmtree(tmpdir) + + +def test_train_extra_feats_cpu(): + problem = DiffusionReactionSystem() + boundaries = ['D'] + n = 10 + problem.discretise_domain(n, 'grid', locations=boundaries) + pinn = CausalPINN(problem=problem, + model=model_extra_feats, + extra_features=extra_feats) + trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu') + trainer.train() \ No newline at end of file diff --git a/tests/test_solvers/test_competitive_pinn.py b/tests/test_solvers/test_competitive_pinn.py new file mode 100644 index 00000000..9c6f6c8e --- /dev/null +++ b/tests/test_solvers/test_competitive_pinn.py @@ -0,0 +1,418 @@ +import torch +import pytest + +from pina.problem import SpatialProblem, InverseProblem +from pina.operators import laplacian +from pina.geometry import CartesianDomain +from pina import Condition, LabelTensor +from pina.solvers import CompetitivePINN as PINN +from pina.trainer import Trainer +from pina.model import FeedForward +from pina.equation.equation import Equation +from pina.equation.equation_factory import FixedValue +from pina.loss import LpLoss + + +def laplace_equation(input_, output_): + force_term = (torch.sin(input_.extract(['x']) * torch.pi) * + torch.sin(input_.extract(['y']) * torch.pi)) + delta_u = laplacian(output_.extract(['u']), input_) + return delta_u - force_term + + +my_laplace = Equation(laplace_equation) +in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y']) +out_ = LabelTensor(torch.tensor([[0.]]), ['u']) +in2_ = LabelTensor(torch.rand(60, 2), ['x', 'y']) +out2_ = LabelTensor(torch.rand(60, 1), ['u']) + + +class InversePoisson(SpatialProblem, InverseProblem): + ''' + Problem definition for the Poisson equation. + ''' + output_variables = ['u'] + x_min = -2 + x_max = 2 + y_min = -2 + y_max = 2 + data_input = LabelTensor(torch.rand(10, 2), ['x', 'y']) + data_output = LabelTensor(torch.rand(10, 1), ['u']) + spatial_domain = CartesianDomain({'x': [x_min, x_max], 'y': [y_min, y_max]}) + # define the ranges for the parameters + unknown_parameter_domain = CartesianDomain({'mu1': [-1, 1], 'mu2': [-1, 1]}) + + def laplace_equation(input_, output_, params_): + ''' + Laplace equation with a force term. + ''' + force_term = torch.exp( + - 2*(input_.extract(['x']) - params_['mu1'])**2 + - 2*(input_.extract(['y']) - params_['mu2'])**2) + delta_u = laplacian(output_, input_, components=['u'], d=['x', 'y']) + + return delta_u - force_term + + # define the conditions for the loss (boundary conditions, equation, data) + conditions = { + 'gamma1': Condition(location=CartesianDomain({'x': [x_min, x_max], + 'y': y_max}), + equation=FixedValue(0.0, components=['u'])), + 'gamma2': Condition(location=CartesianDomain( + {'x': [x_min, x_max], 'y': y_min + }), + equation=FixedValue(0.0, components=['u'])), + 'gamma3': Condition(location=CartesianDomain( + {'x': x_max, 'y': [y_min, y_max] + }), + equation=FixedValue(0.0, components=['u'])), + 'gamma4': Condition(location=CartesianDomain( + {'x': x_min, 'y': [y_min, y_max] + }), + equation=FixedValue(0.0, components=['u'])), + 'D': Condition(location=CartesianDomain( + {'x': [x_min, x_max], 'y': [y_min, y_max] + }), + equation=Equation(laplace_equation)), + 'data': Condition(input_points=data_input.extract(['x', 'y']), + output_points=data_output) + } + + +class Poisson(SpatialProblem): + output_variables = ['u'] + spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]}) + + conditions = { + 'gamma1': Condition( + location=CartesianDomain({'x': [0, 1], 'y': 1}), + equation=FixedValue(0.0)), + 'gamma2': Condition( + location=CartesianDomain({'x': [0, 1], 'y': 0}), + equation=FixedValue(0.0)), + 'gamma3': Condition( + location=CartesianDomain({'x': 1, 'y': [0, 1]}), + equation=FixedValue(0.0)), + 'gamma4': Condition( + location=CartesianDomain({'x': 0, 'y': [0, 1]}), + equation=FixedValue(0.0)), + 'D': Condition( + input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']), + equation=my_laplace), + 'data': Condition( + input_points=in_, + output_points=out_), + 'data2': Condition( + input_points=in2_, + output_points=out2_) + } + + def poisson_sol(self, pts): + return -(torch.sin(pts.extract(['x']) * torch.pi) * + torch.sin(pts.extract(['y']) * torch.pi)) / (2 * torch.pi**2) + + truth_solution = poisson_sol + + +class myFeature(torch.nn.Module): + """ + Feature: sin(x) + """ + + def __init__(self): + super(myFeature, self).__init__() + + def forward(self, x): + t = (torch.sin(x.extract(['x']) * torch.pi) * + torch.sin(x.extract(['y']) * torch.pi)) + return LabelTensor(t, ['sin(x)sin(y)']) + + +# make the problem +poisson_problem = Poisson() +model = FeedForward(len(poisson_problem.input_variables), + len(poisson_problem.output_variables)) +model_extra_feats = FeedForward( + len(poisson_problem.input_variables) + 1, + len(poisson_problem.output_variables)) +extra_feats = [myFeature()] + + +def test_constructor(): + PINN(problem=poisson_problem, model=model) + PINN(problem=poisson_problem, model=model, discriminator = model) + + +def test_constructor_extra_feats(): + with pytest.raises(TypeError): + model_extra_feats = FeedForward( + len(poisson_problem.input_variables) + 1, + len(poisson_problem.output_variables)) + PINN(problem=poisson_problem, + model=model_extra_feats, + extra_features=extra_feats) + + +def test_train_cpu(): + poisson_problem = Poisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 10 + poisson_problem.discretise_domain(n, 'grid', locations=boundaries) + pinn = PINN(problem = poisson_problem, model=model, loss=LpLoss()) + trainer = Trainer(solver=pinn, max_epochs=1, + accelerator='cpu', batch_size=20) + trainer.train() + + +def test_train_restore(): + tmpdir = "tests/tmp_restore" + poisson_problem = Poisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 10 + poisson_problem.discretise_domain(n, 'grid', locations=boundaries) + pinn = PINN(problem=poisson_problem, + model=model, + loss=LpLoss()) + trainer = Trainer(solver=pinn, + max_epochs=5, + accelerator='cpu', + default_root_dir=tmpdir) + trainer.train() + ntrainer = Trainer(solver=pinn, max_epochs=15, accelerator='cpu') + t = ntrainer.train( + ckpt_path=f'{tmpdir}/lightning_logs/version_0/' + 'checkpoints/epoch=4-step=10.ckpt') + import shutil + shutil.rmtree(tmpdir) + + +def test_train_load(): + tmpdir = "tests/tmp_load" + poisson_problem = Poisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 10 + poisson_problem.discretise_domain(n, 'grid', locations=boundaries) + pinn = PINN(problem=poisson_problem, + model=model, + loss=LpLoss()) + trainer = Trainer(solver=pinn, + max_epochs=15, + accelerator='cpu', + default_root_dir=tmpdir) + trainer.train() + new_pinn = PINN.load_from_checkpoint( + f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt', + problem = poisson_problem, model=model) + test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10) + assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1) + assert new_pinn.forward(test_pts).extract( + ['u']).shape == pinn.forward(test_pts).extract(['u']).shape + torch.testing.assert_close( + new_pinn.forward(test_pts).extract(['u']), + pinn.forward(test_pts).extract(['u'])) + import shutil + shutil.rmtree(tmpdir) + +def test_train_inverse_problem_cpu(): + poisson_problem = InversePoisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D'] + n = 100 + poisson_problem.discretise_domain(n, 'random', locations=boundaries) + pinn = PINN(problem = poisson_problem, model=model, loss=LpLoss()) + trainer = Trainer(solver=pinn, max_epochs=1, + accelerator='cpu', batch_size=20) + trainer.train() + + +# # TODO does not currently work +# def test_train_inverse_problem_restore(): +# tmpdir = "tests/tmp_restore_inv" +# poisson_problem = InversePoisson() +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D'] +# n = 100 +# poisson_problem.discretise_domain(n, 'random', locations=boundaries) +# pinn = PINN(problem=poisson_problem, +# model=model, +# loss=LpLoss()) +# trainer = Trainer(solver=pinn, +# max_epochs=5, +# accelerator='cpu', +# default_root_dir=tmpdir) +# trainer.train() +# ntrainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu') +# t = ntrainer.train( +# ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=10.ckpt') +# import shutil +# shutil.rmtree(tmpdir) + + +def test_train_inverse_problem_load(): + tmpdir = "tests/tmp_load_inv" + poisson_problem = InversePoisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D'] + n = 100 + poisson_problem.discretise_domain(n, 'random', locations=boundaries) + pinn = PINN(problem=poisson_problem, + model=model, + loss=LpLoss()) + trainer = Trainer(solver=pinn, + max_epochs=15, + accelerator='cpu', + default_root_dir=tmpdir) + trainer.train() + new_pinn = PINN.load_from_checkpoint( + f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt', + problem = poisson_problem, model=model) + test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10) + assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1) + assert new_pinn.forward(test_pts).extract( + ['u']).shape == pinn.forward(test_pts).extract(['u']).shape + torch.testing.assert_close( + new_pinn.forward(test_pts).extract(['u']), + pinn.forward(test_pts).extract(['u'])) + import shutil + shutil.rmtree(tmpdir) + +# # TODO fix asap. Basically sampling few variables +# # works only if both variables are in a range. +# # if one is fixed and the other not, this will +# # not work. This test also needs to be fixed and +# # insert in test problem not in test pinn. +# def test_train_cpu_sampling_few_vars(): +# poisson_problem = Poisson() +# boundaries = ['gamma1', 'gamma2', 'gamma3'] +# n = 10 +# poisson_problem.discretise_domain(n, 'grid', locations=boundaries) +# poisson_problem.discretise_domain(n, 'random', locations=['gamma4'], variables=['x']) +# poisson_problem.discretise_domain(n, 'random', locations=['gamma4'], variables=['y']) +# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) +# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'cpu'}) +# trainer.train() + + +# TODO, fix GitHub actions to run also on GPU +# def test_train_gpu(): +# poisson_problem = Poisson() +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# poisson_problem.discretise_domain(n, 'grid', locations=boundaries) +# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) +# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'}) +# trainer.train() + +# def test_train_gpu(): #TODO fix ASAP +# poisson_problem = Poisson() +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# poisson_problem.discretise_domain(n, 'grid', locations=boundaries) +# poisson_problem.conditions.pop('data') # The input/output pts are allocated on cpu +# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) +# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'}) +# trainer.train() + +# def test_train_2(): +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# expected_keys = [[], list(range(0, 50, 3))] +# param = [0, 3] +# for i, truth_key in zip(param, expected_keys): +# pinn = PINN(problem, model) +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(50, save_loss=i) +# assert list(pinn.history_loss.keys()) == truth_key + + +# def test_train_extra_feats(): +# pinn = PINN(problem, model_extra_feat, [myFeature()]) +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(5) + + +# def test_train_2_extra_feats(): +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# expected_keys = [[], list(range(0, 50, 3))] +# param = [0, 3] +# for i, truth_key in zip(param, expected_keys): +# pinn = PINN(problem, model_extra_feat, [myFeature()]) +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(50, save_loss=i) +# assert list(pinn.history_loss.keys()) == truth_key + + +# def test_train_with_optimizer_kwargs(): +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# expected_keys = [[], list(range(0, 50, 3))] +# param = [0, 3] +# for i, truth_key in zip(param, expected_keys): +# pinn = PINN(problem, model, optimizer_kwargs={'lr' : 0.3}) +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(50, save_loss=i) +# assert list(pinn.history_loss.keys()) == truth_key + + +# def test_train_with_lr_scheduler(): +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# expected_keys = [[], list(range(0, 50, 3))] +# param = [0, 3] +# for i, truth_key in zip(param, expected_keys): +# pinn = PINN( +# problem, +# model, +# lr_scheduler_type=torch.optim.lr_scheduler.CyclicLR, +# lr_scheduler_kwargs={'base_lr' : 0.1, 'max_lr' : 0.3, 'cycle_momentum': False} +# ) +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(50, save_loss=i) +# assert list(pinn.history_loss.keys()) == truth_key + + +# # def test_train_batch(): +# # pinn = PINN(problem, model, batch_size=6) +# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# # n = 10 +# # pinn.discretise_domain(n, 'grid', locations=boundaries) +# # pinn.discretise_domain(n, 'grid', locations=['D']) +# # pinn.train(5) + + +# # def test_train_batch_2(): +# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# # n = 10 +# # expected_keys = [[], list(range(0, 50, 3))] +# # param = [0, 3] +# # for i, truth_key in zip(param, expected_keys): +# # pinn = PINN(problem, model, batch_size=6) +# # pinn.discretise_domain(n, 'grid', locations=boundaries) +# # pinn.discretise_domain(n, 'grid', locations=['D']) +# # pinn.train(50, save_loss=i) +# # assert list(pinn.history_loss.keys()) == truth_key + + +# if torch.cuda.is_available(): + +# # def test_gpu_train(): +# # pinn = PINN(problem, model, batch_size=20, device='cuda') +# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# # n = 100 +# # pinn.discretise_domain(n, 'grid', locations=boundaries) +# # pinn.discretise_domain(n, 'grid', locations=['D']) +# # pinn.train(5) + +# def test_gpu_train_nobatch(): +# pinn = PINN(problem, model, batch_size=None, device='cuda') +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 100 +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(5) + diff --git a/tests/test_solvers/test_gpinn.py b/tests/test_solvers/test_gpinn.py new file mode 100644 index 00000000..d00d3b4d --- /dev/null +++ b/tests/test_solvers/test_gpinn.py @@ -0,0 +1,432 @@ +import torch + +from pina.problem import SpatialProblem, InverseProblem +from pina.operators import laplacian +from pina.geometry import CartesianDomain +from pina import Condition, LabelTensor +from pina.solvers import GPINN +from pina.trainer import Trainer +from pina.model import FeedForward +from pina.equation.equation import Equation +from pina.equation.equation_factory import FixedValue +from pina.loss import LpLoss + + +def laplace_equation(input_, output_): + force_term = (torch.sin(input_.extract(['x']) * torch.pi) * + torch.sin(input_.extract(['y']) * torch.pi)) + delta_u = laplacian(output_.extract(['u']), input_) + return delta_u - force_term + + +my_laplace = Equation(laplace_equation) +in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y']) +out_ = LabelTensor(torch.tensor([[0.]]), ['u']) +in2_ = LabelTensor(torch.rand(60, 2), ['x', 'y']) +out2_ = LabelTensor(torch.rand(60, 1), ['u']) + + +class InversePoisson(SpatialProblem, InverseProblem): + ''' + Problem definition for the Poisson equation. + ''' + output_variables = ['u'] + x_min = -2 + x_max = 2 + y_min = -2 + y_max = 2 + data_input = LabelTensor(torch.rand(10, 2), ['x', 'y']) + data_output = LabelTensor(torch.rand(10, 1), ['u']) + spatial_domain = CartesianDomain({'x': [x_min, x_max], 'y': [y_min, y_max]}) + # define the ranges for the parameters + unknown_parameter_domain = CartesianDomain({'mu1': [-1, 1], 'mu2': [-1, 1]}) + + def laplace_equation(input_, output_, params_): + ''' + Laplace equation with a force term. + ''' + force_term = torch.exp( + - 2*(input_.extract(['x']) - params_['mu1'])**2 + - 2*(input_.extract(['y']) - params_['mu2'])**2) + delta_u = laplacian(output_, input_, components=['u'], d=['x', 'y']) + + return delta_u - force_term + + # define the conditions for the loss (boundary conditions, equation, data) + conditions = { + 'gamma1': Condition(location=CartesianDomain({'x': [x_min, x_max], + 'y': y_max}), + equation=FixedValue(0.0, components=['u'])), + 'gamma2': Condition(location=CartesianDomain( + {'x': [x_min, x_max], 'y': y_min}), + equation=FixedValue(0.0, components=['u'])), + 'gamma3': Condition(location=CartesianDomain( + {'x': x_max, 'y': [y_min, y_max]}), + equation=FixedValue(0.0, components=['u'])), + 'gamma4': Condition(location=CartesianDomain( + {'x': x_min, 'y': [y_min, y_max] + }), + equation=FixedValue(0.0, components=['u'])), + 'D': Condition(location=CartesianDomain( + {'x': [x_min, x_max], 'y': [y_min, y_max] + }), + equation=Equation(laplace_equation)), + 'data': Condition( + input_points=data_input.extract(['x', 'y']), + output_points=data_output) + } + + +class Poisson(SpatialProblem): + output_variables = ['u'] + spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]}) + + conditions = { + 'gamma1': Condition( + location=CartesianDomain({'x': [0, 1], 'y': 1}), + equation=FixedValue(0.0)), + 'gamma2': Condition( + location=CartesianDomain({'x': [0, 1], 'y': 0}), + equation=FixedValue(0.0)), + 'gamma3': Condition( + location=CartesianDomain({'x': 1, 'y': [0, 1]}), + equation=FixedValue(0.0)), + 'gamma4': Condition( + location=CartesianDomain({'x': 0, 'y': [0, 1]}), + equation=FixedValue(0.0)), + 'D': Condition( + input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']), + equation=my_laplace), + 'data': Condition( + input_points=in_, + output_points=out_), + 'data2': Condition( + input_points=in2_, + output_points=out2_) + } + + def poisson_sol(self, pts): + return -(torch.sin(pts.extract(['x']) * torch.pi) * + torch.sin(pts.extract(['y']) * torch.pi)) / (2 * torch.pi**2) + + truth_solution = poisson_sol + + +class myFeature(torch.nn.Module): + """ + Feature: sin(x) + """ + + def __init__(self): + super(myFeature, self).__init__() + + def forward(self, x): + t = (torch.sin(x.extract(['x']) * torch.pi) * + torch.sin(x.extract(['y']) * torch.pi)) + return LabelTensor(t, ['sin(x)sin(y)']) + + +# make the problem +poisson_problem = Poisson() +model = FeedForward(len(poisson_problem.input_variables), + len(poisson_problem.output_variables)) +model_extra_feats = FeedForward( + len(poisson_problem.input_variables) + 1, + len(poisson_problem.output_variables)) +extra_feats = [myFeature()] + + +def test_constructor(): + GPINN(problem=poisson_problem, model=model, extra_features=None) + + +def test_constructor_extra_feats(): + model_extra_feats = FeedForward( + len(poisson_problem.input_variables) + 1, + len(poisson_problem.output_variables)) + GPINN(problem=poisson_problem, + model=model_extra_feats, + extra_features=extra_feats) + + +def test_train_cpu(): + poisson_problem = Poisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 10 + poisson_problem.discretise_domain(n, 'grid', locations=boundaries) + pinn = GPINN(problem = poisson_problem, + model=model, extra_features=None, loss=LpLoss()) + trainer = Trainer(solver=pinn, max_epochs=1, + accelerator='cpu', batch_size=20) + trainer.train() + + +def test_train_restore(): + tmpdir = "tests/tmp_restore" + poisson_problem = Poisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 10 + poisson_problem.discretise_domain(n, 'grid', locations=boundaries) + pinn = GPINN(problem=poisson_problem, + model=model, + extra_features=None, + loss=LpLoss()) + trainer = Trainer(solver=pinn, + max_epochs=5, + accelerator='cpu', + default_root_dir=tmpdir) + trainer.train() + ntrainer = Trainer(solver=pinn, max_epochs=15, accelerator='cpu') + t = ntrainer.train( + ckpt_path=f'{tmpdir}/lightning_logs/version_0/' + 'checkpoints/epoch=4-step=10.ckpt') + import shutil + shutil.rmtree(tmpdir) + + +def test_train_load(): + tmpdir = "tests/tmp_load" + poisson_problem = Poisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 10 + poisson_problem.discretise_domain(n, 'grid', locations=boundaries) + pinn = GPINN(problem=poisson_problem, + model=model, + extra_features=None, + loss=LpLoss()) + trainer = Trainer(solver=pinn, + max_epochs=15, + accelerator='cpu', + default_root_dir=tmpdir) + trainer.train() + new_pinn = GPINN.load_from_checkpoint( + f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt', + problem = poisson_problem, model=model) + test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10) + assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1) + assert new_pinn.forward(test_pts).extract( + ['u']).shape == pinn.forward(test_pts).extract(['u']).shape + torch.testing.assert_close( + new_pinn.forward(test_pts).extract(['u']), + pinn.forward(test_pts).extract(['u'])) + import shutil + shutil.rmtree(tmpdir) + +def test_train_inverse_problem_cpu(): + poisson_problem = InversePoisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D'] + n = 100 + poisson_problem.discretise_domain(n, 'random', locations=boundaries) + pinn = GPINN(problem = poisson_problem, + model=model, extra_features=None, loss=LpLoss()) + trainer = Trainer(solver=pinn, max_epochs=1, + accelerator='cpu', batch_size=20) + trainer.train() + + +# # TODO does not currently work +# def test_train_inverse_problem_restore(): +# tmpdir = "tests/tmp_restore_inv" +# poisson_problem = InversePoisson() +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D'] +# n = 100 +# poisson_problem.discretise_domain(n, 'random', locations=boundaries) +# pinn = GPINN(problem=poisson_problem, +# model=model, +# extra_features=None, +# loss=LpLoss()) +# trainer = Trainer(solver=pinn, +# max_epochs=5, +# accelerator='cpu', +# default_root_dir=tmpdir) +# trainer.train() +# ntrainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu') +# t = ntrainer.train( +# ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=10.ckpt') +# import shutil +# shutil.rmtree(tmpdir) + + +def test_train_inverse_problem_load(): + tmpdir = "tests/tmp_load_inv" + poisson_problem = InversePoisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D'] + n = 100 + poisson_problem.discretise_domain(n, 'random', locations=boundaries) + pinn = GPINN(problem=poisson_problem, + model=model, + extra_features=None, + loss=LpLoss()) + trainer = Trainer(solver=pinn, + max_epochs=15, + accelerator='cpu', + default_root_dir=tmpdir) + trainer.train() + new_pinn = GPINN.load_from_checkpoint( + f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt', + problem = poisson_problem, model=model) + test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10) + assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1) + assert new_pinn.forward(test_pts).extract( + ['u']).shape == pinn.forward(test_pts).extract(['u']).shape + torch.testing.assert_close( + new_pinn.forward(test_pts).extract(['u']), + pinn.forward(test_pts).extract(['u'])) + import shutil + shutil.rmtree(tmpdir) + +# # TODO fix asap. Basically sampling few variables +# # works only if both variables are in a range. +# # if one is fixed and the other not, this will +# # not work. This test also needs to be fixed and +# # insert in test problem not in test pinn. +# def test_train_cpu_sampling_few_vars(): +# poisson_problem = Poisson() +# boundaries = ['gamma1', 'gamma2', 'gamma3'] +# n = 10 +# poisson_problem.discretise_domain(n, 'grid', locations=boundaries) +# poisson_problem.discretise_domain(n, 'random', locations=['gamma4'], variables=['x']) +# poisson_problem.discretise_domain(n, 'random', locations=['gamma4'], variables=['y']) +# pinn = GPINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) +# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'cpu'}) +# trainer.train() + + +def test_train_extra_feats_cpu(): + poisson_problem = Poisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 10 + poisson_problem.discretise_domain(n, 'grid', locations=boundaries) + pinn = GPINN(problem=poisson_problem, + model=model_extra_feats, + extra_features=extra_feats) + trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu') + trainer.train() + + +# TODO, fix GitHub actions to run also on GPU +# def test_train_gpu(): +# poisson_problem = Poisson() +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# poisson_problem.discretise_domain(n, 'grid', locations=boundaries) +# pinn = GPINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) +# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'}) +# trainer.train() + +# def test_train_gpu(): #TODO fix ASAP +# poisson_problem = Poisson() +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# poisson_problem.discretise_domain(n, 'grid', locations=boundaries) +# poisson_problem.conditions.pop('data') # The input/output pts are allocated on cpu +# pinn = GPINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) +# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'}) +# trainer.train() + +# def test_train_2(): +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# expected_keys = [[], list(range(0, 50, 3))] +# param = [0, 3] +# for i, truth_key in zip(param, expected_keys): +# pinn = GPINN(problem, model) +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(50, save_loss=i) +# assert list(pinn.history_loss.keys()) == truth_key + + +# def test_train_extra_feats(): +# pinn = GPINN(problem, model_extra_feat, [myFeature()]) +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(5) + + +# def test_train_2_extra_feats(): +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# expected_keys = [[], list(range(0, 50, 3))] +# param = [0, 3] +# for i, truth_key in zip(param, expected_keys): +# pinn = GPINN(problem, model_extra_feat, [myFeature()]) +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(50, save_loss=i) +# assert list(pinn.history_loss.keys()) == truth_key + + +# def test_train_with_optimizer_kwargs(): +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# expected_keys = [[], list(range(0, 50, 3))] +# param = [0, 3] +# for i, truth_key in zip(param, expected_keys): +# pinn = GPINN(problem, model, optimizer_kwargs={'lr' : 0.3}) +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(50, save_loss=i) +# assert list(pinn.history_loss.keys()) == truth_key + + +# def test_train_with_lr_scheduler(): +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# expected_keys = [[], list(range(0, 50, 3))] +# param = [0, 3] +# for i, truth_key in zip(param, expected_keys): +# pinn = GPINN( +# problem, +# model, +# lr_scheduler_type=torch.optim.lr_scheduler.CyclicLR, +# lr_scheduler_kwargs={'base_lr' : 0.1, 'max_lr' : 0.3, 'cycle_momentum': False} +# ) +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(50, save_loss=i) +# assert list(pinn.history_loss.keys()) == truth_key + + +# # def test_train_batch(): +# # pinn = GPINN(problem, model, batch_size=6) +# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# # n = 10 +# # pinn.discretise_domain(n, 'grid', locations=boundaries) +# # pinn.discretise_domain(n, 'grid', locations=['D']) +# # pinn.train(5) + + +# # def test_train_batch_2(): +# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# # n = 10 +# # expected_keys = [[], list(range(0, 50, 3))] +# # param = [0, 3] +# # for i, truth_key in zip(param, expected_keys): +# # pinn = GPINN(problem, model, batch_size=6) +# # pinn.discretise_domain(n, 'grid', locations=boundaries) +# # pinn.discretise_domain(n, 'grid', locations=['D']) +# # pinn.train(50, save_loss=i) +# # assert list(pinn.history_loss.keys()) == truth_key + + +# if torch.cuda.is_available(): + +# # def test_gpu_train(): +# # pinn = GPINN(problem, model, batch_size=20, device='cuda') +# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# # n = 100 +# # pinn.discretise_domain(n, 'grid', locations=boundaries) +# # pinn.discretise_domain(n, 'grid', locations=['D']) +# # pinn.train(5) + +# def test_gpu_train_nobatch(): +# pinn = GPINN(problem, model, batch_size=None, device='cuda') +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 100 +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(5) + diff --git a/tests/test_solvers/test_pinn.py b/tests/test_solvers/test_pinn.py index 0a42410c..ea3b077b 100644 --- a/tests/test_solvers/test_pinn.py +++ b/tests/test_solvers/test_pinn.py @@ -1,6 +1,6 @@ import torch -from pina.problem import SpatialProblem +from pina.problem import SpatialProblem, InverseProblem from pina.operators import laplacian from pina.geometry import CartesianDomain from pina import Condition, LabelTensor @@ -26,6 +26,58 @@ def laplace_equation(input_, output_): out2_ = LabelTensor(torch.rand(60, 1), ['u']) +class InversePoisson(SpatialProblem, InverseProblem): + ''' + Problem definition for the Poisson equation. + ''' + output_variables = ['u'] + x_min = -2 + x_max = 2 + y_min = -2 + y_max = 2 + data_input = LabelTensor(torch.rand(10, 2), ['x', 'y']) + data_output = LabelTensor(torch.rand(10, 1), ['u']) + spatial_domain = CartesianDomain({'x': [x_min, x_max], 'y': [y_min, y_max]}) + # define the ranges for the parameters + unknown_parameter_domain = CartesianDomain({'mu1': [-1, 1], 'mu2': [-1, 1]}) + + def laplace_equation(input_, output_, params_): + ''' + Laplace equation with a force term. + ''' + force_term = torch.exp( + - 2*(input_.extract(['x']) - params_['mu1'])**2 + - 2*(input_.extract(['y']) - params_['mu2'])**2) + delta_u = laplacian(output_, input_, components=['u'], d=['x', 'y']) + + return delta_u - force_term + + # define the conditions for the loss (boundary conditions, equation, data) + conditions = { + 'gamma1': Condition(location=CartesianDomain({'x': [x_min, x_max], + 'y': y_max}), + equation=FixedValue(0.0, components=['u'])), + 'gamma2': Condition(location=CartesianDomain( + {'x': [x_min, x_max], 'y': y_min + }), + equation=FixedValue(0.0, components=['u'])), + 'gamma3': Condition(location=CartesianDomain( + {'x': x_max, 'y': [y_min, y_max] + }), + equation=FixedValue(0.0, components=['u'])), + 'gamma4': Condition(location=CartesianDomain( + {'x': x_min, 'y': [y_min, y_max] + }), + equation=FixedValue(0.0, components=['u'])), + 'D': Condition(location=CartesianDomain( + {'x': [x_min, x_max], 'y': [y_min, y_max] + }), + equation=Equation(laplace_equation)), + 'data': Condition(input_points=data_input.extract(['x', 'y']), + output_points=data_output) + } + + class Poisson(SpatialProblem): output_variables = ['u'] spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]}) @@ -103,8 +155,10 @@ def test_train_cpu(): boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] n = 10 poisson_problem.discretise_domain(n, 'grid', locations=boundaries) - pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) - trainer = Trainer(solver=pinn, max_epochs=1, accelerator='cpu', batch_size=20) + pinn = PINN(problem = poisson_problem, model=model, + extra_features=None, loss=LpLoss()) + trainer = Trainer(solver=pinn, max_epochs=1, + accelerator='cpu', batch_size=20) trainer.train() @@ -125,7 +179,8 @@ def test_train_restore(): trainer.train() ntrainer = Trainer(solver=pinn, max_epochs=15, accelerator='cpu') t = ntrainer.train( - ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=10.ckpt') + ckpt_path=f'{tmpdir}/lightning_logs/version_0/' + 'checkpoints/epoch=4-step=10.ckpt') import shutil shutil.rmtree(tmpdir) @@ -158,6 +213,68 @@ def test_train_load(): import shutil shutil.rmtree(tmpdir) +def test_train_inverse_problem_cpu(): + poisson_problem = InversePoisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D'] + n = 100 + poisson_problem.discretise_domain(n, 'random', locations=boundaries) + pinn = PINN(problem = poisson_problem, model=model, + extra_features=None, loss=LpLoss()) + trainer = Trainer(solver=pinn, max_epochs=1, + accelerator='cpu', batch_size=20) + trainer.train() + + +# # TODO does not currently work +# def test_train_inverse_problem_restore(): +# tmpdir = "tests/tmp_restore_inv" +# poisson_problem = InversePoisson() +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D'] +# n = 100 +# poisson_problem.discretise_domain(n, 'random', locations=boundaries) +# pinn = PINN(problem=poisson_problem, +# model=model, +# extra_features=None, +# loss=LpLoss()) +# trainer = Trainer(solver=pinn, +# max_epochs=5, +# accelerator='cpu', +# default_root_dir=tmpdir) +# trainer.train() +# ntrainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu') +# t = ntrainer.train( +# ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=10.ckpt') +# import shutil +# shutil.rmtree(tmpdir) + + +def test_train_inverse_problem_load(): + tmpdir = "tests/tmp_load_inv" + poisson_problem = InversePoisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D'] + n = 100 + poisson_problem.discretise_domain(n, 'random', locations=boundaries) + pinn = PINN(problem=poisson_problem, + model=model, + extra_features=None, + loss=LpLoss()) + trainer = Trainer(solver=pinn, + max_epochs=15, + accelerator='cpu', + default_root_dir=tmpdir) + trainer.train() + new_pinn = PINN.load_from_checkpoint( + f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt', + problem = poisson_problem, model=model) + test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10) + assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1) + assert new_pinn.forward(test_pts).extract( + ['u']).shape == pinn.forward(test_pts).extract(['u']).shape + torch.testing.assert_close( + new_pinn.forward(test_pts).extract(['u']), + pinn.forward(test_pts).extract(['u'])) + import shutil + shutil.rmtree(tmpdir) # # TODO fix asap. Basically sampling few variables # # works only if both variables are in a range. @@ -197,85 +314,32 @@ def test_train_extra_feats_cpu(): # pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) # trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'}) # trainer.train() -""" -def test_train_gpu(): #TODO fix ASAP - poisson_problem = Poisson() - boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] - n = 10 - poisson_problem.discretise_domain(n, 'grid', locations=boundaries) - poisson_problem.conditions.pop('data') # The input/output pts are allocated on cpu - pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) - trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'}) - trainer.train() - -def test_train_2(): - boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] - n = 10 - expected_keys = [[], list(range(0, 50, 3))] - param = [0, 3] - for i, truth_key in zip(param, expected_keys): - pinn = PINN(problem, model) - pinn.discretise_domain(n, 'grid', locations=boundaries) - pinn.discretise_domain(n, 'grid', locations=['D']) - pinn.train(50, save_loss=i) - assert list(pinn.history_loss.keys()) == truth_key - - -def test_train_extra_feats(): - pinn = PINN(problem, model_extra_feat, [myFeature()]) - boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] - n = 10 - pinn.discretise_domain(n, 'grid', locations=boundaries) - pinn.discretise_domain(n, 'grid', locations=['D']) - pinn.train(5) - - -def test_train_2_extra_feats(): - boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] - n = 10 - expected_keys = [[], list(range(0, 50, 3))] - param = [0, 3] - for i, truth_key in zip(param, expected_keys): - pinn = PINN(problem, model_extra_feat, [myFeature()]) - pinn.discretise_domain(n, 'grid', locations=boundaries) - pinn.discretise_domain(n, 'grid', locations=['D']) - pinn.train(50, save_loss=i) - assert list(pinn.history_loss.keys()) == truth_key +# def test_train_gpu(): #TODO fix ASAP +# poisson_problem = Poisson() +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# poisson_problem.discretise_domain(n, 'grid', locations=boundaries) +# poisson_problem.conditions.pop('data') # The input/output pts are allocated on cpu +# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) +# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'}) +# trainer.train() -def test_train_with_optimizer_kwargs(): - boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] - n = 10 - expected_keys = [[], list(range(0, 50, 3))] - param = [0, 3] - for i, truth_key in zip(param, expected_keys): - pinn = PINN(problem, model, optimizer_kwargs={'lr' : 0.3}) - pinn.discretise_domain(n, 'grid', locations=boundaries) - pinn.discretise_domain(n, 'grid', locations=['D']) - pinn.train(50, save_loss=i) - assert list(pinn.history_loss.keys()) == truth_key +# def test_train_2(): +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# expected_keys = [[], list(range(0, 50, 3))] +# param = [0, 3] +# for i, truth_key in zip(param, expected_keys): +# pinn = PINN(problem, model) +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(50, save_loss=i) +# assert list(pinn.history_loss.keys()) == truth_key -def test_train_with_lr_scheduler(): - boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] - n = 10 - expected_keys = [[], list(range(0, 50, 3))] - param = [0, 3] - for i, truth_key in zip(param, expected_keys): - pinn = PINN( - problem, - model, - lr_scheduler_type=torch.optim.lr_scheduler.CyclicLR, - lr_scheduler_kwargs={'base_lr' : 0.1, 'max_lr' : 0.3, 'cycle_momentum': False} - ) - pinn.discretise_domain(n, 'grid', locations=boundaries) - pinn.discretise_domain(n, 'grid', locations=['D']) - pinn.train(50, save_loss=i) - assert list(pinn.history_loss.keys()) == truth_key - - -# def test_train_batch(): -# pinn = PINN(problem, model, batch_size=6) +# def test_train_extra_feats(): +# pinn = PINN(problem, model_extra_feat, [myFeature()]) # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] # n = 10 # pinn.discretise_domain(n, 'grid', locations=boundaries) @@ -283,34 +347,87 @@ def test_train_with_lr_scheduler(): # pinn.train(5) -# def test_train_batch_2(): +# def test_train_2_extra_feats(): +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# expected_keys = [[], list(range(0, 50, 3))] +# param = [0, 3] +# for i, truth_key in zip(param, expected_keys): +# pinn = PINN(problem, model_extra_feat, [myFeature()]) +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(50, save_loss=i) +# assert list(pinn.history_loss.keys()) == truth_key + + +# def test_train_with_optimizer_kwargs(): +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# expected_keys = [[], list(range(0, 50, 3))] +# param = [0, 3] +# for i, truth_key in zip(param, expected_keys): +# pinn = PINN(problem, model, optimizer_kwargs={'lr' : 0.3}) +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(50, save_loss=i) +# assert list(pinn.history_loss.keys()) == truth_key + + +# def test_train_with_lr_scheduler(): # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] # n = 10 # expected_keys = [[], list(range(0, 50, 3))] # param = [0, 3] # for i, truth_key in zip(param, expected_keys): -# pinn = PINN(problem, model, batch_size=6) +# pinn = PINN( +# problem, +# model, +# lr_scheduler_type=torch.optim.lr_scheduler.CyclicLR, +# lr_scheduler_kwargs={'base_lr' : 0.1, 'max_lr' : 0.3, 'cycle_momentum': False} +# ) # pinn.discretise_domain(n, 'grid', locations=boundaries) # pinn.discretise_domain(n, 'grid', locations=['D']) # pinn.train(50, save_loss=i) # assert list(pinn.history_loss.keys()) == truth_key -if torch.cuda.is_available(): - - # def test_gpu_train(): - # pinn = PINN(problem, model, batch_size=20, device='cuda') - # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] - # n = 100 - # pinn.discretise_domain(n, 'grid', locations=boundaries) - # pinn.discretise_domain(n, 'grid', locations=['D']) - # pinn.train(5) - - def test_gpu_train_nobatch(): - pinn = PINN(problem, model, batch_size=None, device='cuda') - boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] - n = 100 - pinn.discretise_domain(n, 'grid', locations=boundaries) - pinn.discretise_domain(n, 'grid', locations=['D']) - pinn.train(5) -""" +# # def test_train_batch(): +# # pinn = PINN(problem, model, batch_size=6) +# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# # n = 10 +# # pinn.discretise_domain(n, 'grid', locations=boundaries) +# # pinn.discretise_domain(n, 'grid', locations=['D']) +# # pinn.train(5) + + +# # def test_train_batch_2(): +# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# # n = 10 +# # expected_keys = [[], list(range(0, 50, 3))] +# # param = [0, 3] +# # for i, truth_key in zip(param, expected_keys): +# # pinn = PINN(problem, model, batch_size=6) +# # pinn.discretise_domain(n, 'grid', locations=boundaries) +# # pinn.discretise_domain(n, 'grid', locations=['D']) +# # pinn.train(50, save_loss=i) +# # assert list(pinn.history_loss.keys()) == truth_key + + +# if torch.cuda.is_available(): + +# # def test_gpu_train(): +# # pinn = PINN(problem, model, batch_size=20, device='cuda') +# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# # n = 100 +# # pinn.discretise_domain(n, 'grid', locations=boundaries) +# # pinn.discretise_domain(n, 'grid', locations=['D']) +# # pinn.train(5) + +# def test_gpu_train_nobatch(): +# pinn = PINN(problem, model, batch_size=None, device='cuda') +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 100 +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(5) + diff --git a/tests/test_solvers/test_rom_solver.py b/tests/test_solvers/test_rom_solver.py new file mode 100644 index 00000000..a16ffcaa --- /dev/null +++ b/tests/test_solvers/test_rom_solver.py @@ -0,0 +1,105 @@ +import torch +import pytest + +from pina.problem import AbstractProblem +from pina import Condition, LabelTensor +from pina.solvers import ReducedOrderModelSolver +from pina.trainer import Trainer +from pina.model import FeedForward +from pina.loss import LpLoss + + +class NeuralOperatorProblem(AbstractProblem): + input_variables = ['u_0', 'u_1'] + output_variables = [f'u_{i}' for i in range(100)] + conditions = {'data' : Condition(input_points= + LabelTensor(torch.rand(10, 2), + input_variables), + output_points= + LabelTensor(torch.rand(10, 100), + output_variables))} + + +# make the problem + extra feats +class AE(torch.nn.Module): + def __init__(self, input_dimensions, rank): + super().__init__() + self.encode = FeedForward(input_dimensions, rank, layers=[input_dimensions//4]) + self.decode = FeedForward(rank, input_dimensions, layers=[input_dimensions//4]) +class AE_missing_encode(torch.nn.Module): + def __init__(self, input_dimensions, rank): + super().__init__() + self.encode = FeedForward(input_dimensions, rank, layers=[input_dimensions//4]) +class AE_missing_decode(torch.nn.Module): + def __init__(self, input_dimensions, rank): + super().__init__() + self.decode = FeedForward(rank, input_dimensions, layers=[input_dimensions//4]) + +rank = 10 +problem = NeuralOperatorProblem() +interpolation_net = FeedForward(len(problem.input_variables), + rank) +reduction_net = AE(len(problem.output_variables), rank) + +def test_constructor(): + ReducedOrderModelSolver(problem=problem,reduction_network=reduction_net, + interpolation_network=interpolation_net) + with pytest.raises(SyntaxError): + ReducedOrderModelSolver(problem=problem, + reduction_network=AE_missing_encode( + len(problem.output_variables), rank), + interpolation_network=interpolation_net) + ReducedOrderModelSolver(problem=problem, + reduction_network=AE_missing_decode( + len(problem.output_variables), rank), + interpolation_network=interpolation_net) + + +def test_train_cpu(): + solver = ReducedOrderModelSolver(problem = problem,reduction_network=reduction_net, + interpolation_network=interpolation_net, loss=LpLoss()) + trainer = Trainer(solver=solver, max_epochs=3, accelerator='cpu', batch_size=20) + trainer.train() + + +def test_train_restore(): + tmpdir = "tests/tmp_restore" + solver = ReducedOrderModelSolver(problem=problem, + reduction_network=reduction_net, + interpolation_network=interpolation_net, + loss=LpLoss()) + trainer = Trainer(solver=solver, + max_epochs=5, + accelerator='cpu', + default_root_dir=tmpdir) + trainer.train() + ntrainer = Trainer(solver=solver, max_epochs=15, accelerator='cpu') + t = ntrainer.train( + ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt') + import shutil + shutil.rmtree(tmpdir) + + +def test_train_load(): + tmpdir = "tests/tmp_load" + solver = ReducedOrderModelSolver(problem=problem, + reduction_network=reduction_net, + interpolation_network=interpolation_net, + loss=LpLoss()) + trainer = Trainer(solver=solver, + max_epochs=15, + accelerator='cpu', + default_root_dir=tmpdir) + trainer.train() + new_solver = ReducedOrderModelSolver.load_from_checkpoint( + f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=15.ckpt', + problem = problem,reduction_network=reduction_net, + interpolation_network=interpolation_net) + test_pts = LabelTensor(torch.rand(20, 2), problem.input_variables) + assert new_solver.forward(test_pts).shape == (20, 100) + assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape + torch.testing.assert_close( + new_solver.forward(test_pts), + solver.forward(test_pts)) + import shutil + shutil.rmtree(tmpdir) \ No newline at end of file diff --git a/tests/test_solvers/test_sapinn.py b/tests/test_solvers/test_sapinn.py new file mode 100644 index 00000000..60c3094c --- /dev/null +++ b/tests/test_solvers/test_sapinn.py @@ -0,0 +1,437 @@ +import torch +import pytest + +from pina.problem import SpatialProblem, InverseProblem +from pina.operators import laplacian +from pina.geometry import CartesianDomain +from pina import Condition, LabelTensor +from pina.solvers import SAPINN as PINN +from pina.trainer import Trainer +from pina.model import FeedForward +from pina.equation.equation import Equation +from pina.equation.equation_factory import FixedValue +from pina.loss import LpLoss + + +def laplace_equation(input_, output_): + force_term = (torch.sin(input_.extract(['x']) * torch.pi) * + torch.sin(input_.extract(['y']) * torch.pi)) + delta_u = laplacian(output_.extract(['u']), input_) + return delta_u - force_term + + +my_laplace = Equation(laplace_equation) +in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y']) +out_ = LabelTensor(torch.tensor([[0.]]), ['u']) +in2_ = LabelTensor(torch.rand(60, 2), ['x', 'y']) +out2_ = LabelTensor(torch.rand(60, 1), ['u']) + + +class InversePoisson(SpatialProblem, InverseProblem): + ''' + Problem definition for the Poisson equation. + ''' + output_variables = ['u'] + x_min = -2 + x_max = 2 + y_min = -2 + y_max = 2 + data_input = LabelTensor(torch.rand(10, 2), ['x', 'y']) + data_output = LabelTensor(torch.rand(10, 1), ['u']) + spatial_domain = CartesianDomain({'x': [x_min, x_max], 'y': [y_min, y_max]}) + # define the ranges for the parameters + unknown_parameter_domain = CartesianDomain({'mu1': [-1, 1], 'mu2': [-1, 1]}) + + def laplace_equation(input_, output_, params_): + ''' + Laplace equation with a force term. + ''' + force_term = torch.exp( + - 2*(input_.extract(['x']) - params_['mu1'])**2 + - 2*(input_.extract(['y']) - params_['mu2'])**2) + delta_u = laplacian(output_, input_, components=['u'], d=['x', 'y']) + + return delta_u - force_term + + # define the conditions for the loss (boundary conditions, equation, data) + conditions = { + 'gamma1': Condition(location=CartesianDomain({'x': [x_min, x_max], + 'y': y_max}), + equation=FixedValue(0.0, components=['u'])), + 'gamma2': Condition(location=CartesianDomain( + {'x': [x_min, x_max], 'y': y_min + }), + equation=FixedValue(0.0, components=['u'])), + 'gamma3': Condition(location=CartesianDomain( + {'x': x_max, 'y': [y_min, y_max] + }), + equation=FixedValue(0.0, components=['u'])), + 'gamma4': Condition(location=CartesianDomain( + {'x': x_min, 'y': [y_min, y_max] + }), + equation=FixedValue(0.0, components=['u'])), + 'D': Condition(location=CartesianDomain( + {'x': [x_min, x_max], 'y': [y_min, y_max] + }), + equation=Equation(laplace_equation)), + 'data': Condition(input_points=data_input.extract(['x', 'y']), + output_points=data_output) + } + + +class Poisson(SpatialProblem): + output_variables = ['u'] + spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]}) + + conditions = { + 'gamma1': Condition( + location=CartesianDomain({'x': [0, 1], 'y': 1}), + equation=FixedValue(0.0)), + 'gamma2': Condition( + location=CartesianDomain({'x': [0, 1], 'y': 0}), + equation=FixedValue(0.0)), + 'gamma3': Condition( + location=CartesianDomain({'x': 1, 'y': [0, 1]}), + equation=FixedValue(0.0)), + 'gamma4': Condition( + location=CartesianDomain({'x': 0, 'y': [0, 1]}), + equation=FixedValue(0.0)), + 'D': Condition( + input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']), + equation=my_laplace), + 'data': Condition( + input_points=in_, + output_points=out_), + 'data2': Condition( + input_points=in2_, + output_points=out2_) + } + + def poisson_sol(self, pts): + return -(torch.sin(pts.extract(['x']) * torch.pi) * + torch.sin(pts.extract(['y']) * torch.pi)) / (2 * torch.pi**2) + + truth_solution = poisson_sol + + +class myFeature(torch.nn.Module): + """ + Feature: sin(x) + """ + + def __init__(self): + super(myFeature, self).__init__() + + def forward(self, x): + t = (torch.sin(x.extract(['x']) * torch.pi) * + torch.sin(x.extract(['y']) * torch.pi)) + return LabelTensor(t, ['sin(x)sin(y)']) + + +# make the problem +poisson_problem = Poisson() +model = FeedForward(len(poisson_problem.input_variables), + len(poisson_problem.output_variables)) +model_extra_feats = FeedForward( + len(poisson_problem.input_variables) + 1, + len(poisson_problem.output_variables)) +extra_feats = [myFeature()] + + +def test_constructor(): + PINN(problem=poisson_problem, model=model, extra_features=None) + with pytest.raises(ValueError): + PINN(problem=poisson_problem, model=model, extra_features=None, + weights_function=1) + + +def test_constructor_extra_feats(): + model_extra_feats = FeedForward( + len(poisson_problem.input_variables) + 1, + len(poisson_problem.output_variables)) + PINN(problem=poisson_problem, + model=model_extra_feats, + extra_features=extra_feats) + + +def test_train_cpu(): + poisson_problem = Poisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 10 + poisson_problem.discretise_domain(n, 'grid', locations=boundaries) + pinn = PINN(problem = poisson_problem, model=model, + extra_features=None, loss=LpLoss()) + trainer = Trainer(solver=pinn, max_epochs=1, + accelerator='cpu', batch_size=20) + trainer.train() + + +def test_train_restore(): + tmpdir = "tests/tmp_restore" + poisson_problem = Poisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 10 + poisson_problem.discretise_domain(n, 'grid', locations=boundaries) + pinn = PINN(problem=poisson_problem, + model=model, + extra_features=None, + loss=LpLoss()) + trainer = Trainer(solver=pinn, + max_epochs=5, + accelerator='cpu', + default_root_dir=tmpdir) + trainer.train() + ntrainer = Trainer(solver=pinn, max_epochs=15, accelerator='cpu') + t = ntrainer.train( + ckpt_path=f'{tmpdir}/lightning_logs/version_0/' + 'checkpoints/epoch=4-step=10.ckpt') + import shutil + shutil.rmtree(tmpdir) + + +def test_train_load(): + tmpdir = "tests/tmp_load" + poisson_problem = Poisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 10 + poisson_problem.discretise_domain(n, 'grid', locations=boundaries) + pinn = PINN(problem=poisson_problem, + model=model, + extra_features=None, + loss=LpLoss()) + trainer = Trainer(solver=pinn, + max_epochs=15, + accelerator='cpu', + default_root_dir=tmpdir) + trainer.train() + new_pinn = PINN.load_from_checkpoint( + f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt', + problem = poisson_problem, model=model) + test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10) + assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1) + assert new_pinn.forward(test_pts).extract( + ['u']).shape == pinn.forward(test_pts).extract(['u']).shape + torch.testing.assert_close( + new_pinn.forward(test_pts).extract(['u']), + pinn.forward(test_pts).extract(['u'])) + import shutil + shutil.rmtree(tmpdir) + +def test_train_inverse_problem_cpu(): + poisson_problem = InversePoisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D'] + n = 100 + poisson_problem.discretise_domain(n, 'random', locations=boundaries) + pinn = PINN(problem = poisson_problem, model=model, + extra_features=None, loss=LpLoss()) + trainer = Trainer(solver=pinn, max_epochs=1, + accelerator='cpu', batch_size=20) + trainer.train() + + +# # TODO does not currently work +# def test_train_inverse_problem_restore(): +# tmpdir = "tests/tmp_restore_inv" +# poisson_problem = InversePoisson() +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D'] +# n = 100 +# poisson_problem.discretise_domain(n, 'random', locations=boundaries) +# pinn = PINN(problem=poisson_problem, +# model=model, +# extra_features=None, +# loss=LpLoss()) +# trainer = Trainer(solver=pinn, +# max_epochs=5, +# accelerator='cpu', +# default_root_dir=tmpdir) +# trainer.train() +# ntrainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu') +# t = ntrainer.train( +# ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=10.ckpt') +# import shutil +# shutil.rmtree(tmpdir) + + +def test_train_inverse_problem_load(): + tmpdir = "tests/tmp_load_inv" + poisson_problem = InversePoisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D'] + n = 100 + poisson_problem.discretise_domain(n, 'random', locations=boundaries) + pinn = PINN(problem=poisson_problem, + model=model, + extra_features=None, + loss=LpLoss()) + trainer = Trainer(solver=pinn, + max_epochs=15, + accelerator='cpu', + default_root_dir=tmpdir) + trainer.train() + new_pinn = PINN.load_from_checkpoint( + f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt', + problem = poisson_problem, model=model) + test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10) + assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1) + assert new_pinn.forward(test_pts).extract( + ['u']).shape == pinn.forward(test_pts).extract(['u']).shape + torch.testing.assert_close( + new_pinn.forward(test_pts).extract(['u']), + pinn.forward(test_pts).extract(['u'])) + import shutil + shutil.rmtree(tmpdir) + +# # TODO fix asap. Basically sampling few variables +# # works only if both variables are in a range. +# # if one is fixed and the other not, this will +# # not work. This test also needs to be fixed and +# # insert in test problem not in test pinn. +# def test_train_cpu_sampling_few_vars(): +# poisson_problem = Poisson() +# boundaries = ['gamma1', 'gamma2', 'gamma3'] +# n = 10 +# poisson_problem.discretise_domain(n, 'grid', locations=boundaries) +# poisson_problem.discretise_domain(n, 'random', locations=['gamma4'], variables=['x']) +# poisson_problem.discretise_domain(n, 'random', locations=['gamma4'], variables=['y']) +# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) +# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'cpu'}) +# trainer.train() + + +def test_train_extra_feats_cpu(): + poisson_problem = Poisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 10 + poisson_problem.discretise_domain(n, 'grid', locations=boundaries) + pinn = PINN(problem=poisson_problem, + model=model_extra_feats, + extra_features=extra_feats) + trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu') + trainer.train() + + +# TODO, fix GitHub actions to run also on GPU +# def test_train_gpu(): +# poisson_problem = Poisson() +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# poisson_problem.discretise_domain(n, 'grid', locations=boundaries) +# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) +# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'}) +# trainer.train() + +# def test_train_gpu(): #TODO fix ASAP +# poisson_problem = Poisson() +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# poisson_problem.discretise_domain(n, 'grid', locations=boundaries) +# poisson_problem.conditions.pop('data') # The input/output pts are allocated on cpu +# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) +# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'}) +# trainer.train() + +# def test_train_2(): +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# expected_keys = [[], list(range(0, 50, 3))] +# param = [0, 3] +# for i, truth_key in zip(param, expected_keys): +# pinn = PINN(problem, model) +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(50, save_loss=i) +# assert list(pinn.history_loss.keys()) == truth_key + + +# def test_train_extra_feats(): +# pinn = PINN(problem, model_extra_feat, [myFeature()]) +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(5) + + +# def test_train_2_extra_feats(): +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# expected_keys = [[], list(range(0, 50, 3))] +# param = [0, 3] +# for i, truth_key in zip(param, expected_keys): +# pinn = PINN(problem, model_extra_feat, [myFeature()]) +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(50, save_loss=i) +# assert list(pinn.history_loss.keys()) == truth_key + + +# def test_train_with_optimizer_kwargs(): +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# expected_keys = [[], list(range(0, 50, 3))] +# param = [0, 3] +# for i, truth_key in zip(param, expected_keys): +# pinn = PINN(problem, model, optimizer_kwargs={'lr' : 0.3}) +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(50, save_loss=i) +# assert list(pinn.history_loss.keys()) == truth_key + + +# def test_train_with_lr_scheduler(): +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 10 +# expected_keys = [[], list(range(0, 50, 3))] +# param = [0, 3] +# for i, truth_key in zip(param, expected_keys): +# pinn = PINN( +# problem, +# model, +# lr_scheduler_type=torch.optim.lr_scheduler.CyclicLR, +# lr_scheduler_kwargs={'base_lr' : 0.1, 'max_lr' : 0.3, 'cycle_momentum': False} +# ) +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(50, save_loss=i) +# assert list(pinn.history_loss.keys()) == truth_key + + +# # def test_train_batch(): +# # pinn = PINN(problem, model, batch_size=6) +# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# # n = 10 +# # pinn.discretise_domain(n, 'grid', locations=boundaries) +# # pinn.discretise_domain(n, 'grid', locations=['D']) +# # pinn.train(5) + + +# # def test_train_batch_2(): +# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# # n = 10 +# # expected_keys = [[], list(range(0, 50, 3))] +# # param = [0, 3] +# # for i, truth_key in zip(param, expected_keys): +# # pinn = PINN(problem, model, batch_size=6) +# # pinn.discretise_domain(n, 'grid', locations=boundaries) +# # pinn.discretise_domain(n, 'grid', locations=['D']) +# # pinn.train(50, save_loss=i) +# # assert list(pinn.history_loss.keys()) == truth_key + + +# if torch.cuda.is_available(): + +# # def test_gpu_train(): +# # pinn = PINN(problem, model, batch_size=20, device='cuda') +# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# # n = 100 +# # pinn.discretise_domain(n, 'grid', locations=boundaries) +# # pinn.discretise_domain(n, 'grid', locations=['D']) +# # pinn.train(5) + +# def test_gpu_train_nobatch(): +# pinn = PINN(problem, model, batch_size=None, device='cuda') +# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +# n = 100 +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) +# pinn.train(5) +