From 4b268db25ed514c2139584b37a3c29c2077fed13 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Mon, 30 Oct 2023 09:41:50 -0700 Subject: [PATCH] Release notebook for TRICE paper. PiperOrigin-RevId: 577872975 --- .github/workflows/pytest_and_autopublish.yml | 2 +- cascades/examples/notebooks/trice.ipynb | 1003 ++++++++++++++++++ 2 files changed, 1004 insertions(+), 1 deletion(-) create mode 100644 cascades/examples/notebooks/trice.ipynb diff --git a/.github/workflows/pytest_and_autopublish.yml b/.github/workflows/pytest_and_autopublish.yml index 35b6ecd..f6eed35 100644 --- a/.github/workflows/pytest_and_autopublish.yml +++ b/.github/workflows/pytest_and_autopublish.yml @@ -18,7 +18,7 @@ jobs: # Install deps - uses: actions/setup-python@v4 with: - python-version: 3.7 + python-version: 3.9 - run: pip --version - run: pip install -e .[dev] - run: pip freeze diff --git a/cascades/examples/notebooks/trice.ipynb b/cascades/examples/notebooks/trice.ipynb new file mode 100644 index 0000000..49bcc71 --- /dev/null +++ b/cascades/examples/notebooks/trice.ipynb @@ -0,0 +1,1003 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "riiNPfaT5qaj" + }, + "source": [ + "This colab demonstrates the main TRICE algorithm in the paper: Training Chain-of-Thought via Latent-Variable Inference ([openreview](https://openreview.net/forum?id=a147pIS2Co))." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "EyzbCM02ArpA" + }, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import optax\n", + "import requests\n", + "from tensorflow_probability.substrates import jax as tfp\n", + "\n", + "# " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Z0pIKg-l6E6Y" + }, + "source": [ + "The code in this colab interfaces with an LLM backend via four functions: `sample`, `log_prob`, `grad`, and `init`. Signatures for these functions are below; to run this colab, you will need to provide callables that implement these functions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "v_Is6BkgBJOx" + }, + "outputs": [], + "source": [ + "# pylint:disable=unused-argument\n", + "def sample(params, context, num_steps, seed):\n", + " \"\"\"Draws a sample continuation.\n", + "\n", + " Args:\n", + " params: PyTree params.\n", + " context: a Python string or a list of Python strings.\n", + " num_steps: The maximum number of tokens to generate.\n", + " seed: The random seed.\n", + "\n", + " Returns:\n", + " A Python string or a list of Python strings.\n", + " \"\"\"\n", + " raise NotImplementedError\n", + "\n", + "def log_prob(params, context, continuation):\n", + " \"\"\"Computes the log-probability of generating continuation.\n", + "\n", + " Args:\n", + " params: PyTree params.\n", + " context: a Python string or a list of Python strings.\n", + " continuation: a target string or a list of target strings to get log\n", + " probability with respect to.\n", + "\n", + " Returns:\n", + " A scalar or a vector array.\n", + " \"\"\"\n", + " raise NotImplementedError\n", + "\n", + "def grad(params, context, continuation):\n", + " \"\"\"Computes the gradient of log_prob w.r.t. params.\n", + "\n", + " Args:\n", + " params: PyTree params.\n", + " context: a Python string or a list of Python strings.\n", + " continuation: a target string or a list of target strings to get gradient of\n", + " log_prob with respect to.\n", + "\n", + " Returns:\n", + " A PyTree params corresponding to the gradient of log_prob w.r.t. params. If\n", + " a list of string inputs are provided, the gradient will have a corresponding\n", + " batch dimension on the left.\n", + " \"\"\"\n", + " raise NotImplementedError\n", + "\n", + "def init(seed):\n", + " \"\"\"Returns an initial value for a PyTree params.\n", + "\n", + " Args:\n", + " seed: The seed to generate. This could be an optional text to use\n", + " soft-prompt embedding as the initial value for the parameter.\n", + "\n", + " Returns:\n", + " A PyTree params.\n", + " \"\"\"\n", + " raise NotImplementedError\n", + "# pylint:enable=unused-argument\n", + "\n", + "def get_dataset(seed=None):\n", + " \"\"\"Get train and validation datasets.\n", + "\n", + " Args:\n", + " seed: The random seed to shuffle the training dataset.\n", + "\n", + " Returns:\n", + " A tuple of (train_questions, train_answers, val_questions, val_answers).\n", + " \"\"\"\n", + " task = \"logical_deduction_three_objects\"\n", + " data_url = f\"https://raw.githubusercontent.com/suzgunmirac/BIG-Bench-Hard/main/bbh/{task}.json\"\n", + " examples = requests.get(data_url).json()[\"examples\"]\n", + " if seed is not None:\n", + " permutation = jax.random.permutation(seed, len(examples))\n", + " examples = [examples[int(i)] for i in permutation]\n", + " questions = [ex[\"input\"] for ex in examples]\n", + " answers = [ex[\"target\"] for ex in examples]\n", + " train_questions, train_answers = questions[:150], answers[:150]\n", + " val_questions, val_answers = questions[150:], answers[150:]\n", + " return train_questions, train_answers, val_questions, val_answers\n", + "\n", + "# " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "uuoBZgSB7lBD" + }, + "outputs": [], + "source": [ + "RATIONALE_TEMPLATE = \"\"\"Question: {question}\n", + "Answer: Let's think step by step.\n", + "\"\"\"\n", + "GUIDE_TEMPLATE = \"\"\"Question: {question}\n", + "Answer: The answer is {answer}. Let's think step by step.\n", + "\"\"\"\n", + "SEED = 0\n", + "LEARNING_RATE = 1.0\n", + "TRAIN_STEPS = 100\n", + "BATCH_SIZE = 64\n", + "GRADIENT_SUBSAMPLE_SIZE = 64\n", + "LOG_EVERY = 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sEITBL605dTU" + }, + "outputs": [], + "source": [ + "def is_correct(rationale, answer):\n", + " \"\"\"Checks whether the rationale is correct.\"\"\"\n", + " return rationale.endswith(\"the answer is \" + answer + \".\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "h_VvQMpRu8SB" + }, + "outputs": [], + "source": [ + "def make_cot_prompt(questions, answers, seed):\n", + " \"\"\"Generates the COT soft prompt from the questions and answers.\"\"\"\n", + " qs, rs = [], []\n", + " for i, (q, a) in enumerate(zip(questions, answers)):\n", + " seed_i = jax.random.fold_in(seed, i)\n", + " r = sample(params=None, context=RATIONALE_TEMPLATE.format(question=q), seed=seed_i)\n", + " if is_correct(r, a):\n", + " qs.append(q)\n", + " rs.append(r)\n", + " if len(qs) == 3:\n", + " break\n", + " assert len(qs) == 3, \"Not enough examples to construct cot prompt.\"\n", + " return \"\".join([RATIONALE_TEMPLATE.format(question=q) + r + \"\\n\\n\" for q, r in zip(qs, rs)])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "EyRqVGPMu6Sp" + }, + "outputs": [], + "source": [ + "def get_init_memory(fewshots_prompt, train_questions, train_answers, seed):\n", + " \"\"\"Obtains the initial memory.\"\"\"\n", + " examples = fewshots_prompt.strip().split(\"\\n\\n\")\n", + " guide_prompt = \"\"\n", + " for ex in examples:\n", + " q, r = ex.split(\"\\nAnswer: Let's think step by step.\\n\")\n", + " q = q.split(\"Question: \")[-1]\n", + " a = r.split(\"the answer is \")[-1][:-1]\n", + " guide_prompt += GUIDE_TEMPLATE.format(question=q, answer=a) + r + \"\\n\\n\"\n", + " print(\"GUIDE PROMPT:\", guide_prompt)\n", + " context = [guide_prompt + GUIDE_TEMPLATE.format(question=q, answer=a)\n", + " for q, a in zip(train_questions, train_answers)]\n", + " memory = sample(params=None, context=context, seed=seed)\n", + " return memory" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "uiHL-O6HvK-K" + }, + "outputs": [], + "source": [ + "def evaluate(params, val_questions, val_answers):\n", + " \"\"\"Evaluates the model accuracy given the current params.\"\"\"\n", + " context = [RATIONALE_TEMPLATE.format(question=q) for q in val_questions]\n", + " rationales = sample(params, context=context, temperature=0.0, seed=jax.random.PRNGKey(0))\n", + " correct_rationales = [r for r, a in zip(rationales, val_answers) if is_correct(r, a)]\n", + " print(\"Val accuracy:\", len(correct_rationales) / len(val_answers), flush=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "aWSqVvQ7usNx" + }, + "outputs": [], + "source": [ + "def trice_loss(params, memory, seed, questions, answers):\n", + " \"\"\"Computes TRICE objective and its gradients.\"\"\"\n", + " subsample_seed, sample_seed = jax.random.split(seed)\n", + " context = [RATIONALE_TEMPLATE.format(question=q) for q in questions]\n", + " proposal = sample(params, context=context, seed=sample_seed)\n", + " is_proposal_correct = jnp.stack([is_correct(r, a) for r, a in zip(proposal, answers)])\n", + " is_memory_correct = jnp.stack([is_correct(m, a) for m, a in zip(memory, answers)])\n", + " new_memory = [r if accept else m for r, m, accept in zip(proposal, memory, is_proposal_correct)]\n", + " mask = is_proposal_correct | is_memory_correct\n", + " correlation_est = (is_proposal_correct.sum() - is_proposal_correct) / (mask.sum() - 1 + 1e-10)\n", + "\n", + " # compute weight contributions of rationales from both memory and proposal.\n", + " weights_memory = mask * (1 - correlation_est * is_proposal_correct)\n", + " weights_proposal = mask * correlation_est * (1 - is_proposal_correct)\n", + " flat_weights = jnp.concatenate([weights_memory, weights_proposal])\n", + " flat_rationales = new_memory + proposal\n", + " flat_contexts = context + context\n", + " flat_signs = jnp.concatenate([mask, -1 * mask])\n", + " flat_weights = jnp.clip(flat_weights, a_min=1e-10)\n", + " weights_mean = flat_weights.sum() / (mask.sum() + 1e-10)\n", + " # Note: to compute the loss without subsampling, we can set\n", + " # `subsampled_indices = jnp.arange(2 * len(questions))`\n", + " subsampled_indices = tfp.experimental.mcmc.resample_systematic(\n", + " jax.nn.log_softmax(flat_weights), GRADIENT_SUBSAMPLE_SIZE, (), subsample_seed)\n", + " subsampled_rationales = [flat_rationales[i] for i in subsampled_indices]\n", + " subsampled_contexts = [flat_contexts[i] for i in subsampled_indices]\n", + " subsampled_signs = jnp.stack([flat_signs[i] for i in subsampled_indices])\n", + " subsampled_log_probs = log_prob(params, subsampled_contexts, subsampled_rationales)\n", + " subsampled_grads = grad(params, subsampled_contexts, subsampled_rationales)\n", + " loss = -(subsampled_log_probs * subsampled_signs).mean() * weights_mean\n", + " params_grad = jax.tree_util.tree_map(\n", + " lambda g: -(jnp.moveaxis(g, 0, -1) * subsampled_signs).mean(-1) * weights_mean, subsampled_grads)\n", + " return (loss, new_memory, is_proposal_correct.mean(), mask.mean()), params_grad" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RH1pDpG5wXhD" + }, + "outputs": [], + "source": [ + "def fit(optimizer, params, memory, train_questions, train_answers, val_questions, val_answers, seed):\n", + " \"\"\"Runs the optimization loop.\"\"\"\n", + " opt_state = optimizer.init(params)\n", + " permute_seed, loss_seed = jax.random.split(seed)\n", + " num_batches_per_epoch = len(train_questions) // BATCH_SIZE\n", + " memory = memory.copy()\n", + " for i in range(TRAIN_STEPS):\n", + " # Get a mini-batch of questions and answers.\n", + " permute_seed_i = jax.random.fold_in(permute_seed, i // num_batches_per_epoch)\n", + " permutation = jax.random.permutation(permute_seed_i, len(train_questions))\n", + " batch_indices = permutation[i % num_batches_per_epoch::num_batches_per_epoch][:BATCH_SIZE]\n", + " batch_questions = [train_questions[int(i)] for i in batch_indices]\n", + " batch_answers = [train_answers[int(i)] for i in batch_indices]\n", + " batch_memory = [memory[int(i)] for i in batch_indices]\n", + "\n", + " loss_seed_i = jax.random.fold_in(loss_seed, i)\n", + " (loss, batch_memory, proposal_acc, memory_acc), grads = trice_loss(\n", + " params, batch_memory, loss_seed_i, batch_questions, batch_answers)\n", + " for j, idx in enumerate(batch_indices):\n", + " memory[idx] = batch_memory[j]\n", + " updates, opt_state = optimizer.update(grads, opt_state, params)\n", + " params = optax.apply_updates(params, updates)\n", + " if (i + 1) % LOG_EVERY == 0:\n", + " print(f\"Step {i + 1}: loss {loss:.4f} | proposal_acc {proposal_acc:.4f} \"\n", + " f\"| memory_acc {memory_acc:.4f}\", flush=True)\n", + " evaluate(params, val_questions, val_answers)\n", + " return params, memory" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "height": 523 + }, + "id": "5qtGTcw7E8O_", + "outputId": "62141a0e-9e80-4161-a198-3b7040456df2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "GUIDE PROMPT: Question: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. A fruit stand sells three fruits: cantaloupes, oranges, and watermelons. The oranges are the most expensive. The cantaloupes are more expensive than the watermelons.\n", + "Options:\n", + "(A) The cantaloupes are the cheapest\n", + "(B) The oranges are the cheapest\n", + "(C) The watermelons are the cheapest\n", + "Answer: The answer is (C). Let's think step by step.\n", + "The oranges are the most expensive. The cantaloupes are more expensive than the watermelons. So, the fruit that is the cheapest here is the watermelon (as compared to orange and cantaloupe, the watermelon is cheapest).\n", + "Thus, the answer is (C).\n", + "\n", + "Question: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. A fruit stand sells three fruits: kiwis, loquats, and cantaloupes. The kiwis are less expensive than the loquats. The cantaloupes are less expensive than the kiwis.\n", + "Options:\n", + "(A) The kiwis are the cheapest\n", + "(B) The loquats are the cheapest\n", + "(C) The cantaloupes are the cheapest\n", + "Answer: The answer is (C). Let's think step by step.\n", + "The cantaloupes are less expensive than the kiwis (which are less expensive than the loquats). The cantaloupes are the cheapest.\n", + "So, the answer is (C).\n", + "\n", + "Question: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. A fruit stand sells three fruits: oranges, mangoes, and plums. The plums are less expensive than the mangoes. The oranges are less expensive than the plums.\n", + "Options:\n", + "(A) The oranges are the cheapest\n", + "(B) The mangoes are the cheapest\n", + "(C) The plums are the cheapest\n", + "Answer: The answer is (A). Let's think step by step.\n", + "The oranges are less expensive than the plums. The Plums are less expensive than the mangoes. The oranges are the cheapest.\n", + "So, the answer is (A).\n", + "\n", + "\n" + ] + } + ], + "source": [ + "data_seed, cot_seed, params_seed, memory_seed, fit_seed = jax.random.split(jax.random.PRNGKey(SEED), 5)\n", + "full_train_questions, full_train_answers, full_val_questions, full_val_answers = get_dataset(seed=data_seed)\n", + "cot_prompt = make_cot_prompt(full_train_questions, full_train_answers, cot_seed)\n", + "init_params = init(cot_prompt)\n", + "init_memory = get_init_memory(cot_prompt, full_train_questions, full_train_answers, memory_seed)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "height": 1000 + }, + "id": "MZmKIu0MGvpX", + "outputId": "98d1432c-fc72-4e12-cc1a-e7f8b01d6232" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 1: loss 1.2668 | proposal_acc 0.4688 | memory_acc 1.0000\n", + "Val accuracy: 0.5\n", + "Step 2: loss 6.6766 | proposal_acc 0.4219 | memory_acc 1.0000\n", + "Val accuracy: 0.58\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 3: loss 2.7884 | proposal_acc 0.5625 | memory_acc 1.0000\n", + "Val accuracy: 0.64\n", + "Step 4: loss 0.6468 | proposal_acc 0.6250 | memory_acc 1.0000\n", + "Val accuracy: 0.67\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 5: loss 4.4672 | proposal_acc 0.5781 | memory_acc 1.0000\n", + "Val accuracy: 0.66\n", + "Step 6: loss 4.5081 | proposal_acc 0.5625 | memory_acc 1.0000\n", + "Val accuracy: 0.65\n", + "Step 7: loss 2.0394 | proposal_acc 0.6094 | memory_acc 1.0000\n", + "Val accuracy: 0.66\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 8: loss 1.4886 | proposal_acc 0.5625 | memory_acc 1.0000\n", + "Val accuracy: 0.69\n", + "Step 9: loss 1.9745 | proposal_acc 0.6094 | memory_acc 1.0000\n", + "Val accuracy: 0.71\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 10: loss 0.5645 | proposal_acc 0.6406 | memory_acc 1.0000\n", + "Val accuracy: 0.7\n", + "Step 11: loss 3.4187 | proposal_acc 0.5312 | memory_acc 1.0000\n", + "Val accuracy: 0.69\n", + "Step 12: loss 2.1898 | proposal_acc 0.6250 | memory_acc 1.0000\n", + "Val accuracy: 0.71\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 13: loss 3.2672 | proposal_acc 0.6719 | memory_acc 1.0000\n", + "Val accuracy: 0.73\n", + "Step 14: loss 3.6095 | proposal_acc 0.5938 | memory_acc 1.0000\n", + "Val accuracy: 0.69\n", + "Step 15: loss 1.6829 | proposal_acc 0.7031 | memory_acc 1.0000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Val accuracy: 0.69\n", + "Step 16: loss 2.8023 | proposal_acc 0.5938 | memory_acc 1.0000\n", + "Val accuracy: 0.67\n", + "Step 17: loss 2.5703 | proposal_acc 0.6875 | memory_acc 1.0000\n", + "Val accuracy: 0.65\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 18: loss 1.9247 | proposal_acc 0.6406 | memory_acc 1.0000\n", + "Val accuracy: 0.64\n", + "Step 19: loss 3.4522 | proposal_acc 0.5781 | memory_acc 1.0000\n", + "Val accuracy: 0.64\n", + "Step 20: loss 5.3423 | proposal_acc 0.6562 | memory_acc 1.0000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Val accuracy: 0.67\n", + "Step 21: loss 0.7300 | proposal_acc 0.7500 | memory_acc 1.0000\n", + "Val accuracy: 0.69\n", + "Step 22: loss 0.6918 | proposal_acc 0.6875 | memory_acc 1.0000\n", + "Val accuracy: 0.71\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 23: loss 1.0869 | proposal_acc 0.6875 | memory_acc 1.0000\n", + "Val accuracy: 0.72\n", + "Step 24: loss 3.3732 | proposal_acc 0.7344 | memory_acc 1.0000\n", + "Val accuracy: 0.71\n", + "Step 25: loss 1.8571 | proposal_acc 0.7500 | memory_acc 1.0000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Val accuracy: 0.69\n", + "Step 26: loss 1.8002 | proposal_acc 0.7500 | memory_acc 1.0000\n", + "Val accuracy: 0.69\n", + "Step 27: loss 0.3136 | proposal_acc 0.7656 | memory_acc 1.0000\n", + "Val accuracy: 0.7\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 28: loss -0.4348 | proposal_acc 0.7344 | memory_acc 1.0000\n", + "Val accuracy: 0.73\n", + "Step 29: loss 0.6007 | proposal_acc 0.7656 | memory_acc 1.0000\n", + "Val accuracy: 0.75\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 30: loss -0.2545 | proposal_acc 0.7500 | memory_acc 1.0000\n", + "Val accuracy: 0.75\n", + "Step 31: loss -0.3139 | proposal_acc 0.8750 | memory_acc 1.0000\n", + "Val accuracy: 0.76\n", + "Step 32: loss 0.3814 | proposal_acc 0.8438 | memory_acc 1.0000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Val accuracy: 0.76\n", + "Step 33: loss 1.5325 | proposal_acc 0.7188 | memory_acc 1.0000\n", + "Val accuracy: 0.75\n", + "Step 34: loss 0.3954 | proposal_acc 0.7969 | memory_acc 1.0000\n", + "Val accuracy: 0.76\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 35: loss 0.2625 | proposal_acc 0.7344 | memory_acc 1.0000\n", + "Val accuracy: 0.78\n", + "Step 36: loss 0.1536 | proposal_acc 0.8438 | memory_acc 1.0000\n", + "Val accuracy: 0.77\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 37: loss -0.0219 | proposal_acc 0.7969 | memory_acc 1.0000\n", + "Val accuracy: 0.78\n", + "Step 38: loss 0.3998 | proposal_acc 0.8125 | memory_acc 1.0000\n", + "Val accuracy: 0.78\n", + "Step 39: loss 1.6071 | proposal_acc 0.7969 | memory_acc 1.0000\n", + "Val accuracy: 0.76\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 40: loss 0.6734 | proposal_acc 0.7656 | memory_acc 1.0000\n", + "Val accuracy: 0.74\n", + "Step 41: loss -0.4060 | proposal_acc 0.8750 | memory_acc 1.0000\n", + "Val accuracy: 0.76\n", + "Step 42: loss 0.3061 | proposal_acc 0.7969 | memory_acc 1.0000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Val accuracy: 0.76\n", + "Step 43: loss 0.2787 | proposal_acc 0.8125 | memory_acc 1.0000\n", + "Val accuracy: 0.76\n", + "Step 44: loss 0.9522 | proposal_acc 0.8281 | memory_acc 1.0000\n", + "Val accuracy: 0.78\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 45: loss 0.2043 | proposal_acc 0.8594 | memory_acc 1.0000\n", + "Val accuracy: 0.8\n", + "Step 46: loss 0.3309 | proposal_acc 0.8750 | memory_acc 1.0000\n", + "Val accuracy: 0.8\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 47: loss 0.4043 | proposal_acc 0.8906 | memory_acc 1.0000\n", + "Val accuracy: 0.8\n", + "Step 48: loss 0.2161 | proposal_acc 0.8594 | memory_acc 1.0000\n", + "Val accuracy: 0.8\n", + "Step 49: loss 0.6909 | proposal_acc 0.7344 | memory_acc 1.0000\n", + "Val accuracy: 0.79\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 50: loss -0.0147 | proposal_acc 0.7656 | memory_acc 1.0000\n", + "Val accuracy: 0.79\n", + "Step 51: loss 0.0737 | proposal_acc 0.9219 | memory_acc 1.0000\n", + "Val accuracy: 0.78\n", + "Step 52: loss 0.4943 | proposal_acc 0.8125 | memory_acc 1.0000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Val accuracy: 0.81\n", + "Step 53: loss 0.8976 | proposal_acc 0.8438 | memory_acc 1.0000\n", + "Val accuracy: 0.8\n", + "Step 54: loss 0.2485 | proposal_acc 0.8906 | memory_acc 1.0000\n", + "Val accuracy: 0.81\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 55: loss 0.2710 | proposal_acc 0.8750 | memory_acc 1.0000\n", + "Val accuracy: 0.82\n", + "Step 56: loss 0.2302 | proposal_acc 0.8906 | memory_acc 1.0000\n", + "Val accuracy: 0.82\n", + "Step 57: loss 0.0981 | proposal_acc 0.8906 | memory_acc 1.0000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Val accuracy: 0.84\n", + "Step 58: loss 0.8216 | proposal_acc 0.8438 | memory_acc 1.0000\n", + "Val accuracy: 0.84\n", + "Step 59: loss 0.0988 | proposal_acc 0.8906 | memory_acc 1.0000\n", + "Val accuracy: 0.83\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 60: loss 1.1516 | proposal_acc 0.8438 | memory_acc 1.0000\n", + "Val accuracy: 0.83\n", + "Step 61: loss -0.1295 | proposal_acc 0.8125 | memory_acc 1.0000\n", + "Val accuracy: 0.85\n", + "Step 62: loss 0.5059 | proposal_acc 0.8906 | memory_acc 1.0000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Val accuracy: 0.84\n", + "Step 63: loss -0.0203 | proposal_acc 0.9375 | memory_acc 1.0000\n", + "Val accuracy: 0.83\n", + "Step 64: loss 0.3703 | proposal_acc 0.9219 | memory_acc 1.0000\n", + "Val accuracy: 0.86\n", + "Step 65: loss 0.4292 | proposal_acc 0.8906 | memory_acc 1.0000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Val accuracy: 0.85\n", + "Step 66: loss -0.3071 | proposal_acc 0.9062 | memory_acc 1.0000\n", + "Val accuracy: 0.87\n", + "Step 67: loss 0.0239 | proposal_acc 0.9375 | memory_acc 1.0000\n", + "Val accuracy: 0.87\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 68: loss 0.0374 | proposal_acc 0.8750 | memory_acc 1.0000\n", + "Val accuracy: 0.87\n", + "Step 69: loss 0.2615 | proposal_acc 0.9375 | memory_acc 1.0000\n", + "Val accuracy: 0.87\n", + "Step 70: loss 0.0776 | proposal_acc 0.9375 | memory_acc 1.0000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Val accuracy: 0.87\n", + "Step 71: loss -0.1293 | proposal_acc 0.9219 | memory_acc 1.0000\n", + "Val accuracy: 0.85\n", + "Step 72: loss -0.1358 | proposal_acc 0.9219 | memory_acc 1.0000\n", + "Val accuracy: 0.87\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 73: loss -0.0691 | proposal_acc 0.9219 | memory_acc 1.0000\n", + "Val accuracy: 0.88\n", + "Step 74: loss 0.0730 | proposal_acc 0.9688 | memory_acc 1.0000\n", + "Val accuracy: 0.88\n", + "Step 75: loss 0.0903 | proposal_acc 0.9062 | memory_acc 1.0000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Val accuracy: 0.88\n", + "Step 76: loss -0.0191 | proposal_acc 0.9219 | memory_acc 1.0000\n", + "Val accuracy: 0.87\n", + "Step 77: loss -0.1886 | proposal_acc 0.8906 | memory_acc 1.0000\n", + "Val accuracy: 0.87\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 78: loss 0.1393 | proposal_acc 0.9062 | memory_acc 1.0000\n", + "Val accuracy: 0.88\n", + "Step 79: loss 0.1810 | proposal_acc 0.9375 | memory_acc 1.0000\n", + "Val accuracy: 0.87\n", + "Step 80: loss -0.0811 | proposal_acc 0.9531 | memory_acc 1.0000\n", + "Val accuracy: 0.88\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 81: loss 0.1608 | proposal_acc 0.9219 | memory_acc 1.0000\n", + "Val accuracy: 0.88\n", + "Step 82: loss -0.3453 | proposal_acc 0.9531 | memory_acc 1.0000\n", + "Val accuracy: 0.89\n", + "Step 83: loss 0.4232 | proposal_acc 0.8750 | memory_acc 1.0000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Val accuracy: 0.89\n", + "Step 84: loss 0.1070 | proposal_acc 0.9219 | memory_acc 1.0000\n", + "Val accuracy: 0.88\n", + "Step 85: loss 0.1431 | proposal_acc 0.9219 | memory_acc 1.0000\n", + "Val accuracy: 0.88\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 86: loss 0.5932 | proposal_acc 0.9219 | memory_acc 1.0000\n", + "Val accuracy: 0.88\n", + "Step 87: loss 0.0786 | proposal_acc 0.9219 | memory_acc 1.0000\n", + "Val accuracy: 0.88\n", + "Step 88: loss -0.0000 | proposal_acc 1.0000 | memory_acc 1.0000\n", + "Val accuracy: 0.88\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 89: loss -0.0293 | proposal_acc 0.9219 | memory_acc 1.0000\n", + "Val accuracy: 0.87\n", + "Step 90: loss -0.3033 | proposal_acc 0.9531 | memory_acc 1.0000\n", + "Val accuracy: 0.88\n", + "Step 91: loss -0.1331 | proposal_acc 0.9062 | memory_acc 1.0000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Val accuracy: 0.88\n", + "Step 92: loss 0.0357 | proposal_acc 0.9844 | memory_acc 1.0000\n", + "Val accuracy: 0.86\n", + "Step 93: loss 0.3205 | proposal_acc 0.9219 | memory_acc 1.0000\n", + "Val accuracy: 0.87\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 94: loss 0.2064 | proposal_acc 0.9531 | memory_acc 1.0000\n", + "Val accuracy: 0.86\n", + "Step 95: loss 0.0000 | proposal_acc 1.0000 | memory_acc 1.0000\n", + "Val accuracy: 0.88\n", + "Step 96: loss 0.1271 | proposal_acc 0.9375 | memory_acc 1.0000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Val accuracy: 0.87\n", + "Step 97: loss 0.2188 | proposal_acc 0.9688 | memory_acc 1.0000\n", + "Val accuracy: 0.86\n", + "Step 98: loss 0.6120 | proposal_acc 0.9062 | memory_acc 1.0000\n", + "Val accuracy: 0.87\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 99: loss -0.0757 | proposal_acc 0.9062 | memory_acc 1.0000\n", + "Val accuracy: 0.87\n", + "Step 100: loss 0.1217 | proposal_acc 0.9531 | memory_acc 1.0000\n", + "Val accuracy: 0.86\n" + ] + } + ], + "source": [ + "optax_optimizer = optax.adam(optax.cosine_decay_schedule(LEARNING_RATE, TRAIN_STEPS, 0.1))\n", + "last_params, last_memory = fit(optax_optimizer, init_params, init_memory, full_train_questions,\n", + " full_train_answers, full_val_questions, full_val_answers, fit_seed)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oQm-UDAGxx3T" + }, + "source": [ + "Let's inspect the last memory and the final params." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oW-aU-HwxRtC", + "outputId": "8fcb57cc-339a-432b-da3c-d3520063b537" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Question 0: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. In a golf tournament, there were three golfers: Amy, Eli, and Eve. Eve finished above Amy. Eli finished below Amy.\n", + "Options:\n", + "(A) Amy finished second\n", + "(B) Eli finished second\n", + "(C) Eve finished second\n", + "Answer: (A)\n", + "Init memory: Amy finishing above Eve implies Amy finished second.\n", + "So, the answer is (A).\n", + "Last memory: Eve finished above Amy. Eli finished below Amy. Amy finished second in the golf tournament.\n", + "So, the answer is (A).\n", + "=====\n", + "Question 1: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. A fruit stand sells three fruits: mangoes, watermelons, and kiwis. The watermelons are less expensive than the kiwis. The kiwis are the second-most expensive.\n", + "Options:\n", + "(A) The mangoes are the most expensive\n", + "(B) The watermelons are the most expensive\n", + "(C) The kiwis are the most expensive\n", + "Answer: (A)\n", + "Init memory: The watermelons are less expensive than the kiwis (which can be either most or second most expensive). The kiwis are the second-most expensive. The mangoes are the most expensive.\n", + "So, the answer is (A).\n", + "Last memory: The watermelons are the least expensive (i.e. cheapest) sothe mangoes are the most expensive.\n", + "So, the answer is (A).\n", + "=====\n", + "Question 2: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. On a shelf, there are three books: a blue book, an orange book, and a red book. The blue book is the rightmost. The orange book is the leftmost.\n", + "Options:\n", + "(A) The blue book is the leftmost\n", + "(B) The orange book is the leftmost\n", + "(C) The red book is the leftmost\n", + "Answer: (B)\n", + "Init memory: The orange book is the leftmost. The blue book is the rightmost. The red book is between them. So, the answer is (B).\n", + "Last memory: The blue book is the rightmost. The orangre book is the leftmost. The orange book is the leftmost.\n", + "So, the answer is (B).\n", + "=====\n", + "Question 3: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. On a branch, there are three birds: a robin, an owl, and a hawk. The robin is to the right of the hawk. The owl is the leftmost.\n", + "Options:\n", + "(A) The robin is the second from the left\n", + "(B) The owl is the second from the left\n", + "(C) The hawk is the second from the left\n", + "Answer: (C)\n", + "Init memory: The hawk is the second from the left because the robin is to the right of the hawk but the owl is the leftmost.\n", + "So, the answer is (C).\n", + "Last memory: The robin is to the right of the hawk. The owl is the leftmost. The hawk is the second from the left.\n", + "So, the answer is (C).\n", + "=====\n", + "Question 4: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. In a golf tournament, there were three golfers: Amy, Ana, and Joe. Amy finished below Ana. Ana finished below Joe.\n", + "Options:\n", + "(A) Amy finished second\n", + "(B) Ana finished second\n", + "(C) Joe finished second\n", + "Answer: (B)\n", + "Init memory: Amy finished below Ana, so Ana was better than Amy. Ana finished below Joe, so Joe was better than Ana. So, the final order is Joe 1st, Ana 2nd, Amy 3rd.\n", + "So, the answer is (B).\n", + "Last memory: Amy finished below Ana (who finished below Joe). So, Ana finished second.\n", + "So, the answer is (B).\n", + "=====\n" + ] + } + ], + "source": [ + "for k, q_, a_, init_m, last_m in zip(range(5), full_train_questions, full_train_answers, init_memory, last_memory):\n", + " print(f\"Question {k}: {q_}\\nAnswer: {a_}\\nInit memory: {init_m}\\nLast memory: {last_m}\\n=====\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PGPBpItWxs7C", + "outputId": "a7445277-25f3-474e-e198-ef7d19a0cf0f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Question 0: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. On a shelf, there are three books: a green book, a brown book, and an orange book. The brown book is to the left of the orange book. The green book is to the left of the brown book.\n", + "Options:\n", + "(A) The green book is the leftmost\n", + "(B) The brown book is the leftmost\n", + "(C) The orange book is the leftmost\n", + "Answer: (A)\n", + "Rationale (init params): The green book is to the left of the brown book. The brown book is to the left of the orange book. The green book is the leftmost.\n", + "So, the answer is (A).\n", + "Rationale (last params): The green book is to the left of the brown book (which is to the left of the orange book). The green book is the leftmost.\n", + "So, the answer is (A).\n", + "=====\n", + "Question 1: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. On a branch, there are three birds: a robin, a blue jay, and a crow. The robin is to the right of the crow. The robin is to the left of the blue jay.\n", + "Options:\n", + "(A) The robin is the leftmost\n", + "(B) The blue jay is the leftmost\n", + "(C) The crow is the leftmost\n", + "Answer: (C)\n", + "Rationale (init params): The robin is to the left of the blue jay. The robin is to the right of the crow. The blue jay is to the left of the crow.\n", + "So, the answer is (B).\n", + "Rationale (last params): The robin is to the right of the crow. The robin is to the left of the blue jay. The crow is the leftmost.\n", + "So, the answer is (C).\n", + "=====\n", + "Question 2: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. In an antique car show, there are three vehicles: a hatchback, a convertible, and a tractor. The convertible is older than the tractor. The hatchback is the second-newest.\n", + "Options:\n", + "(A) The hatchback is the oldest\n", + "(B) The convertible is the oldest\n", + "(C) The tractor is the oldest\n", + "Answer: (B)\n", + "Rationale (init params): The hatchback is the second-newest. The convertible is older than the tractor. The tractor is the oldest.\n", + "So, the answer is (C).\n", + "Rationale (last params): The convertible is older than the tractor. The hatchback is the second-newest. The convertible is the oldest.\n", + "So, the answer is (B).\n", + "=====\n", + "Question 3: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. A fruit stand sells three fruits: apples, peaches, and mangoes. The mangoes are less expensive than the peaches. The apples are less expensive than the mangoes.\n", + "Options:\n", + "(A) The apples are the cheapest\n", + "(B) The peaches are the cheapest\n", + "(C) The mangoes are the cheapest\n", + "Answer: (A)\n", + "Rationale (init params): The apples are less expensive than the mangoes. The mangoes are less expensive than the peaches. The apples are the cheapest.\n", + "So, the answer is (A).\n", + "Rationale (last params): The mangoes are less expensive than the peaches. The apples are less expensive than the mangoes. The apples are the cheapest.\n", + "So, the answer is (A).\n", + "=====\n", + "Question 4: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. A fruit stand sells three fruits: oranges, mangoes, and plums. The plums are less expensive than the mangoes. The oranges are less expensive than the plums.\n", + "Options:\n", + "(A) The oranges are the second-most expensive\n", + "(B) The mangoes are the second-most expensive\n", + "(C) The plums are the second-most expensive\n", + "Answer: (C)\n", + "Rationale (init params): The oranges are less expensive than the plums. The plums are less expensive than the mangoes. The oranges are the second-most expensive.\n", + "So, the answer is (A).\n", + "Rationale (last params): The plums are less expensive than the mangoes. The oranges are less expensive than the plums. The mangoes are the second-most expensive.\n", + "So, the answer is (B).\n", + "=====\n" + ] + } + ], + "source": [ + "example_contexts = [RATIONALE_TEMPLATE.format(question=q) for q in full_val_questions[:5]]\n", + "init_rationales = sample(init_params, context=example_contexts, temperature=0.0, seed=jax.random.PRNGKey(0))\n", + "last_rationales = sample(last_params, context=example_contexts, temperature=0.0, seed=jax.random.PRNGKey(0))\n", + "for k, q_, a_, init_r, last_r in zip(range(5), full_val_questions, full_val_answers, init_rationales, last_rationales):\n", + " print(f\"Question {k}: {q_}\\nAnswer: {a_}\\nRationale (init params): {init_r}\\nRationale (last params): {last_r}\\n=====\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7lR5inNggHp6" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}