google-research

Форк
0
/
OSS_general_pattern_machines_ARC.ipynb 
550 строк · 45.1 Кб
1
{
2
  "nbformat": 4,
3
  "nbformat_minor": 0,
4
  "metadata": {
5
    "colab": {
6
      "provenance": [],
7
      "collapsed_sections": [
8
        "Ltj0f-flNAi9"
9
      ]
10
    },
11
    "kernelspec": {
12
      "name": "python3",
13
      "display_name": "Python 3"
14
    },
15
    "language_info": {
16
      "name": "python"
17
    }
18
  },
19
  "cells": [
20
    {
21
      "cell_type": "markdown",
22
      "source": [
23
        "##### Copyright 2023 Google LLC. SPDX-License-Identifier: Apache-2.0"
24
      ],
25
      "metadata": {
26
        "id": "Ltj0f-flNAi9"
27
      }
28
    },
29
    {
30
      "cell_type": "markdown",
31
      "source": [
32
        "Copyright 2023 Google LLC. SPDX-License-Identifier: Apache-2.0\n",
33
        "\n",
34
        "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n",
35
        "\n",
36
        "https://www.apache.org/licenses/LICENSE-2.0\n",
37
        "\n",
38
        "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
39
      ],
40
      "metadata": {
41
        "id": "I8NlVpAzNB2u"
42
      }
43
    },
44
    {
45
      "cell_type": "markdown",
46
      "source": [
47
        "## **LLMs as General Pattern Machines:** ARC Benchmark\n",
48
        "\n",
49
        "We observe that pretrained large language models (LLMs) are capable of autoregressively completing complex token sequences -- from arbitrary ones procedurally generated by probabilistic context-free grammars (PCFG), to more rich spatial patterns found in the Abstract Reasoning Corpus (ARC), a general AI benchmark, prompted in the style of ASCII art. Surprisingly, pattern completion proficiency can be partially retained even when the sequences are expressed using tokens randomly sampled from the vocabulary. These results suggest that without any additional training, LLMs can serve as general sequence modelers, driven by in-context learning. In this work, we investigate how these zero-shot capabilities may be applied to problems in robotics -- from extrapolating sequences of numbers that represent states over time to complete simple motions, to least-to-most prompting of reward-conditioned trajectories that can discover and represent closed-loop policies (e.g., a stabilizing controller for CartPole). While difficult to deploy today for real systems due to latency, context size limitations, and compute costs, the approach of using LLMs to drive low-level control may provide an exciting glimpse into how the patterns among words could be transferred to actions.\n",
50
        "\n",
51
        "This colab runs GPT-3 on the ARC benchmark with consistent tokenization (described more in Sec. 4 of the main paper).\n",
52
        "\n",
53
        "### **Quick Start:**\n",
54
        "\n",
55
        "**Step 1.** Register for an [OpenAI API key](https://openai.com/blog/openai-api/) to use GPT-3 (there's a free trial) and enter it below\n",
56
        "\n",
57
        "**Step 2.** Menu > Runtime > Run all"
58
      ],
59
      "metadata": {
60
        "id": "zqTADtDB6zyA"
61
      }
62
    },
63
    {
64
      "cell_type": "code",
65
      "source": [
66
        "openai_api_key = \"your-api-key-here\""
67
      ],
68
      "metadata": {
69
        "id": "wwJDOJSz71lk"
70
      },
71
      "execution_count": null,
72
      "outputs": []
73
    },
74
    {
75
      "cell_type": "markdown",
76
      "source": [
77
        "## **Setup**\n",
78
        "\n",
79
        "This does a few things:\n",
80
        "* Installs Python packages and sets OpenAI API key.\n",
81
        "* Downloads the Abstract Reasoning Corpus (ARC) benchmark.\n",
82
        "\n",
83
        "**Note:** only needs a CPU (public) runtime."
84
      ],
85
      "metadata": {
86
        "id": "kbWMlIj7XxX8"
87
      }
88
    },
89
    {
90
      "cell_type": "code",
91
      "source": [
92
        "!pip install openai transformers\n",
93
        "\n",
94
        "import json\n",
95
        "import os\n",
96
        "import time\n",
97
        "\n",
98
        "import matplotlib.pyplot as plt\n",
99
        "import numpy as np\n",
100
        "import openai\n",
101
        "import pickle\n",
102
        "from transformers import GPT2Tokenizer\n",
103
        "# import tiktoken  # Faster than GPT2Tokenizer.\n",
104
        "\n",
105
        "openai.api_key = openai_api_key\n",
106
        "\n",
107
        "if not os.path.exists(\"ARC\"):\n",
108
        "  !git clone https://github.com/fchollet/ARC"
109
      ],
110
      "metadata": {
111
        "id": "uI4hX8y5XzeH"
112
      },
113
      "execution_count": null,
114
      "outputs": []
115
    },
116
    {
117
      "cell_type": "markdown",
118
      "source": [
119
        "## **API:** Large Language Models\n",
120
        "\n",
121
        "Define helper functions to call large language models and the tokenizer.\n",
122
        "\n",
123
        "**Note:** this can get expensive."
124
      ],
125
      "metadata": {
126
        "id": "2AoMDZ-GZxRP"
127
      }
128
    },
129
    {
130
      "cell_type": "code",
131
      "execution_count": null,
132
      "metadata": {
133
        "id": "-waqt2fUb9ex",
134
        "colab": {
135
          "base_uri": "https://localhost:8080/"
136
        },
137
        "outputId": "f2d6d52e-5502-45e7-91cb-0855fde30c60"
138
      },
139
      "outputs": [
140
        {
141
          "output_type": "execute_result",
142
          "data": {
143
            "text/plain": [
144
              "[\"\\n\\nHello World! It's great to be here.\"]"
145
            ]
146
          },
147
          "metadata": {},
148
          "execution_count": 23
149
        }
150
      ],
151
      "source": [
152
        "model = \"text-davinci-003\"\n",
153
        "token_limit = 4096\n",
154
        "\n",
155
        "def LLM(prompt, stop=None, max_tokens=256, temperature=0):\n",
156
        "  responses = openai.Completion.create(engine=model, prompt=prompt, max_tokens=max_tokens, temperature=temperature, stop=stop)\n",
157
        "  text = [response['text'] for response in responses['choices']]\n",
158
        "  return text\n",
159
        "\n",
160
        "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
161
        "\n",
162
        "LLM(\"hello world!\")"
163
      ]
164
    },
165
    {
166
      "cell_type": "markdown",
167
      "source": [
168
        "## **Alphabet:** Token Set\n",
169
        "\n",
170
        "Build a fixed token set by random sampling from the LLM's token vocabulary."
171
      ],
172
      "metadata": {
173
        "id": "vMPLptkxatkC"
174
      }
175
    },
176
    {
177
      "cell_type": "code",
178
      "source": [
179
        "item_delim = tokenizer.encode(\",\")\n",
180
        "row_delim = tokenizer.encode(\"\\n\")\n",
181
        "sample_delim = tokenizer.encode(\"---\\n\")\n",
182
        "\n",
183
        "# Handpicked: comma-separated number matrices.\n",
184
        "alphabet = [tokenizer.encode(\" \" + str(a))[0] for a in range(10)]\n",
185
        "value_to_token = lambda x: {i:a for i, a in enumerate(alphabet)}[x]\n",
186
        "print(\"Token Set:\", {i:value_to_token(i) for i in np.arange(10)})\n",
187
        "\n",
188
        "# Random sampled tokens.\n",
189
        "# for seed_offset in range(20):\n",
190
        "# seed_offset = 0\n",
191
        "# np.random.seed(42 + seed_offset)\n",
192
        "# alphabet = [int(i) for i in np.random.randint(tokenizer.vocab_size, size=10)]\n",
193
        "# value_to_token = lambda x: {i:a for i, a in enumerate(alphabet)}[x]"
194
      ],
195
      "metadata": {
196
        "id": "tRIQ3AZMUmpK",
197
        "colab": {
198
          "base_uri": "https://localhost:8080/"
199
        },
200
        "outputId": "9e07e590-0986-4320-f164-0268609a9cf7"
201
      },
202
      "execution_count": null,
203
      "outputs": [
204
        {
205
          "output_type": "stream",
206
          "name": "stdout",
207
          "text": [
208
            "Token Set: {0: 657, 1: 352, 2: 362, 3: 513, 4: 604, 5: 642, 6: 718, 7: 767, 8: 807, 9: 860}\n"
209
          ]
210
        }
211
      ]
212
    },
213
    {
214
      "cell_type": "markdown",
215
      "source": [
216
        "## **Load:** ARC Benchmark\n",
217
        "\n",
218
        "Load tasks from the ARC benchmark."
219
      ],
220
      "metadata": {
221
        "id": "SD_-rknea0WU"
222
      }
223
    },
224
    {
225
      "cell_type": "code",
226
      "source": [
227
        "def state_to_tokens(state, value_to_token_fn):\n",
228
        "  tokens = []\n",
229
        "  for row in state:\n",
230
        "    for i, value in enumerate(row):\n",
231
        "      tokens +=[value_to_token_fn(value)]\n",
232
        "      if i < len(row) - 1:\n",
233
        "        tokens += item_delim\n",
234
        "    tokens += row_delim\n",
235
        "  return tokens\n",
236
        "\n",
237
        "\n",
238
        "def task_json_to_tokens(task_json, value_to_token_fn):\n",
239
        "\n",
240
        "  # Training examples.\n",
241
        "  train_samples = []\n",
242
        "  for sample in task_json[\"train\"]:\n",
243
        "    tokens = []\n",
244
        "    tokens += tokenizer.encode(\"input:\\n\")\n",
245
        "    tokens += state_to_tokens(sample[\"input\"], value_to_token_fn)\n",
246
        "    tokens += tokenizer.encode(\"output:\\n\")\n",
247
        "    tokens += state_to_tokens(sample[\"output\"], value_to_token_fn)\n",
248
        "    tokens += sample_delim\n",
249
        "    train_samples.append(tokens)\n",
250
        "\n",
251
        "  # Testing examples.\n",
252
        "  test_inputs = []\n",
253
        "  test_outputs = []\n",
254
        "  for sample in task_json[\"test\"]:\n",
255
        "    inputs, outputs = [], []\n",
256
        "    inputs += tokenizer.encode(\"input:\\n\")\n",
257
        "    inputs += state_to_tokens(sample[\"input\"], value_to_token_fn)\n",
258
        "    inputs += tokenizer.encode(\"output:\\n\")\n",
259
        "    test_inputs.append(inputs)\n",
260
        "    outputs += state_to_tokens(sample[\"output\"], value_to_token_fn)\n",
261
        "    test_outputs.append(outputs)\n",
262
        "  return train_samples, test_inputs, test_outputs"
263
      ],
264
      "metadata": {
265
        "id": "5rvACw0XFZWY"
266
      },
267
      "execution_count": null,
268
      "outputs": []
269
    },
270
    {
271
      "cell_type": "code",
272
      "source": [
273
        "tasks_jsons = []\n",
274
        "tasks_names = []\n",
275
        "tasks_len = []\n",
276
        "task_dir = \"ARC/data/training\"\n",
277
        "for task_file in sorted(os.listdir(task_dir)):\n",
278
        "  with open(os.path.join(task_dir, task_file)) as fid:\n",
279
        "    task_json = json.load(fid)\n",
280
        "  tasks_jsons.append(task_json)\n",
281
        "  tasks_names.append(task_file)\n",
282
        "  tokens, _, _ = task_json_to_tokens(task_json, value_to_token)\n",
283
        "  tasks_len.append(np.sum([len(sample) for sample in tokens]))\n",
284
        "\n",
285
        "task_dir = \"ARC/data/evaluation\"\n",
286
        "for task_file in sorted(os.listdir(task_dir)):\n",
287
        "  with open(os.path.join(task_dir, task_file)) as fid:\n",
288
        "    task_json = json.load(fid)\n",
289
        "  tasks_jsons.append(task_json)\n",
290
        "  tasks_names.append(task_file)\n",
291
        "  tokens, _, _ = task_json_to_tokens(task_json, value_to_token)\n",
292
        "  tasks_len.append(np.sum([len(sample) for sample in tokens]))\n",
293
        "\n",
294
        "sorted_task_ids = np.argsort(tasks_len)\n",
295
        "\n",
296
        "print(\"Total number of tasks:\", len(sorted_task_ids))"
297
      ],
298
      "metadata": {
299
        "id": "zZY7OSoHbIRk",
300
        "colab": {
301
          "base_uri": "https://localhost:8080/"
302
        },
303
        "outputId": "2fda0d4e-1014-4888-dd2a-d76fb4942ff2"
304
      },
305
      "execution_count": null,
306
      "outputs": [
307
        {
308
          "output_type": "stream",
309
          "name": "stdout",
310
          "text": [
311
            "Total number of tasks: 800\n"
312
          ]
313
        }
314
      ]
315
    },
316
    {
317
      "cell_type": "markdown",
318
      "source": [
319
        "## **Example:** ARC Problem\n",
320
        "\n",
321
        "Show the LLM prompt for an ARC problem and visualize the grids used as inputs and outputs."
322
      ],
323
      "metadata": {
324
        "id": "2wBok5T59kDT"
325
      }
326
    },
327
    {
328
      "cell_type": "code",
329
      "source": [
330
        "colors = [(0, 0, 0),\n",
331
        "          (0, 116, 217),\n",
332
        "          (255, 65, 54),\n",
333
        "          (46, 204, 6),\n",
334
        "          (255, 220, 0),\n",
335
        "          (170, 170, 170),\n",
336
        "          (240, 18, 190),\n",
337
        "          (255, 133, 27),\n",
338
        "          (127, 219, 255),\n",
339
        "          (135, 12, 37)]\n",
340
        "\n",
341
        "def grid_to_img(grid):\n",
342
        "  grid = np.int32(grid)\n",
343
        "  scale = 10\n",
344
        "  img = np.zeros((grid.shape[0] * scale + 1, grid.shape[1] * scale + 1, 3), dtype=np.uint8)\n",
345
        "  for r in range(grid.shape[0]):\n",
346
        "    for c in range(grid.shape[1]):\n",
347
        "      img[r*scale+1:(r+1)*scale, c*scale+1:(c+1)*scale, :] = colors[grid[r, c]]\n",
348
        "  new_img = img.copy()\n",
349
        "  new_img[0::10, :, :] = np.uint8(np.round((0.7 * np.float32(img[0::10, :, :]) + 0.3 * 255)))\n",
350
        "  new_img[:, 0::10, :] = np.uint8(np.round((0.7 * np.float32(img[:, 0::10, :]) + 0.3 * 255)))\n",
351
        "  return new_img"
352
      ],
353
      "metadata": {
354
        "id": "qCl4heCw88MG"
355
      },
356
      "execution_count": null,
357
      "outputs": []
358
    },
359
    {
360
      "cell_type": "code",
361
      "source": [
362
        "example_json = tasks_jsons[sorted_task_ids[0]]\n",
363
        "\n",
364
        "context = []\n",
365
        "train_xy, test_x, test_y = task_json_to_tokens(example_json, value_to_token)\n",
366
        "for sample in train_xy:\n",
367
        "  context += sample\n",
368
        "context += test_x[0]\n",
369
        "\n",
370
        "print(\"PROMPT:\")\n",
371
        "print(tokenizer.decode(context, skip_special_tokens=True))\n",
372
        "print(\"SOLUTION:\")\n",
373
        "print(tokenizer.decode(test_y[0], skip_special_tokens=True))\n",
374
        "\n",
375
        "# Show problem.\n",
376
        "print(\"TRAIN:\")\n",
377
        "for i, ex in enumerate(example_json[\"train\"]):\n",
378
        "  in_img = grid_to_img(ex[\"input\"])\n",
379
        "  out_img = grid_to_img(ex[\"output\"])\n",
380
        "  plt.subplot(1, 2, 1); plt.imshow(grid_to_img(ex[\"input\"]))\n",
381
        "  plt.subplot(1, 2, 2); plt.imshow(grid_to_img(ex[\"output\"]))\n",
382
        "  plt.show()\n",
383
        "print(\"TEST:\")\n",
384
        "for i, ex in enumerate(example_json[\"test\"]):\n",
385
        "  in_img = grid_to_img(ex[\"input\"])\n",
386
        "  out_img = grid_to_img(ex[\"output\"])\n",
387
        "  plt.subplot(1, 2, 1); plt.imshow(grid_to_img(ex[\"input\"]))\n",
388
        "  plt.subplot(1, 2, 2); plt.imshow(grid_to_img(ex[\"output\"]))\n",
389
        "  plt.show()"
390
      ],
391
      "metadata": {
392
        "id": "orLb7781byY3",
393
        "colab": {
394
          "base_uri": "https://localhost:8080/",
395
          "height": 1000
396
        },
397
        "outputId": "a5e6948a-ccf2-4dc6-bf03-d7902c07e85a"
398
      },
399
      "execution_count": null,
400
      "outputs": [
401
        {
402
          "output_type": "stream",
403
          "name": "stdout",
404
          "text": [
405
            "PROMPT:\n",
406
            "input:\n",
407
            " 3, 3, 8\n",
408
            " 3, 7, 0\n",
409
            " 5, 0, 0\n",
410
            "output:\n",
411
            " 0, 0, 5\n",
412
            " 0, 7, 3\n",
413
            " 8, 3, 3\n",
414
            "---\n",
415
            "input:\n",
416
            " 5, 5, 2\n",
417
            " 1, 0, 0\n",
418
            " 0, 0, 0\n",
419
            "output:\n",
420
            " 0, 0, 0\n",
421
            " 0, 0, 1\n",
422
            " 2, 5, 5\n",
423
            "---\n",
424
            "input:\n",
425
            " 6, 3, 5\n",
426
            " 6, 8, 0\n",
427
            " 4, 0, 0\n",
428
            "output:\n",
429
            "\n",
430
            "SOLUTION:\n",
431
            " 0, 0, 4\n",
432
            " 0, 8, 6\n",
433
            " 5, 3, 6\n",
434
            "\n",
435
            "TRAIN:\n"
436
          ]
437
        },
438
        {
439
          "output_type": "display_data",
440
          "data": {
441
            "text/plain": [
442
              "<Figure size 640x480 with 2 Axes>"
443
            ],
444
            "image/png": "\n"
445
          },
446
          "metadata": {}
447
        },
448
        {
449
          "output_type": "display_data",
450
          "data": {
451
            "text/plain": [
452
              "<Figure size 640x480 with 2 Axes>"
453
            ],
454
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiQAAAEPCAYAAABycN8YAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZdUlEQVR4nO3dcUxV9/3/8RcoXK3CpahwJYKj1WlbC02Y0hs7YycTWWy0xWR2S8StqZGBieW7dNK0dXZLbqP5rbYdpcmWSU1m2boMTWmqbVEx7dBFJqHWSaphk6aArQkXpeWK8vn9YXq7q8D1woXP5fJ8JCfxnnPuuW8+ja++vJwLMcYYIwAAAItibQ8AAABAIQEAANZRSAAAgHUUEgAAYB2FBAAAWEchAQAA1lFIAACAdRQSAABgHYUEAABYRyEBAADWTR6tC1dUVGjXrl3q6OhQdna2Xn31VS1ZsiTo8/r7+/X5558rISFBMTExozUegCEYY3T58mWlpaUpNnbs/t0y3NyQyA7AthHnhhkF1dXVJj4+3vzpT38yn3zyiXnyySdNUlKS6ezsDPrctrY2I4mNjS0Ctra2ttGIiAGNJDeMITvY2CJlG25uxBgT/l+ul5ubq8WLF+v3v/+9pBv/cklPT9eWLVu0bdu2IZ/r9XqVlJSkFStWaPLkUXsDB8AQrl27prq6OnV1dcnpdI7Ja44kNySyA7BtpLkR9r+1V69eVWNjo8rLy/37YmNjlZeXp4aGhlvO9/l88vl8/seXL1++MdjkyYqLiwv3eABCMFbf+gg1NySyA4hUw82NsH9z+Msvv9T169eVmpoasD81NVUdHR23nO/xeOR0Ov1benp6uEcCEOFCzQ2J7ACijfVP2ZSXl8vr9fq3trY22yMBGAfIDiC6hP1bNjNnztSkSZPU2dkZsL+zs1Mul+uW8x0OhxwOR7jHADCOhJobEtkBRJuwF5L4+Hjl5OSorq5Oa9eulXTj5rS6ujqVlpaO+PobNmwY8TWiwd69ewc9xhrdMNQavd1zcQwniVyPTEuxPYKk0c8NSaqtrQ3Ldca71atXD3qMNbqBNQpuqDUarlG5Fb2srExFRUX63ve+pyVLlmj37t3q6enRz372s9F4OQBRgNwAJrZRKSQ//vGP9cUXX+j5559XR0eHHnjgAR08ePCWG9YA4BvkBjCxjdqH9UtLS8P2ViuAiYHcACYu65+yAQAAoJAAAADrKCQAAMA6CgkAALCOQgIAAKyjkAAAAOsoJAAAwDoKCQAAsI5CAgAArKOQAAAA6ygkAADAOgoJAACwjkICAACso5AAAADrKCQAAMA6CgkAALCOQgIAAKyjkAAAAOsoJAAAwDoKCQAAsI5CAgAArKOQAAAA6ygkAADAOgoJAACwjkICAACsC3sh+fWvf62YmJiAbeHCheF+GQBRhNwAMHk0Lnrffffpgw8++PZFJo/KywCIIuQGMLGNyt/4yZMny+VyjcalAUQpcgOY2EblHpJPP/1UaWlpuuuuu/TTn/5UFy5cGI2XARBFyA1gYgv7OyS5ubmqqqrSggUL1N7erh07duj73/++Tp8+rYSEhFvO9/l88vl8/sfd3d3hHglAhAs1NySyA4g2YS8kBQUF/j9nZWUpNzdXc+fO1V//+lc98cQTt5zv8Xi0Y8eOcI8BYBwJNTcksgOINqP+sd+kpCR997vf1blz5wY8Xl5eLq/X69/a2tpGeyQAES5YbkhkBxBtRr2QXLlyRefPn9fs2bMHPO5wOJSYmBiwAZjYguWGRHYA0SbsheSXv/yl6uvr9Z///Ef/+Mc/9Oijj2rSpEl6/PHHw/1SAKIEuQEg7PeQfPbZZ3r88cd16dIlzZo1Sw899JCOHz+uWbNmhfulAEQJcgNA2AtJdXV1uC8JIMqRGwD4XTYAAMA6CgkAALCOQgIAAKyjkAAAAOsoJAAAwDoKCQAAsI5CAgAArKOQAAAA6ygkAADAOgoJAACwjkICAACso5AAAADrKCQAAMA6CgkAALCOQgIAAKyjkAAAAOsoJAAAwDoKCQAAsI5CAgAArKOQAAAA6ygkAADAOgoJAACwjkICAACso5AAAADrKCQAAMA6CgkAALAuxhhjQnnCsWPHtGvXLjU2Nqq9vV01NTVau3at/7gxRtu3b9cf/vAHdXV1aenSpaqsrNT8+fNv6/rd3d1yOp3Kz89XXFxcSF8MgPDo6+vToUOH5PV6lZiYOOLrjXZuSGQHYNtIcyPkd0h6enqUnZ2tioqKAY/v3LlTr7zyil5//XWdOHFC06ZNU35+vnp7e0MeDkB0IDcABDM51CcUFBSooKBgwGPGGO3evVvPPvus1qxZI0nau3evUlNTtX//fq1fv35k0wIYl8gNAMGE9R6S1tZWdXR0KC8vz7/P6XQqNzdXDQ0NAz7H5/Opu7s7YAMwcQwnNySyA4g2YS0kHR0dkqTU1NSA/ampqf5jN/N4PHI6nf4tPT09nCMBiHDDyQ2J7ACijfVP2ZSXl8vr9fq3trY22yMBGAfIDiC6hLWQuFwuSVJnZ2fA/s7OTv+xmzkcDiUmJgZsACaO4eSGRHYA0Sbkm1qHkpmZKZfLpbq6Oj3wwAOSbnwU78SJEyouLg7La9Qu/H9huc54t/rs/w16rLa2dgwniVyrV68e9BhrdMNQazRWxiI3JP6bf4O/F8ENuUb8P0jS0P8PGq6QC8mVK1d07tw5/+PW1lY1NTUpOTlZGRkZ2rp1q377299q/vz5yszM1HPPPae0tLSAnzkAYGIhNwAEE3IhOXnypB5++GH/47KyMklSUVGRqqqq9PTTT6unp0ebNm1SV1eXHnroIR08eFBTpkwJ39QAxhVyA0AwIReS5cuXa6gf7hoTE6MXXnhBL7zwwogGAxA9yA0AwVj/lA0AAACFBAAAWEchAQAA1lFIAACAdRQSAABgHYUEAABYRyEBAADWUUgAAIB1FBIAAGAdhQQAAFhHIQEAANZRSAAAgHUUEgAAYB2FBAAAWEchAQAA1lFIAACAdRQSAABgHYUEAABYRyEBAADWUUgAAIB1FBIAAGAdhQQAAFhHIQEAANZRSAAAgHUUEgAAYF3IheTYsWN65JFHlJaWppiYGO3fvz/g+MaNGxUTExOwrVq1KlzzAhiHyA0AwYRcSHp6epSdna2KiopBz1m1apXa29v925tvvjmiIQGMb+QGgGAmh/qEgoICFRQUDHmOw+GQy+Ua9lAAogu5ASCYUbmH5OjRo0pJSdGCBQtUXFysS5cujcbLAIgi5AYwsYX8Dkkwq1at0mOPPabMzEydP39ezzzzjAoKCtTQ0KBJkybdcr7P55PP5/M/7u7uDvdIACJcqLkhkR1AtAl7IVm/fr3/z/fff7+ysrJ099136+jRo1qxYsUt53s8Hu3YsSPcYwAYR0LNDYnsAKLNqH/s96677tLMmTN17ty5AY+Xl5fL6/X6t7a2ttEeCUCEC5YbEtkBRJuwv0Nys88++0yXLl3S7NmzBzzucDjkcDhGewwA40iw3JDIDiDahFxIrly5EvCvltbWVjU1NSk5OVnJycnasWOHCgsL5XK5dP78eT399NOaN2+e8vPzwzo4gPGD3AAQTMiF5OTJk3r44Yf9j8vKyiRJRUVFqqysVHNzs9544w11dXUpLS1NK1eu1G9+8xv+JQNMYOQGgGBCLiTLly+XMWbQ44cOHRrRQACiD7kBIBh+lw0AALCOQgIAAKyjkAAAAOsoJAAAwDoKCQAAsI5CAgAArKOQAAAA6ygkAADAOgoJAACwjkICAACso5AAAADrKCQAAMA6CgkAALCOQgIAAKyjkAAAAOsoJAAAwDoKCQAAsI5CAgAArKOQAAAA6ygkAADAOgoJAACwjkICAACso5AAAADrKCQAAMA6CgkAALAupELi8Xi0ePFiJSQkKCUlRWvXrlVLS0vAOb29vSopKdGMGTM0ffp0FRYWqrOzM6xDAxhfyA4AwcQYY8ztnrxq1SqtX79eixcv1rVr1/TMM8/o9OnTOnPmjKZNmyZJKi4u1jvvvKOqqio5nU6VlpYqNjZWH3300W29Rnd3t5xOp/Lz8xUXFze8rwrAiPT19enQoUPyer1KTEwc8fXIDiD6jTQ3QiokN/viiy+UkpKi+vp6LVu2TF6vV7NmzdK+ffu0bt06SdLZs2d1zz33qKGhQQ8++GDQaxIqgH3hLiQ3IzuA6DPS3BjRPSRer1eSlJycLElqbGxUX1+f8vLy/OcsXLhQGRkZamhoGMlLAYgiZAeAm00e7hP7+/u1detWLV26VIsWLZIkdXR0KD4+XklJSQHnpqamqqOjY8Dr+Hw++Xw+/+Pu7u7hjgRgHCA7AAxk2O+QlJSU6PTp06qurh7RAB6PR06n07+lp6eP6HoAIhvZAWAgwyokpaWlqq2t1ZEjRzRnzhz/fpfLpatXr6qrqyvg/M7OTrlcrgGvVV5eLq/X69/a2tqGMxKAcYDsADCYkAqJMUalpaWqqanR4cOHlZmZGXA8JydHcXFxqqur8+9raWnRhQsX5Ha7B7ymw+FQYmJiwAYgupAdAIIJ6R6SkpIS7du3TwcOHFBCQoL/e7tOp1NTp06V0+nUE088obKyMiUnJysxMVFbtmyR2+2+rbvkb0dtbW1YrjPerV69etBjrNENrFFwQ61ROEVCdrzdczEs1xnvHpmWMuixDRs2jOEkkWvv3r2DHmONbhhqjYYrpEJSWVkpSVq+fHnA/j179mjjxo2SpJdeekmxsbEqLCyUz+dTfn6+XnvttbAMC2B8IjsABBNSIbmdH1kyZcoUVVRUqKKiYthDAYguZAeAYPhdNgAAwDoKCQAAsI5CAgAArKOQAAAA6ygkAADAOgoJAACwjkICAACso5AAAADrKCQAAMA6CgkAALCOQgIAAKyjkAAAAOsoJAAAwDoKCQAAsI5CAgAArKOQAAAA6ygkAADAOgoJAACwjkICAACso5AAAADrKCQAAMA6CgkAALCOQgIAAKyjkAAAAOsoJAAAwDoKCQAAsC6kQuLxeLR48WIlJCQoJSVFa9euVUtLS8A5y5cvV0xMTMC2efPmsA4NYHwhOwAEE1Ihqa+vV0lJiY4fP673339ffX19WrlypXp6egLOe/LJJ9Xe3u7fdu7cGdahAYwvZAeAYCaHcvLBgwcDHldVVSklJUWNjY1atmyZf/8dd9whl8sVngkBjHtkB4BgRnQPidfrlSQlJycH7P/zn/+smTNnatGiRSovL9dXX3016DV8Pp+6u7sDNgDRjewAcLOQ3iH5X/39/dq6dauWLl2qRYsW+ff/5Cc/0dy5c5WWlqbm5mb96le/UktLi/7+978PeB2Px6MdO3YMdwwA4wzZAWAgwy4kJSUlOn36tD788MOA/Zs2bfL/+f7779fs2bO1YsUKnT9/Xnffffct1ykvL1dZWZn/cXd3t9LT04c7FoAIR3YAGMiwCklpaalqa2t17NgxzZkzZ8hzc3NzJUnnzp0bMFQcDoccDsdwxgAwzpAdAAYTUiExxmjLli2qqanR0aNHlZmZGfQ5TU1NkqTZs2cPa0AA4x/ZASCYkApJSUmJ9u3bpwMHDighIUEdHR2SJKfTqalTp+r8+fPat2+ffvSjH2nGjBlqbm7WU089pWXLlikrK2tUvgAAkY/sABBMSIWksrJS0o0fYPS/9uzZo40bNyo+Pl4ffPCBdu/erZ6eHqWnp6uwsFDPPvts2AYGMP6QHQCCCflbNkNJT09XfX39iAYCEH3IDgDB8LtsAACAdRQSAABgHYUEAABYRyEBAADWUUgAAIB1FBIAAGAdhQQAAFhHIQEAANZRSAAAgHUUEgAAYB2FBAAAWEchAQAA1lFIAACAdRQSAABgHYUEAABYRyEBAADWUUgAAIB1FBIAAGAdhQQAAFhHIQEAANZRSAAAgHUUEgAAYB2FBAAAWEchAQAA1lFIAACAdSEVksrKSmVlZSkxMVGJiYlyu9169913/cd7e3tVUlKiGTNmaPr06SosLFRnZ2fYhwYwvpAdAIKJMcaY2z357bff1qRJkzR//nwZY/TGG29o165dOnXqlO677z4VFxfrnXfeUVVVlZxOp0pLSxUbG6uPPvrotgfq7u6W0+lUfn6+4uLihvVFARiZvr4+HTp0SF6vV4mJiSO+HtkBRL+R5kZIhWQgycnJ2rVrl9atW6dZs2Zp3759WrdunSTp7Nmzuueee9TQ0KAHH3zwtq5HqAD2hbuQDITsAKLLSHNj2PeQXL9+XdXV1erp6ZHb7VZjY6P6+vqUl5fnP2fhwoXKyMhQQ0PDcF8GQJQhOwAMZHKoT/j444/ldrvV29ur6dOnq6amRvfee6+ampoUHx+vpKSkgPNTU1PV0dEx6PV8Pp98Pp//cXd3d6gjARgHyA4AQwn5HZIFCxaoqalJJ06cUHFxsYqKinTmzJlhD+DxeOR0Ov1benr6sK8FIHKRHQCGEnIhiY+P17x585STkyOPx6Ps7Gy9/PLLcrlcunr1qrq6ugLO7+zslMvlGvR65eXl8nq9/q2trS3kLwJA5CM7AAwl5G/Z3Ky/v18+n085OTmKi4tTXV2dCgsLJUktLS26cOGC3G73oM93OBxyOBz+x9/cY3vt2rWRjgZgmL75+zfCe96HRHYA0WXEuWFCsG3bNlNfX29aW1tNc3Oz2bZtm4mJiTHvvfeeMcaYzZs3m4yMDHP48GFz8uRJ43a7jdvtDuUlTFtbm5HExsYWAVtbW1tIf3/JDjY2tuHmRkjvkFy8eFEbNmxQe3u7nE6nsrKydOjQIf3whz+UJL300kuKjY1VYWGhfD6f8vPz9dprr4XyEkpLS1NbW5sSEhIUExOj7u5upaenq62tbdQ+fjjesUZDY32Cu3mNjDG6fPmy0tLSwnJ9siPysD7BsUbB/e8aJSQkjCg3RvxzSEbbNz9bYDR/HsJ4xxoNjfUJLhrXKBq/pnBifYJjjYIL5xrxu2wAAIB1FBIAAGBdxBcSh8Oh7du3B9xNj0Cs0dBYn+CicY2i8WsKJ9YnONYouHCuUcTfQwIAAKJfxL9DAgAAoh+FBAAAWEchAQAA1lFIAACAdRFdSCoqKvSd73xHU6ZMUW5urv75z3/aHsmaY8eO6ZFHHlFaWppiYmK0f//+gOPGGD3//POaPXu2pk6dqry8PH366ad2hrXE4/Fo8eLFSkhIUEpKitauXauWlpaAc3p7e1VSUqIZM2Zo+vTpKiwsVGdnp6WJx1ZlZaWysrKUmJioxMREud1uvfvuu/7j0bQ2ZMe3yI6hkRvBjVV2RGwh+ctf/qKysjJt375d//rXv5Sdna38/HxdvHjR9mhW9PT0KDs7WxUVFQMe37lzp1555RW9/vrrOnHihKZNm6b8/Hz19vaO8aT21NfXq6SkRMePH9f777+vvr4+rVy5Uj09Pf5znnrqKb399tt66623VF9fr88//1yPPfaYxanHzpw5c/Tiiy+qsbFRJ0+e1A9+8AOtWbNGn3zyiaToWRuyIxDZMTRyI7gxy45h/QacMbBkyRJTUlLif3z9+nWTlpZmPB6PxakigyRTU1Pjf9zf329cLpfZtWuXf19XV5dxOBzmzTfftDBhZLh48aKRZOrr640xN9YkLi7OvPXWW/5z/v3vfxtJpqGhwdaYVt15553mj3/8Y1StDdkxOLIjOHLj9oxGdkTkOyRXr15VY2Oj8vLy/PtiY2OVl5enhoYGi5NFptbWVnV0dASsl9PpVG5u7oReL6/XK0lKTk6WJDU2Nqqvry9gnRYuXKiMjIwJt07Xr19XdXW1enp65Ha7o2ZtyI7QkB23IjeGNprZEdJv+x0rX375pa5fv67U1NSA/ampqTp79qylqSJXR0eHJA24Xt8cm2j6+/u1detWLV26VIsWLZJ0Y53i4+OVlJQUcO5EWqePP/5Ybrdbvb29mj59umpqanTvvfeqqakpKtaG7AgN2RGI3BjcWGRHRBYSYKRKSkp0+vRpffjhh7ZHiSgLFixQU1OTvF6v/va3v6moqEj19fW2xwIiArkxuLHIjoj8ls3MmTM1adKkW+7S7ezslMvlsjRV5PpmTVivG0pLS1VbW6sjR45ozpw5/v0ul0tXr15VV1dXwPkTaZ3i4+M1b9485eTkyOPxKDs7Wy+//HLUrA3ZERqy41vkxtDGIjsispDEx8crJydHdXV1/n39/f2qq6uT2+22OFlkyszMlMvlCliv7u5unThxYkKtlzFGpaWlqqmp0eHDh5WZmRlwPCcnR3FxcQHr1NLSogsXLkyodfpf/f398vl8UbM2ZEdoyA5yY7hGJTvCe99t+FRXVxuHw2GqqqrMmTNnzKZNm0xSUpLp6OiwPZoVly9fNqdOnTKnTp0ykszvfvc7c+rUKfPf//7XGGPMiy++aJKSksyBAwdMc3OzWbNmjcnMzDRff/215cnHTnFxsXE6nebo0aOmvb3dv3311Vf+czZv3mwyMjLM4cOHzcmTJ43b7TZut9vi1GNn27Ztpr6+3rS2tprm5mazbds2ExMTY9577z1jTPSsDdkRiOwYGrkR3FhlR8QWEmOMefXVV01GRoaJj483S5YsMcePH7c9kjVHjhwxkm7ZioqKjDE3Pr733HPPmdTUVONwOMyKFStMS0uL3aHH2EDrI8ns2bPHf87XX39tfvGLX5g777zT3HHHHebRRx817e3t9oYeQz//+c/N3LlzTXx8vJk1a5ZZsWKFP1CMia61ITu+RXYMjdwIbqyyI8YYY4b5jg0AAEBYROQ9JAAAYGKhkAAAAOsoJAAAwDoKCQAAsI5CAgAArKOQAAAA6ygkAADAOgoJAACwjkICAACso5AAAADrKCQAAMA6CgkAALDu/wPBqPU8AVQmZgAAAABJRU5ErkJggg==\n"
455
          },
456
          "metadata": {}
457
        },
458
        {
459
          "output_type": "stream",
460
          "name": "stdout",
461
          "text": [
462
            "TEST:\n"
463
          ]
464
        },
465
        {
466
          "output_type": "display_data",
467
          "data": {
468
            "text/plain": [
469
              "<Figure size 640x480 with 2 Axes>"
470
            ],
471
            "image/png": "\n"
472
          },
473
          "metadata": {}
474
        }
475
      ]
476
    },
477
    {
478
      "cell_type": "markdown",
479
      "source": [
480
        "## **Evaluate:** ARC Benchmark\n",
481
        "\n",
482
        "Evaluate on the available 800 tasks.\n",
483
        "\n",
484
        "**Note:** LLM temperature is set to 0 (deterministic), but your results might still vary depending on stability of the API."
485
      ],
486
      "metadata": {
487
        "id": "A6Jz5lc5bQRx"
488
      }
489
    },
490
    {
491
      "cell_type": "code",
492
      "source": [
493
        "success = {}\n",
494
        "for task_id in sorted_task_ids:\n",
495
        "  task_json, task_name = tasks_jsons[task_id], tasks_names[task_id]\n",
496
        "\n",
497
        "  # Lazy load: skip evals where we already have results.\n",
498
        "  if task_name in success:\n",
499
        "    continue\n",
500
        "\n",
501
        "  # Build context and expected output labels.\n",
502
        "  context = []\n",
503
        "  batch_prompts = []\n",
504
        "  batch_labels = []\n",
505
        "  train_xy, test_x, test_y = task_json_to_tokens(task_json, value_to_token)\n",
506
        "  test_num_tokens = np.max([len(x) + len(y) for x, y in zip(test_x, test_y)])\n",
507
        "  for sample in train_xy:\n",
508
        "    if len(context) + len(sample) + test_num_tokens > token_limit:  # Ensure both train and test examples can fit in the prompt.\n",
509
        "      break\n",
510
        "    context += sample\n",
511
        "\n",
512
        "  # There can be multiple test examples so put them in the same batch.\n",
513
        "  for x, y in zip(test_x, test_y):\n",
514
        "    batch_prompts.append(context + x)\n",
515
        "    batch_labels.append(y)\n",
516
        "\n",
517
        "  # Run LLM.\n",
518
        "  try:\n",
519
        "    stop_token = tokenizer.decode(sample_delim, skip_special_tokens=True)\n",
520
        "    max_tokens = int(np.max([len(y) for y in test_y])) + 10\n",
521
        "    batch_responses = LLM(batch_prompts, stop=stop_token, max_tokens=max_tokens, temperature=0)\n",
522
        "  except Exception as e:\n",
523
        "    print(task_name, f\"LLM failed. {e}\")\n",
524
        "    continue\n",
525
        "\n",
526
        "  # Check answers and save success rates.\n",
527
        "  success[task_name] = 0\n",
528
        "  for response, label in zip(batch_responses, batch_labels):\n",
529
        "    label_str = tokenizer.decode(label, skip_special_tokens=True)\n",
530
        "    is_success = label_str.strip() in response\n",
531
        "    success[task_name] += is_success / len(batch_labels)\n",
532
        "  success[task_name] = int(success[task_name] > 0.99)  # All test cases need to correct.\n",
533
        "\n",
534
        "  # Debug prints.\n",
535
        "  total_success = np.sum(list(success.values()))\n",
536
        "  print(task_name, \"Success:\", success[task_name], \"Total:\", f\"{total_success} / {len(success)}\")\n",
537
        "\n",
538
        "  # # Save results.\n",
539
        "  # result_file = f\"arc-{model}-alphabet-{'-'.join(map(str, alphabet))}.pkl\"\n",
540
        "  # with open(result_file, 'wb') as fid:\n",
541
        "  #   pickle.dump(success, fid, protocol=pickle.HIGHEST_PROTOCOL)"
542
      ],
543
      "metadata": {
544
        "id": "GuFgcjPxsxXG"
545
      },
546
      "execution_count": null,
547
      "outputs": []
548
    }
549
  ]
550
}

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.