Skip to content

Commit

Permalink
[Week 08] Update & clean up the notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
dniku committed Apr 2, 2023
1 parent 9137d39 commit 47c3325
Showing 1 changed file with 54 additions and 38 deletions.
92 changes: 54 additions & 38 deletions week08_pomdp/practice_pytorch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,25 @@
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"if 'google.colab' in sys.modules:\n",
" !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/0ccb0673965dd650d9b284e1ec90c2bfd82c8a94/week08_pomdp/atari_util.py\n",
" !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/0ccb0673965dd650d9b284e1ec90c2bfd82c8a94/week08_pomdp/env_pool.py\n",
"\n",
" !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n",
" !touch .setup_complete\n",
"# If you are running on a server, launch xvfb to record game videos\n",
"# Please make sure you have xvfb installed\n",
"import os\n",
"if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n",
" !bash ../xvfb start\n",
" os.environ['DISPLAY'] = ':1'"
"import sys, os\n",
"if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n",
" # Install xvfb and our launcher script for it\n",
" !apt-get install -y xvfb\n",
" !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/xvfb -O ../xvfb\n",
"\n",
" !pip install gym[atari,accept-rom-license]\n",
"\n",
" !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week08_pomdp/atari_util.py\n",
" !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week08_pomdp/env_pool.py\n",
"\n",
" !touch .setup_complete\n",
"\n",
"# This code creates a virtual display to draw game images on.\n",
"# It will have no effect if your machine has a monitor.\n",
"import os\n",
"if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n",
" !bash ../xvfb start\n",
" os.environ['DISPLAY'] = ':1'"
]
},
{
Expand Down Expand Up @@ -53,7 +59,6 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.\u001b[0m\n",
"Observation shape: (1, 42, 42)\n",
"Num actions: 14\n",
"Action names: ['NOOP', 'UP', 'RIGHT', 'LEFT', 'DOWN', 'DOWNRIGHT', 'DOWNLEFT', 'RIGHTFIRE', 'LEFTFIRE', 'DOWNFIRE', 'UPRIGHTFIRE', 'UPLEFTFIRE', 'DOWNRIGHTFIRE', 'DOWNLEFTFIRE']\n"
Expand All @@ -70,6 +75,7 @@
" env = PreprocessAtari(env, height=42, width=42,\n",
" crop=lambda img: img[60:-30, 15:],\n",
" color=False, n_frames=1)\n",
" env.metadata['render_fps'] = 30\n",
" return env\n",
"\n",
"\n",
Expand Down Expand Up @@ -143,7 +149,7 @@
"\n",
"Let's design another agent that has a recurrent neural net memory to solve this. Here's a sketch.\n",
"\n",
"![img](img1.jpg)\n"
"![img](https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week08_pomdp/img1.jpg)\n"
]
},
{
Expand Down Expand Up @@ -204,13 +210,15 @@
" return new_state, (logits, state_value)\n",
"\n",
" def get_initial_state(self, batch_size):\n",
" \"\"\"Return a list of agent memory states at game start. Each state is a np array of shape [batch_size, ...]\"\"\"\n",
" return torch.zeros((batch_size, 128)), torch.zeros((batch_size, 128))\n",
" \"\"\"Return the agent memory state at the beginning of the game. Each state is a np array of shape [batch_size, ...]\"\"\"\n",
" h0 = torch.zeros((batch_size, 128))\n",
" c0 = torch.zeros((batch_size, 128))\n",
" return h0, c0\n",
"\n",
" def sample_actions(self, agent_outputs):\n",
" \"\"\"pick actions given numeric agent outputs (np arrays)\"\"\"\n",
" logits, state_values = agent_outputs\n",
" probs = F.softmax(logits)\n",
" probs = F.softmax(logits, dim=-1)\n",
" return torch.multinomial(probs, 1)[:, 0].data.numpy()\n",
"\n",
" def step(self, prev_state, obs_t):\n",
Expand Down Expand Up @@ -258,11 +266,13 @@
"metadata": {},
"outputs": [],
"source": [
"import tqdm\n",
"\n",
"def evaluate(agent, env, n_games=1):\n",
" \"\"\"Plays an entire game start to end, returns session rewards.\"\"\"\n",
"\n",
" game_rewards = []\n",
" for _ in range(n_games):\n",
" for _ in tqdm.notebook.trange(n_games):\n",
" # initial observation and memory\n",
" observation = env.reset()\n",
" prev_memories = agent.get_initial_state(1)\n",
Expand Down Expand Up @@ -292,7 +302,7 @@
"source": [
"import gym.wrappers\n",
"\n",
"with gym.wrappers.Monitor(make_env(), directory=\"videos\", force=True) as env_monitor:\n",
"with gym.wrappers.RecordVideo(make_env(), video_folder=\"videos\") as env_monitor:\n",
" rewards = evaluate(agent, env_monitor, n_games=3)\n",
"\n",
"print(rewards)"
Expand Down Expand Up @@ -336,7 +346,7 @@
"### Training on parallel games\n",
"\n",
"We introduce a class called EnvPool - it's a tool that handles multiple environments for you. Here's how it works:\n",
"![img](img2.jpg)"
"![img](https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week08_pomdp/img2.jpg)"
]
},
{
Expand All @@ -354,7 +364,7 @@
"metadata": {},
"source": [
"We gonna train our agent on a thing called __rollouts:__\n",
"![img](img3.jpg)\n",
"![img](https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week08_pomdp/img3.jpg)\n",
"\n",
"A rollout is just a sequence of T observations, actions and rewards that agent took consequently.\n",
"* First __s0__ is not necessarily initial state for the environment\n",
Expand Down Expand Up @@ -446,7 +456,7 @@
"source": [
"def to_one_hot(y, n_dims=None):\n",
" \"\"\" Take an integer tensor and convert it to 1-hot matrix. \"\"\"\n",
" y_tensor = y.to(dtype=torch.int64).view(-1, 1)\n",
" y_tensor = y.to(dtype=torch.int64).reshape(-1, 1)\n",
" n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1\n",
" y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1)\n",
" return y_one_hot"
Expand All @@ -472,7 +482,7 @@
" states = torch.tensor(np.asarray(states), dtype=torch.float32)\n",
" actions = torch.tensor(np.array(actions), dtype=torch.int64) # shape: [batch_size, time]\n",
" rewards = torch.tensor(np.array(rewards), dtype=torch.float32) # shape: [batch_size, time]\n",
" is_not_done = torch.tensor(np.array(is_not_done), dtype=torch.float32) # shape: [batch_size, time]\n",
" is_not_done = torch.tensor(np.array(is_not_done), dtype=torch.bool) # shape: [batch_size, time]\n",
" rollout_length = rewards.shape[1] - 1\n",
"\n",
" # predict logits, probas and log-probas using an agent.\n",
Expand All @@ -483,7 +493,7 @@
" for t in range(rewards.shape[1]):\n",
" obs_t = states[:, t]\n",
"\n",
" # use agent to comute logits_t and state values_t.\n",
" # use agent to compute logits_t and state values_t.\n",
" # append them to logits and state_values array\n",
"\n",
" memory, (logits_t, values_t) = <YOUR CODE>\n",
Expand Down Expand Up @@ -521,9 +531,10 @@
" V_next = state_values[:, t + 1].detach() # next state values\n",
" # log-probability of a_t in s_t\n",
" logpi_a_s_t = logprobas_for_actions[:, t]\n",
" is_not_done_t = is_not_done[:, t]\n",
"\n",
" # update G_t = r_t + gamma * G_{t+1} as we did in week6 reinforce\n",
" cumulative_returns = G_t = r_t + gamma * cumulative_returns\n",
" cumulative_returns = G_t = r_t + torch.where(is_not_done_t, gamma * cumulative_returns, 0)\n",
"\n",
" # Compute temporal difference error (MSE for V(s))\n",
" value_loss += <YOUR CODE>\n",
Expand Down Expand Up @@ -579,7 +590,6 @@
"outputs": [],
"source": [
"from IPython.display import clear_output\n",
"from tqdm import trange\n",
"from pandas import DataFrame\n",
"moving_average = lambda x, **kw: DataFrame(\n",
" {'x': np.asarray(x)}).x.ewm(**kw).mean().values\n",
Expand All @@ -593,21 +603,27 @@
"metadata": {},
"outputs": [],
"source": [
"for i in trange(15000):\n",
"log_every = 100\n",
"\n",
"for i in tqdm.trange(15000):\n",
" # tqdm.notebook.tqdm is not trivial to use here because clear_output(True)\n",
" # also removes the tqdm widget\n",
"\n",
" memory = list(pool.prev_memory_states)\n",
" rollout_obs, rollout_actions, rollout_rewards, rollout_mask = pool.interact(\n",
" 10)\n",
" train_on_rollout(rollout_obs, rollout_actions,\n",
" rollout_rewards, rollout_mask, memory)\n",
" rollout_obs, rollout_actions, rollout_rewards, rollout_mask = pool.interact(10)\n",
" train_on_rollout(rollout_obs, rollout_actions, rollout_rewards, rollout_mask, memory)\n",
"\n",
" if i % 100 == 0:\n",
" if i % log_every == 0:\n",
" rewards_history.append(np.mean(evaluate(agent, env, n_games=1)))\n",
" clear_output(True)\n",
" plt.plot(rewards_history, label='rewards')\n",
" plt.plot(moving_average(np.array(rewards_history),\n",
" span=10), label='rewards ewma@10')\n",
" plt.plot(\n",
" np.arange(len(rewards_history)) * log_every,\n",
" rewards_history, label='rewards')\n",
" plt.plot(\n",
" np.arange(len(rewards_history)) * log_every,\n",
" moving_average(np.array(rewards_history), span=10), label='rewards ewma@10')\n",
" plt.legend()\n",
" plt.grid()\n",
" plt.show()\n",
" if rewards_history[-1] >= 10000:\n",
" print(\"Your agent has just passed the minimum homework threshold\")\n",
Expand All @@ -628,7 +644,7 @@
"Since we use a policy-based method, we also keep track of __policy entropy__ - the same one you used as a regularizer. The only important thing about it is that your entropy shouldn't drop too low (`< 0.1`) before your agent gets the yellow belt. Or at least it can drop there, but _it shouldn't stay there for long_.\n",
"\n",
"If it does, the culprit is likely:\n",
"* Some bug in entropy computation. Remember that it is $ - \\sum p(a_i) \\cdot log p(a_i) $\n",
"* Some bug in entropy computation. Remember that it is $ - \\sum p(a_i) \\cdot \\log p(a_i) $\n",
"* Your agent architecture converges too fast. Increase entropy coefficient in actor loss. \n",
"* Gradient explosion - just [clip gradients](https://stackoverflow.com/a/56069467) and maybe use a smaller network\n",
"* Us. Or PyTorch developers. Or aliens. Or lizardfolk. Contact us on forums before it's too late!\n",
Expand All @@ -651,7 +667,7 @@
"source": [
"import gym.wrappers\n",
"\n",
"with gym.wrappers.Monitor(make_env(), directory=\"videos\", force=True) as env_monitor:\n",
"with gym.wrappers.RecordVideo(make_env(), video_folder=\"videos\") as env_monitor:\n",
" final_rewards = evaluate(agent, env_monitor, n_games=20)\n",
"\n",
"print(\"Final mean reward\", np.mean(final_rewards))"
Expand Down

0 comments on commit 47c3325

Please sign in to comment.