Skip to content

Commit

Permalink
added tutorial for NNsPOD
Browse files Browse the repository at this point in the history
  • Loading branch information
MMRROOO committed Jul 21, 2022
1 parent 37cf70a commit 5b577bd
Show file tree
Hide file tree
Showing 6 changed files with 4,529 additions and 110 deletions.
13 changes: 8 additions & 5 deletions ezyrb/ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _build_model(self, points, values):
layers_torch.append(nn.Linear(layers[-2], layers[-1]))
self.model = nn.Sequential(*layers_torch)

def fit(self, points, values):
def fit(self, points, values, optimizer = torch.optim.Adam, learning_rate = 0.001, frequency_print = 0):
"""
Build the ANN given 'points' and 'values' and perform training.
Expand All @@ -119,14 +119,16 @@ def fit(self, points, values):
:param numpy.ndarray points: the coordinates of the given (training)
points.
:param numpy.ndarray values: the (training) values in the points.
:param torch.optimizer optimizer: the optimizer used for the neural network
:param float learning_rate: learning rate used in the optimizer
:param int frequency_print: the number of epochs between the print of each loss value
"""

self._build_model(points, values)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr = 0.01)
self.optimizer = optimizer(self.model.parameters(), lr = learning_rate)

points = self._convert_numpy_to_torch(points)
values = self._convert_numpy_to_torch(values)
print(points.shape, values.shape)
n_epoch = 1
flag = True
while flag:
Expand All @@ -143,7 +145,9 @@ def fit(self, points, values):
elif isinstance(criteria, float): # stop criteria is float
if loss.item() < criteria:
flag = False
print(loss.item())
if frequency_print != 0:
if n_epoch % frequency_print == 1:
print(loss.item())
n_epoch += 1

def predict(self, new_point):
Expand Down Expand Up @@ -182,7 +186,6 @@ def load_state(self, filename, points, values):


self._build_model(points, values)
print(self.model)
self.optimizer = checkpoint['optimizer_class']

self.model.load_state_dict(checkpoint['model_state'])
Expand Down
1 change: 0 additions & 1 deletion ezyrb/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def add(self, parameters, snapshots, space=None):
raise RuntimeError('No Spatial Value given')

if (self._space is not None) or (space is not None):
print(space.shape, snapshots.shape)
if len(space) != len(snapshots) or len(space[0]) != len(snapshots[0]):
raise RuntimeError(
'length of space and snapshots are different.')
Expand Down
191 changes: 87 additions & 104 deletions ezyrb/nnspod.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ def __init__(self, method = "svd", path = None):


def reshape2dto1d(self, x, y):
"""
reshapes two n by n arrays into one n^2 by 2 array
:param numpy.array x: x value of data
:param numpy.array y: y value of data
"""
x = x.reshape(-1,1)
y = y.reshape(-1,1)
coords = np.concatenate((x, y), axis = 1)
Expand All @@ -24,75 +29,52 @@ def reshape2dto1d(self, x, y):
return coords

def reshape1dto2d(self, snapshots):
print(len(snapshots), snapshots.shape)
"""
turns 1d list of data into 2d
:param array-like snapshots: data to be reshaped
"""
return snapshots.reshape(int(np.sqrt(len(snapshots))), int(np.sqrt(len(snapshots))))


def train_InterpNet1d(self,ref_data, interp_layers, interp_function, interp_stop_training, interp_loss, retrain = False):
def train_interpnet(self,ref_data, interp_layers, interp_function, interp_stop_training, interp_loss, retrain = False, frequency_print = 0):
"""
trains the Interpnet given 1d data:
# print("loading")
:param database ref_data: the reference data that the rest of the data will be shifted to
:param list interp_layers: list with number of neurons in each layer
:param torch.nn.modules.activation interp_function: activation function for the interpnet
:param float interp_stop_training: desired tolerance for the interp training
:param torch.nn.Module interp_loss: loss function (MSE default)
:param boolean retrain: True if the interpNetShould be retrained, False if it should be loaded
"""

self.interp_net = ANN(interp_layers, interp_function, interp_stop_training, interp_loss)
if not retrain:
try:
self.interp_net = self.interp_net.load_state(self.path, ref_data.space.reshape(-1,1), ref_data.snapshots.reshape(-1,1))
print("loaded")
except:
self.interp_net.fit(ref_data.space.reshape(-1,1), ref_data.snapshots.reshape(-1,1))
self.interp_net.save_state(self.path)
print(self.interp_net.load_state(self.path, ref_data.space.reshape(-1,1), ref_data.snapshots.reshape(-1,1)))
if len(ref_data.space.shape) > 2:
space = ref_data.space.reshape(-1, 2)
else:
self.interp_net.fit(ref_data.space.reshape(-1,1), ref_data.snapshots.reshape(-1,1))
self.interp_net.save_state(self.path)
#plt.plot(ref_data.space, ref_data.snapshots, "o")
xi = np.linspace(0,5,1000).reshape(-1,1)
yi = self.interp_net.predict(xi)
print(xi.shape, yi.shape)
#plt.plot(xi,yi, ".")
#plt.show()


def train_InterpNet2d(self,ref_data, interp_layers, interp_function, interp_stop_training, interp_loss, retrain = False):


self.interp_net = ANN(interp_layers, interp_function, interp_stop_training, interp_loss)
space = ref_data.space.reshape(-1, 2)
snapshots = ref_data.snapshots.reshape(-1, 1)

space = ref_data.space.reshape(-1,1)
snapshots = ref_data.snapshots.reshape(-1,1)
if not retrain:
try:
self.interp_net = self.interp_net.load_state(self.path, space, snapshots)
print("loaded interpnet")
except:
self.interp_net.fit(space, snapshots)
self.interp_net.fit(space, snapshots, frequency_print = frequency_print)
self.interp_net.save_state(self.path)
else:
self.interp_net.fit(space, snapshots)
self.interp_net.fit(space, snapshots, frequency_print = frequency_print)
self.interp_net.save_state(self.path)

x = np.linspace(0, 5, 256)
y = np.linspace(0, 5, 256)
gridx, gridy = np.meshgrid(x, y)

plt.pcolor(gridx,gridy,ref_data.snapshots.reshape(256, 256))
plt.show()
res = 1000
x = np.linspace(0, 5, res)
y = np.linspace(0, 5, res)
gridx, gridy = np.meshgrid(x, y)
input = self.reshape2dto1d(gridx, gridy)
output = self.interp_net.predict(input)

toshow = self.reshape1dto2d(output)
plt.pcolor(gridx,gridy,toshow)
plt.show()




def shift(self, x, y, shift_quantity):
"""
shifts data by shift_quanity
"""
return(x+shift_quantity, y)

def pre_shift(self,x,y, ref_y):
"""
moves data so that the max of y and max of ref_y are at the same x coordinate
"""
maxy = 0
for i, n, in enumerate(y):
if n > y[maxy]:
Expand All @@ -102,10 +84,12 @@ def pre_shift(self,x,y, ref_y):
if n > ref_y[maxref]:
maxref = i

print( x[maxref]-x[maxy], maxref, maxy)
return self.shift(x, y, x[maxref]-x[maxy])[0]

def make_points(self, x, params):
"""
creates points that can be used to train and predict shiftnet
"""
if len(x.shape)> 1:
points = np.zeros((len(x),3))
for j, s in enumerate(x):
Expand All @@ -120,9 +104,11 @@ def make_points(self, x, params):
return points

def build_model(self, dim = 1):
"""
builds model based on dimension of input data
"""
layers = self.layers.copy()
layers.insert(0, dim + 1)
print(layers, "!!!!")
layers.append(dim)
layers_torch = []
for i in range(len(layers) - 2):
Expand All @@ -133,23 +119,40 @@ def build_model(self, dim = 1):



def train_ShiftNet1d(self, db, shift_layers, shift_function, shift_stop_training, ref_data, preshift = False):
# TODO:
# make sure neural net works no mater distance between data
# check and implement 2d functionality
# make code look better
def train_shiftnet(self, db, shift_layers, shift_function, shift_stop_training, ref_data, preshift = False):
"""
Trains and evaluates shiftnet given 1d data 'db'
:param Database db: data at a certain parameter value
:param list shift_layers: ordered list with number of neurons in each layer
:param torch.nn.modeulse.activation shift_function: the activation function used by the shiftnet
:param int, float, or list stop_training:
int: number of epochs before stopping
float: desired tolarance before stopping training
list: a int and a float, stops when either desired epochs or tolerance is reached
:param Database db: data at the reference datapoint
:param boolean preshift: True if preshift is desired otherwise false.
"""
self.layers = shift_layers
self.function = shift_function
self.loss_trend = []
if preshift:
x = self.pre_shift(db.space[0], db.snapshots[0], ref_data.snapshots[0])
else:
x = db.space[0]
if len(db.space.shape) > 2:
x_reshaped = x.reshape(-1,2)
self.build_model(dim = 2)
else:
self.build_model(dim = 1)
x_reshaped = x.reshape(-1,1)

values = db.snapshots.reshape(-1,1)

self.stop_training = shift_stop_training
points = self.make_points(x, db.parameters)
values = db.snapshots.reshape(-1,1)
self.build_model(dim = 1)



self.optimizer = torch.optim.Adam(self.model.parameters(), 0.0001)

Expand All @@ -160,12 +163,10 @@ def train_ShiftNet1d(self, db, shift_layers, shift_function, shift_stop_training
while flag:
shift = self.model(points)
x_shift, y = self.shift(
torch.from_numpy(x.reshape(-1,1)).float(),
torch.from_numpy(db.snapshots.reshape(-1,1)).float(),
torch.from_numpy(x_reshaped).float(),
torch.from_numpy(values).float(),
shift)
#print(x_shift,y)
ref_interp = self.interp_net.predict_tensor(x_shift)
#print(ref_interp)
ref_interp = self.interp_net.model(x_shift)
loss = self.loss(ref_interp, y)
print(loss.item())
loss.backward()
Expand All @@ -179,30 +180,34 @@ def train_ShiftNet1d(self, db, shift_layers, shift_function, shift_stop_training
if loss.item() < criteria:
flag = False
n_epoch += 1

new_point = self.make_points(x, db.parameters)
shift = self.model(torch.from_numpy(new_point).float())
x_new = self.shift(
torch.from_numpy(x.reshape(-1,1)).float(),
torch.from_numpy(db.snapshots.reshape(-1,1)).float(),
torch.from_numpy(x_reshaped).float(),
torch.from_numpy(values).float(),
shift)[0]

plt.plot(db.space, db.snapshots, "go")
plt.plot(x_new.detach().numpy(), db.snapshots.reshape(-1,1), ".")
return shift
x_ret = x_new.detach().numpy()
return x_ret

def train_ShiftNet2d(self, db, shift_layers, shift_function, shift_stop_training, ref_data, preshift = False):
# TODO:
# make sure neural net works no mater distance between data
# check and implement 2d functionality
# make code look better
# work on pre_shift for 2d data (iterate through all data until max is found)
# make sure shift works for 2d data(might only shift one part)
"""
Trains and evaluates shiftnet given 2d data 'db'
:param Database db: data at a certain parameter value
:param list shift_layers: ordered list with number of neurons in each layer
:param torch.nn.modeulse.activation shift_function: the activation function used by the shiftnet
:param int, float, or list stop_training:
int: number of epochs before stopping
float: desired tolarance before stopping training
list: a int and a float, stops when either desired epochs or tolerance is reached
:param Database db: data at the reference datapoint
:param boolean preshift: True if preshift is desired otherwise false.
"""
self.layers = shift_layers
self.function = shift_function
self.loss_trend = []
if preshift:
x = self.pre_shift(db.space[0], db.snapshots[0], ref_data.snapshots[0])
x = x_preshifted = self.pre_shift(db.space[0], db.snapshots[0], ref_data.snapshots[0])
else:
x = db.space[0]

Expand All @@ -222,9 +227,7 @@ def train_ShiftNet2d(self, db, shift_layers, shift_function, shift_stop_training
torch.from_numpy(x.reshape(-1,2)).float(),
torch.from_numpy(db.snapshots.reshape(-1,1)).float(),
shift)
#print(x_shift,y)
ref_interp = self.interp_net.predict_tensor(x_shift)
#print(ref_interp)
ref_interp = self.interp_net.model(x_shift)
loss = self.loss(ref_interp, y)
print(loss.item())
loss.backward()
Expand All @@ -239,30 +242,10 @@ def train_ShiftNet2d(self, db, shift_layers, shift_function, shift_stop_training
flag = False
n_epoch += 1


x = np.linspace(0, 5, 256)
y = np.linspace(0, 5, 256)
gridx, gridy = np.meshgrid(x, y)

plt.pcolor(gridx,gridy,ref_data.snapshots.reshape(256, 256))
plt.show()
res = 256
x = np.linspace(0, 5, res)
y = np.linspace(0, 5, res)
gridx, gridy = np.meshgrid(x, y)
coords = self.reshape2dto1d(gridx, gridy)
new_point = self.make_points(coords, db.parameters)
new_point = self.make_points(x_preshifted, db.parameters)
shift = self.model(torch.from_numpy(new_point).float())
x_new = self.shift(
torch.from_numpy(coords.reshape(-1,2)).float(),
torch.from_numpy(x_preshifted.reshape(-1,2)).float(),
torch.from_numpy(db.snapshots.reshape(-1,1)).float(),
shift)[0]
print(x_new.shape)
x, y = np.hsplit(x_new.detach().numpy(), 2)
x = self.reshape1dto2d(x)
y = self.reshape1dto2d(y)
snapshots = self.reshape1dto2d(db.snapshots.reshape(-1,1))
print(x.shape, y.shape)
plt.pcolor(x,y,snapshots)
plt.show()
return shift
return x_new
Binary file added tutorials/interpnet1d.pth
Binary file not shown.
Binary file added tutorials/interpnet2d.pth
Binary file not shown.
Loading

0 comments on commit 5b577bd

Please sign in to comment.