Skip to content

Commit

Permalink
🚀 [RofuncRL] State encoder can freeze params
Browse files Browse the repository at this point in the history
  • Loading branch information
Skylark0924 committed Aug 9, 2023
1 parent 98e1dd5 commit 06f15b7
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 52 deletions.
8 changes: 3 additions & 5 deletions rofunc/learning/RofuncRL/agents/mixline/amp_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,9 @@ def __init__(self,
self.collect_reference_motions = collect_reference_motions

'''Define models for AMP'''
if hasattr(cfg.Model, "state_encoder"):
se_type = cfg.Model.state_encoder.encoder_type
self.se = encoder_map[se_type](cfg.Model).to(self.device)
else:
self.se = EmptyEncoder()
se_type = cfg.Model.state_encoder.encoder_type
self.se = encoder_map[se_type](cfg.Model) if hasattr(cfg.Model, "state_encoder") else EmptyEncoder()
self.se.to(self.device)

self.policy = ActorAMP(cfg.Model, observation_space, action_space, self.se).to(self.device)
self.value = Critic(cfg.Model, observation_space, action_space, self.se).to(self.device)
Expand Down
8 changes: 3 additions & 5 deletions rofunc/learning/RofuncRL/agents/mixline/ase_hrl_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,9 @@ def __init__(self,
super().__init__(cfg, observation_space, action_space, memory, device, experiment_dir, rofunc_logger)

'''Define models for ASE HRL agent'''
if hasattr(cfg.Model, "state_encoder"):
se_type = cfg.Model.state_encoder.encoder_type
self.se = encoder_map[se_type](cfg.Model).to(self.device)
else:
self.se = EmptyEncoder()
se_type = cfg.Model.state_encoder.encoder_type
self.se = encoder_map[se_type](cfg.Model) if hasattr(cfg.Model, "state_encoder") else EmptyEncoder()
self.se.to(self.device)

if self.cfg.Model.actor.type == "Beta":
self.policy = ActorPPO_Beta(cfg.Model, observation_space, self._ase_latent_dim, self.se).to(self.device)
Expand Down
8 changes: 3 additions & 5 deletions rofunc/learning/RofuncRL/agents/online/a2c_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,9 @@ def __init__(self,
super().__init__(cfg, observation_space, action_space, memory, device, experiment_dir, rofunc_logger)

'''Define models for A2C'''
if hasattr(cfg.Model, "state_encoder"):
se_type = cfg.Model.state_encoder.encoder_type
self.se = encoder_map[se_type](cfg.Model).to(self.device)
else:
self.se = EmptyEncoder()
se_type = cfg.Model.state_encoder.encoder_type
self.se = encoder_map[se_type](cfg.Model) if hasattr(cfg.Model, "state_encoder") else EmptyEncoder()
self.se.to(self.device)

self.policy = ActorPPO_Gaussian(cfg.Model, observation_space, action_space, self.se).to(self.device)
self.value = Critic(cfg.Model, observation_space, action_space, self.se).to(self.device)
Expand Down
8 changes: 3 additions & 5 deletions rofunc/learning/RofuncRL/agents/online/ppo_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,9 @@ def __init__(self,
super().__init__(cfg, observation_space, action_space, memory, device, experiment_dir, rofunc_logger)

'''Define models for PPO'''
if hasattr(cfg.Model, "state_encoder"):
se_type = cfg.Model.state_encoder.encoder_type
self.se = encoder_map[se_type](cfg.Model).to(self.device)
else:
self.se = EmptyEncoder()
se_type = cfg.Model.state_encoder.encoder_type
self.se = encoder_map[se_type](cfg.Model) if hasattr(cfg.Model, "state_encoder") else EmptyEncoder()
self.se.to(self.device)

if self.cfg.Model.actor.type == "Beta":
self.policy = ActorPPO_Beta(cfg.Model, observation_space, action_space, self.se).to(self.device)
Expand Down
8 changes: 3 additions & 5 deletions rofunc/learning/RofuncRL/agents/online/sac_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,9 @@ def __init__(self,
super().__init__(cfg, observation_space, action_space, memory, device, experiment_dir, rofunc_logger)

'''Define models for SAC'''
if hasattr(cfg.Model, "state_encoder"):
se_type = cfg.Model.state_encoder.encoder_type
self.se = encoder_map[se_type](cfg.Model).to(self.device)
else:
self.se = EmptyEncoder()
se_type = cfg.Model.state_encoder.encoder_type
self.se = encoder_map[se_type](cfg.Model) if hasattr(cfg.Model, "state_encoder") else EmptyEncoder()
self.se.to(self.device)

concat_space = [observation_space, action_space]
self.actor = ActorSAC(cfg.Model, observation_space, action_space, self.se).to(self.device)
Expand Down
8 changes: 3 additions & 5 deletions rofunc/learning/RofuncRL/agents/online/td3_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,9 @@ def __init__(self,
super().__init__(cfg, observation_space, action_space, memory, device, experiment_dir, rofunc_logger)

'''Define models for TD3'''
if hasattr(cfg.Model, "state_encoder"):
se_type = cfg.Model.state_encoder.encoder_type
self.se = encoder_map[se_type](cfg.Model).to(self.device)
else:
self.se = EmptyEncoder()
se_type = cfg.Model.state_encoder.encoder_type
self.se = encoder_map[se_type](cfg.Model) if hasattr(cfg.Model, "state_encoder") else EmptyEncoder()
self.se.to(self.device)

concat_space = [observation_space, action_space]
self.actor = ActorTD3(cfg.Model, observation_space, action_space, self.se).to(self.device)
Expand Down
45 changes: 44 additions & 1 deletion rofunc/learning/RofuncRL/state_encoders/base_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import torch
import torch.nn as nn
from omegaconf import DictConfig

Expand All @@ -40,6 +40,49 @@ def __init__(self, cfg: DictConfig, cfg_name: str = 'state_encoder'):
self.input_dim = self.cfg_dict[cfg_name]['inp_channels']
self.output_dim = self.cfg_dict[cfg_name]['out_channels']

self.use_pretrained = self.cfg_dict[cfg_name]['use_pretrained']
self.freeze = self.cfg_dict[cfg_name]['freeze']
self.model_ckpt = self.cfg_dict[cfg_name]['model_ckpt']

def set_up(self):
if self.freeze:
self.freeze_network()
if self.use_pretrained:
self.pre_trained_mode()

def freeze_network(self):
for net in self.freeze_net_list:
for param in net.parameters():
param.requires_grad = False

def pre_trained_mode(self):
if self.use_pretrained is True and self.model_ckpt is None:
raise ValueError("Cannot freeze the encoder without a checkpoint")
if self.use_pretrained:
self.load_ckpt(self.model_ckpt)

def _get_internal_value(self, module):
return module.state_dict() if hasattr(module, "state_dict") else module

def save_ckpt(self, path: str):
modules = {}
for name, module in self.checkpoint_modules.items():
modules[name] = self._get_internal_value(module)
torch.save(modules, path)

def load_ckpt(self, path: str):
modules = torch.load(path)
if type(modules) is dict:
for name, data in modules.items():
module = self.checkpoint_modules.get(name, None)
if module is not None:
if hasattr(module, "load_state_dict"):
module.load_state_dict(data)
if hasattr(module, "eval"):
module.eval()
else:
raise NotImplementedError


class MLPEncoder(BaseMLP):
def __init__(self, cfg, cfg_name):
Expand Down
83 changes: 62 additions & 21 deletions rofunc/learning/RofuncRL/state_encoders/visual_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,15 @@ def __init__(self, cfg):
self.output_net = build_mlp(dims=[self.mlp_inp_dims, *self.mlp_hidden_dims, self.output_dim],
hidden_activation=self.mlp_activation)

self.checkpoint_modules = {'backbone_net': self.backbone_net, 'output_net': self.output_net}

if self.cfg.use_init:
init_layers(self.backbone_net, gain=1.0, init_type='kaiming_uniform')
init_layers(self.output_net, gain=1.0)

self.freeze_net_list = [self.backbone_net, self.output_net]
self.set_up()

def forward(self, inputs):
x = self.backbone_net(inputs)
x = self.flatten(x)
Expand All @@ -69,6 +74,9 @@ def __init__(self, cfg):
self.backbone_net = torch.hub.load('pytorch/vision', self.sub_type, num_classes=self.output_dim,
pretrained=self.use_pretrained)

self.freeze_net_list = [self.backbone_net]
self.set_up()

def forward(self, inputs):
x = self.backbone_net(inputs)
return x
Expand All @@ -85,29 +93,42 @@ def __init__(self, cfg):
self.backbone_net = torch.hub.load('pytorch/vision', self.sub_type, num_classes=self.output_dim,
pretrained=self.use_pretrained)

self.freeze_net_list = [self.backbone_net]
self.set_up()

def forward(self, inputs):
x = self.backbone_net(inputs)
return x


if __name__ == '__main__':
from omegaconf import DictConfig

# cfg = DictConfig({'use_init': True, 'state_encoder': {'inp_channels': 4, 'out_channels': 512,
# 'cnn_args': {
# 'cnn_structure': ['conv', 'relu', 'conv', 'relu', 'pool'],
# 'cnn_kernel_size': [8, 4],
# 'cnn_stride': 1,
# 'cnn_padding': 1,
# 'cnn_dilation': 1,
# 'cnn_hidden_dims': [32, 64],
# 'cnn_activation': 'relu',
# 'mlp_inp_dims': 2304,
# 'mlp_hidden_dims': [128],
# 'mlp_activation': 'relu',
# }}})
# model = CNNEncoder(cfg=cfg)
# print(model)
import rofunc as rf

cfg = DictConfig({'use_init': True, 'state_encoder': {'inp_channels': 4, 'out_channels': 512,
'use_pretrained': False,
'freeze': False,
'model_ckpt': 'test.ckpt',
'cnn_args': {
'cnn_structure': ['conv', 'relu', 'conv', 'relu', 'pool'],
'cnn_kernel_size': [8, 4],
'cnn_stride': 1,
'cnn_padding': 1,
'cnn_dilation': 1,
'cnn_hidden_dims': [32, 64],
'cnn_activation': 'relu',
'cnn_pooling': None, # ['max', 'avg']
'cnn_pooling_args': {
'cnn_pooling_kernel_size': 2,
'cnn_pooling_stride': 2,
'cnn_pooling_padding': 0,
'cnn_pooling_dilation': 1},
'mlp_inp_dims': 215296,
'mlp_hidden_dims': [512],
'mlp_activation': 'relu',
}}})
model = CNNEncoder(cfg=cfg).to('cuda:0')
print(model)

# cfg = DictConfig(
# {'use_init': True,
Expand All @@ -116,8 +137,28 @@ def forward(self, inputs):
# model = ResnetEncoder(cfg=cfg)
# print(model)

cfg = DictConfig({'use_init': True,
'state_encoder': {'inp_channels': 4, 'out_channels': 512, 'vit_args': {'sub_type': 'vit_b_16'},
'use_pretrained': False}})
model = ViTEncoder(cfg=cfg)
print(model)
# cfg = DictConfig({'use_init': True,
# 'state_encoder': {'inp_channels': 4, 'out_channels': 512, 'vit_args': {'sub_type': 'vit_b_16'},
# 'use_pretrained': False}})
# model = ViTEncoder(cfg=cfg)
# print(model)

inp_latent_vector = torch.randn(32, 4, 64, 64).to('cuda:0') # [B, C, H, W]
gt_latent_vector = torch.randn(32, 512).to('cuda:0')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for i in range(10000):
out_latent_vector = model(inp_latent_vector)
# predicted_latent_vector = out_latent_vector.last_hidden_state[:, -1:, :]
# print(out_latent_vector)

loss = nn.MSELoss()(out_latent_vector, gt_latent_vector)
optimizer.zero_grad()
loss.backward()
optimizer.step()

if i % 100 == 0:
model.save_ckpt('test.ckpt')
rf.utils.beauty_print('Save ckpt')

print(loss)

0 comments on commit 06f15b7

Please sign in to comment.