-
Notifications
You must be signed in to change notification settings - Fork 0
/
ppo.py
88 lines (73 loc) · 2.68 KB
/
ppo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
import argparse
import numpy as np
import gym
from gym.wrappers import TimeLimit, Monitor
from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Space
import time
import random
import os
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnvWrapper
if __name__ == "__main__":
pass
class MultiEnv(VecEnvWrapper):
def __init__(self, venv, device):
super().__init__(venv)
self.device = device
def reset(self):
obs = self.venv.reset()
obs = torch.from_numpy(obs).float().to(self.device)
def step_async(self, actions):
actions = actions.cpu().numpy()
self.venv.step_async(actions)
def step_wait(self):
obs, reward, done, info = self.venv.step_wait()
obs = torch.from_numpy(obs).float().to(self.device)
reward = torch.from_numpy(reward).unsqueeze(dim=1).float()
return obs, reward, done, info
def make_env(gym_id, seed, idx):
env = gym.make(gym_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
#! add video capture
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env
class Agent(nn.Module):
def __init__(self, envs):
super().__init__()
self.actor = nn.Sequential(
self.layer_init(nn.Linear(np.array(envs.observation_space.shape).prod(), 64)),
nn.Tanh(),
self.layer_init(nn.Linear(64, 64)),
nn.Tanh(),
self.layer_init(nn.Linear(64, envs.action_space.n), std=0.01),
)
self.critic = nn.Sequential(
self.layer_init(nn.Linear(np.array(envs.observation_space.shape).prod(), 64)),
nn.Tanh(),
self.layer_init(nn.Linear(64, 64)),
nn.Tanh(),
self.layer_init(nn.Linear(64, 1), std=1.),
)
def layer_init(self, layer, std=np.sqrt(2), bias_const=0.0):
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer
def get_action(self, x, action=None):
logits = self.actor(x)
probs = Categorical(logits=logits)
# if action is not specified, we select it stochastically
if action is None:
action = probs.sample()
return action, probs.log_prob(action), probs.entropy()
def get_value(self, x):
return self.critic(x)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
env = gym.make("CartPole-v0")
agent = Agent(env).to(device)