From 06f15b77b63996eaec25ac547d403de2cd2e5e1f Mon Sep 17 00:00:00 2001 From: Skylark Date: Wed, 9 Aug 2023 19:32:09 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=80=20[RofuncRL]=20State=20encoder=20c?= =?UTF-8?q?an=20freeze=20params?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../RofuncRL/agents/mixline/amp_agent.py | 8 +- .../RofuncRL/agents/mixline/ase_hrl_agent.py | 8 +- .../RofuncRL/agents/online/a2c_agent.py | 8 +- .../RofuncRL/agents/online/ppo_agent.py | 8 +- .../RofuncRL/agents/online/sac_agent.py | 8 +- .../RofuncRL/agents/online/td3_agent.py | 8 +- .../RofuncRL/state_encoders/base_encoders.py | 45 +++++++++- .../state_encoders/visual_encoders.py | 83 ++++++++++++++----- 8 files changed, 124 insertions(+), 52 deletions(-) diff --git a/rofunc/learning/RofuncRL/agents/mixline/amp_agent.py b/rofunc/learning/RofuncRL/agents/mixline/amp_agent.py index 9d820399..dafc7e09 100644 --- a/rofunc/learning/RofuncRL/agents/mixline/amp_agent.py +++ b/rofunc/learning/RofuncRL/agents/mixline/amp_agent.py @@ -72,11 +72,9 @@ def __init__(self, self.collect_reference_motions = collect_reference_motions '''Define models for AMP''' - if hasattr(cfg.Model, "state_encoder"): - se_type = cfg.Model.state_encoder.encoder_type - self.se = encoder_map[se_type](cfg.Model).to(self.device) - else: - self.se = EmptyEncoder() + se_type = cfg.Model.state_encoder.encoder_type + self.se = encoder_map[se_type](cfg.Model) if hasattr(cfg.Model, "state_encoder") else EmptyEncoder() + self.se.to(self.device) self.policy = ActorAMP(cfg.Model, observation_space, action_space, self.se).to(self.device) self.value = Critic(cfg.Model, observation_space, action_space, self.se).to(self.device) diff --git a/rofunc/learning/RofuncRL/agents/mixline/ase_hrl_agent.py b/rofunc/learning/RofuncRL/agents/mixline/ase_hrl_agent.py index da46ad20..60f6c42b 100644 --- a/rofunc/learning/RofuncRL/agents/mixline/ase_hrl_agent.py +++ b/rofunc/learning/RofuncRL/agents/mixline/ase_hrl_agent.py @@ -77,11 +77,9 @@ def __init__(self, super().__init__(cfg, observation_space, action_space, memory, device, experiment_dir, rofunc_logger) '''Define models for ASE HRL agent''' - if hasattr(cfg.Model, "state_encoder"): - se_type = cfg.Model.state_encoder.encoder_type - self.se = encoder_map[se_type](cfg.Model).to(self.device) - else: - self.se = EmptyEncoder() + se_type = cfg.Model.state_encoder.encoder_type + self.se = encoder_map[se_type](cfg.Model) if hasattr(cfg.Model, "state_encoder") else EmptyEncoder() + self.se.to(self.device) if self.cfg.Model.actor.type == "Beta": self.policy = ActorPPO_Beta(cfg.Model, observation_space, self._ase_latent_dim, self.se).to(self.device) diff --git a/rofunc/learning/RofuncRL/agents/online/a2c_agent.py b/rofunc/learning/RofuncRL/agents/online/a2c_agent.py index 16320ac8..6d465e91 100644 --- a/rofunc/learning/RofuncRL/agents/online/a2c_agent.py +++ b/rofunc/learning/RofuncRL/agents/online/a2c_agent.py @@ -58,11 +58,9 @@ def __init__(self, super().__init__(cfg, observation_space, action_space, memory, device, experiment_dir, rofunc_logger) '''Define models for A2C''' - if hasattr(cfg.Model, "state_encoder"): - se_type = cfg.Model.state_encoder.encoder_type - self.se = encoder_map[se_type](cfg.Model).to(self.device) - else: - self.se = EmptyEncoder() + se_type = cfg.Model.state_encoder.encoder_type + self.se = encoder_map[se_type](cfg.Model) if hasattr(cfg.Model, "state_encoder") else EmptyEncoder() + self.se.to(self.device) self.policy = ActorPPO_Gaussian(cfg.Model, observation_space, action_space, self.se).to(self.device) self.value = Critic(cfg.Model, observation_space, action_space, self.se).to(self.device) diff --git a/rofunc/learning/RofuncRL/agents/online/ppo_agent.py b/rofunc/learning/RofuncRL/agents/online/ppo_agent.py index 8fff9b45..3e658cb1 100644 --- a/rofunc/learning/RofuncRL/agents/online/ppo_agent.py +++ b/rofunc/learning/RofuncRL/agents/online/ppo_agent.py @@ -56,11 +56,9 @@ def __init__(self, super().__init__(cfg, observation_space, action_space, memory, device, experiment_dir, rofunc_logger) '''Define models for PPO''' - if hasattr(cfg.Model, "state_encoder"): - se_type = cfg.Model.state_encoder.encoder_type - self.se = encoder_map[se_type](cfg.Model).to(self.device) - else: - self.se = EmptyEncoder() + se_type = cfg.Model.state_encoder.encoder_type + self.se = encoder_map[se_type](cfg.Model) if hasattr(cfg.Model, "state_encoder") else EmptyEncoder() + self.se.to(self.device) if self.cfg.Model.actor.type == "Beta": self.policy = ActorPPO_Beta(cfg.Model, observation_space, action_space, self.se).to(self.device) diff --git a/rofunc/learning/RofuncRL/agents/online/sac_agent.py b/rofunc/learning/RofuncRL/agents/online/sac_agent.py index e2f7c9d8..ba9ef3d7 100644 --- a/rofunc/learning/RofuncRL/agents/online/sac_agent.py +++ b/rofunc/learning/RofuncRL/agents/online/sac_agent.py @@ -59,11 +59,9 @@ def __init__(self, super().__init__(cfg, observation_space, action_space, memory, device, experiment_dir, rofunc_logger) '''Define models for SAC''' - if hasattr(cfg.Model, "state_encoder"): - se_type = cfg.Model.state_encoder.encoder_type - self.se = encoder_map[se_type](cfg.Model).to(self.device) - else: - self.se = EmptyEncoder() + se_type = cfg.Model.state_encoder.encoder_type + self.se = encoder_map[se_type](cfg.Model) if hasattr(cfg.Model, "state_encoder") else EmptyEncoder() + self.se.to(self.device) concat_space = [observation_space, action_space] self.actor = ActorSAC(cfg.Model, observation_space, action_space, self.se).to(self.device) diff --git a/rofunc/learning/RofuncRL/agents/online/td3_agent.py b/rofunc/learning/RofuncRL/agents/online/td3_agent.py index c4b6491a..22c40357 100644 --- a/rofunc/learning/RofuncRL/agents/online/td3_agent.py +++ b/rofunc/learning/RofuncRL/agents/online/td3_agent.py @@ -59,11 +59,9 @@ def __init__(self, super().__init__(cfg, observation_space, action_space, memory, device, experiment_dir, rofunc_logger) '''Define models for TD3''' - if hasattr(cfg.Model, "state_encoder"): - se_type = cfg.Model.state_encoder.encoder_type - self.se = encoder_map[se_type](cfg.Model).to(self.device) - else: - self.se = EmptyEncoder() + se_type = cfg.Model.state_encoder.encoder_type + self.se = encoder_map[se_type](cfg.Model) if hasattr(cfg.Model, "state_encoder") else EmptyEncoder() + self.se.to(self.device) concat_space = [observation_space, action_space] self.actor = ActorTD3(cfg.Model, observation_space, action_space, self.se).to(self.device) diff --git a/rofunc/learning/RofuncRL/state_encoders/base_encoders.py b/rofunc/learning/RofuncRL/state_encoders/base_encoders.py index e590e3dc..7aaeb669 100644 --- a/rofunc/learning/RofuncRL/state_encoders/base_encoders.py +++ b/rofunc/learning/RofuncRL/state_encoders/base_encoders.py @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ - +import torch import torch.nn as nn from omegaconf import DictConfig @@ -40,6 +40,49 @@ def __init__(self, cfg: DictConfig, cfg_name: str = 'state_encoder'): self.input_dim = self.cfg_dict[cfg_name]['inp_channels'] self.output_dim = self.cfg_dict[cfg_name]['out_channels'] + self.use_pretrained = self.cfg_dict[cfg_name]['use_pretrained'] + self.freeze = self.cfg_dict[cfg_name]['freeze'] + self.model_ckpt = self.cfg_dict[cfg_name]['model_ckpt'] + + def set_up(self): + if self.freeze: + self.freeze_network() + if self.use_pretrained: + self.pre_trained_mode() + + def freeze_network(self): + for net in self.freeze_net_list: + for param in net.parameters(): + param.requires_grad = False + + def pre_trained_mode(self): + if self.use_pretrained is True and self.model_ckpt is None: + raise ValueError("Cannot freeze the encoder without a checkpoint") + if self.use_pretrained: + self.load_ckpt(self.model_ckpt) + + def _get_internal_value(self, module): + return module.state_dict() if hasattr(module, "state_dict") else module + + def save_ckpt(self, path: str): + modules = {} + for name, module in self.checkpoint_modules.items(): + modules[name] = self._get_internal_value(module) + torch.save(modules, path) + + def load_ckpt(self, path: str): + modules = torch.load(path) + if type(modules) is dict: + for name, data in modules.items(): + module = self.checkpoint_modules.get(name, None) + if module is not None: + if hasattr(module, "load_state_dict"): + module.load_state_dict(data) + if hasattr(module, "eval"): + module.eval() + else: + raise NotImplementedError + class MLPEncoder(BaseMLP): def __init__(self, cfg, cfg_name): diff --git a/rofunc/learning/RofuncRL/state_encoders/visual_encoders.py b/rofunc/learning/RofuncRL/state_encoders/visual_encoders.py index 7bdf823d..2c5f0045 100644 --- a/rofunc/learning/RofuncRL/state_encoders/visual_encoders.py +++ b/rofunc/learning/RofuncRL/state_encoders/visual_encoders.py @@ -47,10 +47,15 @@ def __init__(self, cfg): self.output_net = build_mlp(dims=[self.mlp_inp_dims, *self.mlp_hidden_dims, self.output_dim], hidden_activation=self.mlp_activation) + self.checkpoint_modules = {'backbone_net': self.backbone_net, 'output_net': self.output_net} + if self.cfg.use_init: init_layers(self.backbone_net, gain=1.0, init_type='kaiming_uniform') init_layers(self.output_net, gain=1.0) + self.freeze_net_list = [self.backbone_net, self.output_net] + self.set_up() + def forward(self, inputs): x = self.backbone_net(inputs) x = self.flatten(x) @@ -69,6 +74,9 @@ def __init__(self, cfg): self.backbone_net = torch.hub.load('pytorch/vision', self.sub_type, num_classes=self.output_dim, pretrained=self.use_pretrained) + self.freeze_net_list = [self.backbone_net] + self.set_up() + def forward(self, inputs): x = self.backbone_net(inputs) return x @@ -85,6 +93,9 @@ def __init__(self, cfg): self.backbone_net = torch.hub.load('pytorch/vision', self.sub_type, num_classes=self.output_dim, pretrained=self.use_pretrained) + self.freeze_net_list = [self.backbone_net] + self.set_up() + def forward(self, inputs): x = self.backbone_net(inputs) return x @@ -92,22 +103,32 @@ def forward(self, inputs): if __name__ == '__main__': from omegaconf import DictConfig - - # cfg = DictConfig({'use_init': True, 'state_encoder': {'inp_channels': 4, 'out_channels': 512, - # 'cnn_args': { - # 'cnn_structure': ['conv', 'relu', 'conv', 'relu', 'pool'], - # 'cnn_kernel_size': [8, 4], - # 'cnn_stride': 1, - # 'cnn_padding': 1, - # 'cnn_dilation': 1, - # 'cnn_hidden_dims': [32, 64], - # 'cnn_activation': 'relu', - # 'mlp_inp_dims': 2304, - # 'mlp_hidden_dims': [128], - # 'mlp_activation': 'relu', - # }}}) - # model = CNNEncoder(cfg=cfg) - # print(model) + import rofunc as rf + + cfg = DictConfig({'use_init': True, 'state_encoder': {'inp_channels': 4, 'out_channels': 512, + 'use_pretrained': False, + 'freeze': False, + 'model_ckpt': 'test.ckpt', + 'cnn_args': { + 'cnn_structure': ['conv', 'relu', 'conv', 'relu', 'pool'], + 'cnn_kernel_size': [8, 4], + 'cnn_stride': 1, + 'cnn_padding': 1, + 'cnn_dilation': 1, + 'cnn_hidden_dims': [32, 64], + 'cnn_activation': 'relu', + 'cnn_pooling': None, # ['max', 'avg'] + 'cnn_pooling_args': { + 'cnn_pooling_kernel_size': 2, + 'cnn_pooling_stride': 2, + 'cnn_pooling_padding': 0, + 'cnn_pooling_dilation': 1}, + 'mlp_inp_dims': 215296, + 'mlp_hidden_dims': [512], + 'mlp_activation': 'relu', + }}}) + model = CNNEncoder(cfg=cfg).to('cuda:0') + print(model) # cfg = DictConfig( # {'use_init': True, @@ -116,8 +137,28 @@ def forward(self, inputs): # model = ResnetEncoder(cfg=cfg) # print(model) - cfg = DictConfig({'use_init': True, - 'state_encoder': {'inp_channels': 4, 'out_channels': 512, 'vit_args': {'sub_type': 'vit_b_16'}, - 'use_pretrained': False}}) - model = ViTEncoder(cfg=cfg) - print(model) + # cfg = DictConfig({'use_init': True, + # 'state_encoder': {'inp_channels': 4, 'out_channels': 512, 'vit_args': {'sub_type': 'vit_b_16'}, + # 'use_pretrained': False}}) + # model = ViTEncoder(cfg=cfg) + # print(model) + + inp_latent_vector = torch.randn(32, 4, 64, 64).to('cuda:0') # [B, C, H, W] + gt_latent_vector = torch.randn(32, 512).to('cuda:0') + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + for i in range(10000): + out_latent_vector = model(inp_latent_vector) + # predicted_latent_vector = out_latent_vector.last_hidden_state[:, -1:, :] + # print(out_latent_vector) + + loss = nn.MSELoss()(out_latent_vector, gt_latent_vector) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if i % 100 == 0: + model.save_ckpt('test.ckpt') + rf.utils.beauty_print('Save ckpt') + + print(loss)