Skip to content

Commit

Permalink
bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
MMRROOO committed Jul 21, 2022
1 parent 5b577bd commit 0f054a3
Show file tree
Hide file tree
Showing 3 changed files with 1,824 additions and 4,115 deletions.
9 changes: 1 addition & 8 deletions ezyrb/ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,7 @@ def predict(self, new_point):
new_point = self._convert_numpy_to_torch(np.array(new_point))
y_new = self.model(new_point)
return self._convert_torch_to_numpy(y_new)

def predict_tensor(self, new_point):

return self.model(new_point)


def save_state(self, filename):

Expand All @@ -175,16 +172,12 @@ def save_state(self, filename):
'model_class' : self.model.__class__
}



torch.save(checkpoint, filename)

def load_state(self, filename, points, values):

checkpoint = torch.load(filename)



self._build_model(points, values)
self.optimizer = checkpoint['optimizer_class']

Expand Down
97 changes: 18 additions & 79 deletions ezyrb/nnspod.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class NNsPOD(POD):
def __init__(self, method = "svd", path = None):
super().__init__(method)
self.path = path


def reshape2dto1d(self, x, y):
"""
Expand All @@ -24,8 +24,7 @@ def reshape2dto1d(self, x, y):
y = y.reshape(-1,1)
coords = np.concatenate((x, y), axis = 1)
coords = np.array(coords).reshape(-1,2)



return coords

def reshape1dto2d(self, snapshots):
Expand All @@ -35,7 +34,6 @@ def reshape1dto2d(self, snapshots):
"""
return snapshots.reshape(int(np.sqrt(len(snapshots))), int(np.sqrt(len(snapshots))))


def train_interpnet(self,ref_data, interp_layers, interp_function, interp_stop_training, interp_loss, retrain = False, frequency_print = 0):
"""
trains the Interpnet given 1d data:
Expand All @@ -47,7 +45,7 @@ def train_interpnet(self,ref_data, interp_layers, interp_function, interp_stop_t
:param torch.nn.Module interp_loss: loss function (MSE default)
:param boolean retrain: True if the interpNetShould be retrained, False if it should be loaded
"""

self.interp_net = ANN(interp_layers, interp_function, interp_stop_training, interp_loss)
if len(ref_data.space.shape) > 2:
space = ref_data.space.reshape(-1, 2)
Expand All @@ -64,13 +62,13 @@ def train_interpnet(self,ref_data, interp_layers, interp_function, interp_stop_t
else:
self.interp_net.fit(space, snapshots, frequency_print = frequency_print)
self.interp_net.save_state(self.path)

def shift(self, x, y, shift_quantity):
"""
shifts data by shift_quanity
"""
return(x+shift_quantity, y)

def pre_shift(self,x,y, ref_y):
"""
moves data so that the max of y and max of ref_y are at the same x coordinate
Expand All @@ -83,9 +81,9 @@ def pre_shift(self,x,y, ref_y):
for i, n in enumerate(ref_y):
if n > ref_y[maxref]:
maxref = i

return self.shift(x, y, x[maxref]-x[maxy])[0]

def make_points(self, x, params):
"""
creates points that can be used to train and predict shiftnet
Expand All @@ -102,7 +100,7 @@ def make_points(self, x, params):
points[j][0] = s
points[j][1] = params[0]
return points

def build_model(self, dim = 1):
"""
builds model based on dimension of input data
Expand All @@ -117,9 +115,9 @@ def build_model(self, dim = 1):
layers_torch.append(nn.Linear(layers[-2], layers[-1]))
self.model = nn.Sequential(*layers_torch)



def train_shiftnet(self, db, shift_layers, shift_function, shift_stop_training, ref_data, preshift = False):
def train_shiftnet(self, db, shift_layers, shift_function, shift_stop_training,
ref_data, preshift = False,
optimizer = torch.optim.Adam, learning_rate = 0.0001, frequency_print = 0):
"""
Trains and evaluates shiftnet given 1d data 'db'
Expand All @@ -146,15 +144,13 @@ def train_shiftnet(self, db, shift_layers, shift_function, shift_stop_training,
else:
self.build_model(dim = 1)
x_reshaped = x.reshape(-1,1)

values = db.snapshots.reshape(-1,1)

self.stop_training = shift_stop_training
points = self.make_points(x, db.parameters)



self.optimizer = torch.optim.Adam(self.model.parameters(), 0.0001)
self.optimizer = optimizer(self.model.parameters(), lr = learning_rate)

self.loss = torch.nn.MSELoss()
points = torch.from_numpy(points).float()
Expand All @@ -168,7 +164,6 @@ def train_shiftnet(self, db, shift_layers, shift_function, shift_stop_training,
shift)
ref_interp = self.interp_net.model(x_shift)
loss = self.loss(ref_interp, y)
print(loss.item())
loss.backward()
self.optimizer.step()
self.loss_trend.append(loss.item())
Expand All @@ -179,8 +174,11 @@ def train_shiftnet(self, db, shift_layers, shift_function, shift_stop_training,
elif isinstance(criteria, float): # stop criteria is float
if loss.item() < criteria:
flag = False
if frequency_print != 0:
if n_epoch % frequency_print == 1:
print(loss.item())
n_epoch += 1

new_point = self.make_points(x, db.parameters)
shift = self.model(torch.from_numpy(new_point).float())
x_new = self.shift(
Expand All @@ -190,62 +188,3 @@ def train_shiftnet(self, db, shift_layers, shift_function, shift_stop_training,
x_ret = x_new.detach().numpy()
return x_ret

def train_ShiftNet2d(self, db, shift_layers, shift_function, shift_stop_training, ref_data, preshift = False):
"""
Trains and evaluates shiftnet given 2d data 'db'
:param Database db: data at a certain parameter value
:param list shift_layers: ordered list with number of neurons in each layer
:param torch.nn.modeulse.activation shift_function: the activation function used by the shiftnet
:param int, float, or list stop_training:
int: number of epochs before stopping
float: desired tolarance before stopping training
list: a int and a float, stops when either desired epochs or tolerance is reached
:param Database db: data at the reference datapoint
:param boolean preshift: True if preshift is desired otherwise false.
"""
self.layers = shift_layers
self.function = shift_function
self.loss_trend = []
if preshift:
x = x_preshifted = self.pre_shift(db.space[0], db.snapshots[0], ref_data.snapshots[0])
else:
x = db.space[0]

self.stop_training = shift_stop_training
points = self.make_points(x, db.parameters)
self.build_model(dim = 2)

self.optimizer = torch.optim.Adam(self.model.parameters(), 0.00001)

self.loss = torch.nn.MSELoss()
points = torch.from_numpy(points).float()
n_epoch = 1
flag = True
while flag:
shift = self.model(points)
x_shift, y = self.shift(
torch.from_numpy(x.reshape(-1,2)).float(),
torch.from_numpy(db.snapshots.reshape(-1,1)).float(),
shift)
ref_interp = self.interp_net.model(x_shift)
loss = self.loss(ref_interp, y)
print(loss.item())
loss.backward()
self.optimizer.step()
self.loss_trend.append(loss.item())
for criteria in self.stop_training:
if isinstance(criteria, int): # stop criteria is an integer
if n_epoch == criteria:
flag = False
elif isinstance(criteria, float): # stop criteria is float
if loss.item() < criteria:
flag = False
n_epoch += 1

new_point = self.make_points(x_preshifted, db.parameters)
shift = self.model(torch.from_numpy(new_point).float())
x_new = self.shift(
torch.from_numpy(x_preshifted.reshape(-1,2)).float(),
torch.from_numpy(db.snapshots.reshape(-1,1)).float(),
shift)[0]
return x_new
Loading

0 comments on commit 0f054a3

Please sign in to comment.