Skip to content

Commit

Permalink
PINN variants addition and Solvers Update (#263)
Browse files Browse the repository at this point in the history
* gpinn/basepinn new classes, pinn restructure
* codacy fix gpinn/basepinn/pinn
* inverse problem fix
* Causal PINN (#267)
* fix GPU training in inverse problem (#283)
* Create a `compute_residual` attribute for `PINNInterface`
* Modify dataloading in solvers (#286)
* Modify PINNInterface by removing _loss_phys, _loss_data
* Adding in PINNInterface a variable to track the current condition during training
* Modify GPINN,PINN,CausalPINN to match changes in PINNInterface
* Competitive Pinn Addition (#288)
* fixing after rebase/ fix loss
* fixing final issues

---------

Co-authored-by: Dario Coscia <[email protected]>

* Modify min max formulation to max min for paper consistency
* Adding SAPINN solver (#291)
* rom solver
* fix import

---------

Co-authored-by: Dario Coscia <[email protected]>
Co-authored-by: Anna Ivagnes <[email protected]>
Co-authored-by: valc89 <[email protected]>
Co-authored-by: Monthly Tag bot <[email protected]>
Co-authored-by: Nicola Demo <[email protected]>
  • Loading branch information
6 people authored May 10, 2024
1 parent 39dc6c4 commit e0429bb
Show file tree
Hide file tree
Showing 29 changed files with 3,838 additions and 358 deletions.
6 changes: 6 additions & 0 deletions docs/source/_rst/_code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,14 @@ Solvers
:titlesonly:

SolverInterface <solvers/solver_interface.rst>
PINNInterface <solvers/basepinn.rst>
PINN <solvers/pinn.rst>
GPINN <solvers/gpinn.rst>
CausalPINN <solvers/causalpinn.rst>
CompetitivePINN <solvers/competitivepinn.rst>
SAPINN <solvers/sapinn.rst>
Supervised solver <solvers/supervised.rst>
ReducedOrderModelSolver <solvers/rom.rst>
GAROM <solvers/garom.rst>


Expand Down
7 changes: 7 additions & 0 deletions docs/source/_rst/solvers/basepinn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
PINNInterface
=================
.. currentmodule:: pina.solvers.pinns.basepinn

.. autoclass:: PINNInterface
:members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/source/_rst/solvers/causalpinn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
CausalPINN
==============
.. currentmodule:: pina.solvers.pinns.causalpinn

.. autoclass:: CausalPINN
:members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/source/_rst/solvers/competitivepinn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
CompetitivePINN
=================
.. currentmodule:: pina.solvers.pinns.competitive_pinn

.. autoclass:: CompetitivePINN
:members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/source/_rst/solvers/gpinn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
GPINN
======
.. currentmodule:: pina.solvers.pinns.gpinn

.. autoclass:: GPINN
:members:
:show-inheritance:
2 changes: 1 addition & 1 deletion docs/source/_rst/solvers/pinn.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
PINN
======
.. currentmodule:: pina.solvers.pinn
.. currentmodule:: pina.solvers.pinns.pinn

.. autoclass:: PINN
:members:
Expand Down
7 changes: 7 additions & 0 deletions docs/source/_rst/solvers/rom.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
ReducedOrderModelSolver
==========================
.. currentmodule:: pina.solvers.rom

.. autoclass:: ReducedOrderModelSolver
:members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/source/_rst/solvers/sapinn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SAPINN
======
.. currentmodule:: pina.solvers.pinns.sapinn

.. autoclass:: SAPINN
:members:
:show-inheritance:
4 changes: 2 additions & 2 deletions pina/model/avno.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ def forward(self, x):
"""
points_tmp = x.extract(self.coordinates_indices)
new_batch = x.extract(self.field_indices)
new_batch = concatenate((new_batch, points_tmp), dim=2)
new_batch = concatenate((new_batch, points_tmp), dim=-1)
new_batch = self._lifting_operator(new_batch)
new_batch = self._integral_kernels(new_batch)
new_batch = concatenate((new_batch, points_tmp), dim=2)
new_batch = concatenate((new_batch, points_tmp), dim=-1)
new_batch = self._projection_operator(new_batch)
return new_batch
21 changes: 17 additions & 4 deletions pina/solvers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
__all__ = ["PINN", "GAROM", "SupervisedSolver", "SolverInterface"]
__all__ = [
"SolverInterface",
"PINNInterface",
"PINN",
"GPINN",
"CausalPINN",
"CompetitivePINN",
"SAPINN",
"SupervisedSolver",
"ReducedOrderModelSolver",
"GAROM",
]

from .garom import GAROM
from .pinn import PINN
from .supervised import SupervisedSolver
from .solver import SolverInterface
from .pinns import *
from .supervised import SupervisedSolver
from .rom import ReducedOrderModelSolver
from .garom import GAROM

9 changes: 1 addition & 8 deletions pina/solvers/garom.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,18 +253,11 @@ def training_step(self, batch, batch_idx):
:rtype: LabelTensor
"""

dataloader = self.trainer.train_dataloader
condition_idx = batch["condition"]

for condition_id in range(condition_idx.min(), condition_idx.max() + 1):

if sys.version_info >= (3, 8):
condition_name = dataloader.condition_names[condition_id]
else:
condition_name = dataloader.loaders.condition_names[
condition_id
]

condition_name = self._dataloader.condition_names[condition_id]
condition = self.problem.conditions[condition_name]
pts = batch["pts"].detach()
out = batch["output"]
Expand Down
232 changes: 0 additions & 232 deletions pina/solvers/pinn.py

This file was deleted.

15 changes: 15 additions & 0 deletions pina/solvers/pinns/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
__all__ = [
"PINNInterface",
"PINN",
"GPINN",
"CausalPINN",
"CompetitivePINN",
"SAPINN",
]

from .basepinn import PINNInterface
from .pinn import PINN
from .gpinn import GPINN
from .causalpinn import CausalPINN
from .competitive_pinn import CompetitivePINN
from .sapinn import SAPINN
Loading

0 comments on commit e0429bb

Please sign in to comment.