Skip to content

Commit

Permalink
added tutorial for NNsPOD
Browse files Browse the repository at this point in the history
  • Loading branch information
MMRROOO committed Jul 21, 2022
1 parent 37cf70a commit 1171901
Show file tree
Hide file tree
Showing 6 changed files with 2,300 additions and 172 deletions.
22 changes: 9 additions & 13 deletions ezyrb/ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _build_model(self, points, values):
layers_torch.append(nn.Linear(layers[-2], layers[-1]))
self.model = nn.Sequential(*layers_torch)

def fit(self, points, values):
def fit(self, points, values, optimizer = torch.optim.Adam, learning_rate = 0.001, frequency_print = 0):
"""
Build the ANN given 'points' and 'values' and perform training.
Expand All @@ -119,14 +119,16 @@ def fit(self, points, values):
:param numpy.ndarray points: the coordinates of the given (training)
points.
:param numpy.ndarray values: the (training) values in the points.
:param torch.optimizer optimizer: the optimizer used for the neural network
:param float learning_rate: learning rate used in the optimizer
:param int frequency_print: the number of epochs between the print of each loss value
"""

self._build_model(points, values)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr = 0.01)
self.optimizer = optimizer(self.model.parameters(), lr = learning_rate)

points = self._convert_numpy_to_torch(points)
values = self._convert_numpy_to_torch(values)
print(points.shape, values.shape)
n_epoch = 1
flag = True
while flag:
Expand All @@ -143,7 +145,9 @@ def fit(self, points, values):
elif isinstance(criteria, float): # stop criteria is float
if loss.item() < criteria:
flag = False
print(loss.item())
if frequency_print != 0:
if n_epoch % frequency_print == 1:
print(loss.item())
n_epoch += 1

def predict(self, new_point):
Expand All @@ -157,10 +161,7 @@ def predict(self, new_point):
new_point = self._convert_numpy_to_torch(np.array(new_point))
y_new = self.model(new_point)
return self._convert_torch_to_numpy(y_new)

def predict_tensor(self, new_point):

return self.model(new_point)


def save_state(self, filename):

Expand All @@ -171,18 +172,13 @@ def save_state(self, filename):
'model_class' : self.model.__class__
}



torch.save(checkpoint, filename)

def load_state(self, filename, points, values):

checkpoint = torch.load(filename)



self._build_model(points, values)
print(self.model)
self.optimizer = checkpoint['optimizer_class']

self.model.load_state_dict(checkpoint['model_state'])
Expand Down
1 change: 0 additions & 1 deletion ezyrb/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def add(self, parameters, snapshots, space=None):
raise RuntimeError('No Spatial Value given')

if (self._space is not None) or (space is not None):
print(space.shape, snapshots.shape)
if len(space) != len(snapshots) or len(space[0]) != len(snapshots[0]):
raise RuntimeError(
'length of space and snapshots are different.')
Expand Down
Loading

0 comments on commit 1171901

Please sign in to comment.