Skip to content

Commit

Permalink
Merge pull request #16 from Oisin-M/feat/vary_activation
Browse files Browse the repository at this point in the history
Allow activation functions to be specified by strings
  • Loading branch information
fpichi authored Apr 12, 2024
2 parents 5e2e881 + ec27c5e commit 236d3ca
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions gca_rom/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from torch import nn
from gca_rom import gca, scaling
import torch.nn.functional as F


class HyperParams:
Expand Down Expand Up @@ -51,7 +52,7 @@ def __init__(self, argv, **kwargs):
self.seed = 10
self.tolerance = 1e-6
self.learning_rate = 0.001
self.act = torch.tanh
self.map_act = 'tanh'
self.layer_vec=[argv[11], self.nodes, self.nodes, self.nodes, self.nodes, self.bottleneck_dim]
self.net_run = '_' + self.scaler_name
self.weight_decay = 0.00001
Expand All @@ -62,14 +63,16 @@ def __init__(self, argv, **kwargs):
self.gamma = 0.0001
self.num_nodes = 0
self.conv = 'GMMConv'
self.ae_act = 'elu'
self.batch_size = np.inf
self.minibatch = False
self.net_dir = './' + self.net_name + '/' + self.net_run + '/' + self.variable + '_' + self.net_name + '_lmap' + str(self.lambda_map) + '_btt' + str(self.bottleneck_dim) \
+ '_seed' + str(self.seed) + '_lv' + str(len(self.layer_vec)-2) + '_hc' + str(len(self.hidden_channels)) + '_nd' + str(self.nodes) \
+ '_ffn' + str(self.ffn) + '_skip' + str(self.skip) + '_lr' + str(self.learning_rate) + '_sc' + str(self.scaling_type) + '_rate' + str(self.rate) + '_conv' + self.conv + '/'
self.cross_validation = True


def get_activation(act_str):
return getattr(F, act_str)

class Net(torch.nn.Module):
"""
Expand Down Expand Up @@ -111,10 +114,10 @@ class Net(torch.nn.Module):

def __init__(self, HyperParams):
super().__init__()
self.encoder = gca.Encoder(HyperParams.hidden_channels, HyperParams.bottleneck_dim, HyperParams.num_nodes, ffn=HyperParams.ffn, skip=HyperParams.skip, conv=HyperParams.conv)
self.decoder = gca.Decoder(HyperParams.hidden_channels, HyperParams.bottleneck_dim, HyperParams.num_nodes, ffn=HyperParams.ffn, skip=HyperParams.skip, conv=HyperParams.conv)
self.encoder = gca.Encoder(HyperParams.hidden_channels, HyperParams.bottleneck_dim, HyperParams.num_nodes, ffn=HyperParams.ffn, skip=HyperParams.skip, act=get_activation(HyperParams.ae_act), conv=HyperParams.conv)
self.decoder = gca.Decoder(HyperParams.hidden_channels, HyperParams.bottleneck_dim, HyperParams.num_nodes, ffn=HyperParams.ffn, skip=HyperParams.skip, act=get_activation(HyperParams.ae_act), conv=HyperParams.conv)

self.act_map = HyperParams.act
self.act_map = get_activation(HyperParams.map_act)
self.layer_vec = HyperParams.layer_vec
self.steps = len(self.layer_vec) - 1

Expand Down

0 comments on commit 236d3ca

Please sign in to comment.