haystack-tutorials

Форк
0
/
28_Structured_Output_With_Loop.ipynb 
485 строк · 17.0 Кб
1
{
2
  "cells": [
3
    {
4
      "cell_type": "markdown",
5
      "metadata": {
6
        "id": "AVBtOVlNJ51C"
7
      },
8
      "source": [
9
        "# Tutorial: Generating Structured Output with Loop-Based Auto-Correction\n",
10
        "\n",
11
        "- **Level**: Intermediate\n",
12
        "- **Time to complete**: 15 minutes\n",
13
        "- **Prerequisites**: You must have an API key from an active OpenAI account as this tutorial is using the gpt-3.5-turbo model by OpenAI.\n",
14
        "- **Components Used**: `PromptBuilder`, `OpenAIGenerator`, `OutputValidator` (Custom component)\n",
15
        "- **Goal**: After completing this tutorial, you will have built a system that extracts unstructured data, puts it in a JSON schema, and automatically corrects errors in the JSON output from a large language model (LLM) to make sure it follows the specified structure.\n",
16
        "\n",
17
        "> This tutorial uses Haystack 2.0 Beta. To learn more, read the [Haystack 2.0 Beta announcement](https://haystack.deepset.ai/blog/introducing-haystack-2-beta-and-advent) or see [Haystack 2.0 Beta Documentation](https://docs.haystack.deepset.ai/v2.0/docs).\n",
18
        "\n",
19
        "## Overview\n",
20
        "This tutorial demonstrates how to use Haystack 2.0-Beta's advanced [looping pipelines](https://docs.haystack.deepset.ai/v2.0/docs/pipelines#loops) with LLMs for more dynamic and flexible data processing. You'll learn how to extract structured data from unstructured data using an LLM, and to validate the generated output against a predefined schema.\n",
21
        "\n",
22
        "This tutorial uses `gpt-3.5-turbo` to change unstructured passages into JSON outputs that follow the [Pydantic](https://github.com/pydantic/pydantic) schema. It uses a custom OutputValidator component to validate the JSON and loop back to make corrections, if necessary."
23
      ]
24
    },
25
    {
26
      "cell_type": "markdown",
27
      "metadata": {
28
        "id": "jmiAHh1oGsKI"
29
      },
30
      "source": [
31
        "## Preparing the Colab Environment\n",
32
        "\n",
33
        "Enable the debug mode of logging:"
34
      ]
35
    },
36
    {
37
      "cell_type": "code",
38
      "execution_count": null,
39
      "metadata": {
40
        "id": "Vor9IHuNRvEh"
41
      },
42
      "outputs": [],
43
      "source": [
44
        "import logging\n",
45
        "\n",
46
        "logging.basicConfig()\n",
47
        "logging.getLogger(\"canals.pipeline.pipeline\").setLevel(logging.DEBUG)"
48
      ]
49
    },
50
    {
51
      "cell_type": "markdown",
52
      "metadata": {
53
        "id": "ljbWiyJkKiPw"
54
      },
55
      "source": [
56
        "## Installing Dependencies\n",
57
        "Install Haystack 2.0 Beta and [colorama](https://pypi.org/project/colorama/) with pip:"
58
      ]
59
    },
60
    {
61
      "cell_type": "code",
62
      "execution_count": null,
63
      "metadata": {
64
        "colab": {
65
          "base_uri": "https://localhost:8080/"
66
        },
67
        "id": "kcc1AlLQd_jI",
68
        "outputId": "efc4bbab-a9fe-46ee-d8af-9d86edacaf04"
69
      },
70
      "outputs": [],
71
      "source": [
72
        "%%bash\n",
73
        "\n",
74
        "pip install haystack-ai\n",
75
        "pip install colorama"
76
      ]
77
    },
78
    {
79
      "cell_type": "markdown",
80
      "metadata": {
81
        "id": "nTA5fdvCLMKD"
82
      },
83
      "source": [
84
        "### Enabling Telemetry\n",
85
        "\n",
86
        "Enable telemetry to let us know you're using this tutorial. (You can always opt out by commenting out this line). For details, see [Telemetry](https://docs.haystack.deepset.ai/docs/telemetry)."
87
      ]
88
    },
89
    {
90
      "cell_type": "code",
91
      "execution_count": null,
92
      "metadata": {
93
        "id": "Apay3QSQLKdM"
94
      },
95
      "outputs": [],
96
      "source": [
97
        "from haystack.telemetry import tutorial_running\n",
98
        "\n",
99
        "tutorial_running(28)"
100
      ]
101
    },
102
    {
103
      "cell_type": "markdown",
104
      "metadata": {
105
        "id": "Cmjfa8CiCeFl"
106
      },
107
      "source": [
108
        "## Defining a Schema to Parse the JSON Object\n",
109
        "\n",
110
        "Define a simple JSON schema for the data you want to extract from a text passsage using the LLM. As the first step, define two [Pydantic models](https://docs.pydantic.dev/1.10/usage/models/), `City` and `CitiesData`, with suitable fields and types."
111
      ]
112
    },
113
    {
114
      "cell_type": "code",
115
      "execution_count": null,
116
      "metadata": {
117
        "id": "xwKrDOOGdaAz"
118
      },
119
      "outputs": [],
120
      "source": [
121
        "from typing import List\n",
122
        "from pydantic import BaseModel\n",
123
        "\n",
124
        "\n",
125
        "class City(BaseModel):\n",
126
        "    name: str\n",
127
        "    country: str\n",
128
        "    population: int\n",
129
        "\n",
130
        "\n",
131
        "class CitiesData(BaseModel):\n",
132
        "    cities: List[City]"
133
      ]
134
    },
135
    {
136
      "cell_type": "markdown",
137
      "metadata": {
138
        "id": "zv-6-l_PCeFl"
139
      },
140
      "source": [
141
        "> You can change these models according to the format you wish to extract from the text."
142
      ]
143
    },
144
    {
145
      "cell_type": "markdown",
146
      "metadata": {
147
        "id": "ouk1mAOUCeFl"
148
      },
149
      "source": [
150
        "Then, generate a JSON schema from Pydantic models using `schema_json()`. You will later on use this schema in the prompt to instruct the LLM.\n",
151
        "\n",
152
        "To learn more about the JSON schemas, visit [Pydantic Schema](https://docs.pydantic.dev/1.10/usage/schema/).  "
153
      ]
154
    },
155
    {
156
      "cell_type": "code",
157
      "execution_count": null,
158
      "metadata": {
159
        "id": "8Lg9_72jCeFl"
160
      },
161
      "outputs": [],
162
      "source": [
163
        "json_schema = CitiesData.schema_json(indent=2)"
164
      ]
165
    },
166
    {
167
      "cell_type": "markdown",
168
      "metadata": {
169
        "id": "KvNhg0bP7kfg"
170
      },
171
      "source": [
172
        "## Creating a Custom Component: OutputValidator\n",
173
        "\n",
174
        "`OutputValidator` is a custom component that validates if the JSON object the LLM generates complies with the provided [Pydantic model](https://docs.pydantic.dev/1.10/usage/models/). If it doesn't, OutputValidator returns an error message along with the incorrect JSON object to get it fixed in the next loop.\n",
175
        "\n",
176
        "For more details about custom components, see [Creating Custom Components](https://docs.haystack.deepset.ai/v2.0/docs/custom-components)."
177
      ]
178
    },
179
    {
180
      "cell_type": "code",
181
      "execution_count": null,
182
      "metadata": {
183
        "id": "yr6D8RN2d7Vy"
184
      },
185
      "outputs": [],
186
      "source": [
187
        "import json\n",
188
        "import random\n",
189
        "import pydantic\n",
190
        "from pydantic import ValidationError\n",
191
        "from typing import Optional, List\n",
192
        "from colorama import Fore\n",
193
        "from haystack import component\n",
194
        "\n",
195
        "# Define the component input parameters\n",
196
        "@component\n",
197
        "class OutputValidator:\n",
198
        "    def __init__(self, pydantic_model: pydantic.BaseModel):\n",
199
        "        self.pydantic_model = pydantic_model\n",
200
        "        self.iteration_counter = 0\n",
201
        "\n",
202
        "    # Define the component output\n",
203
        "    @component.output_types(valid_replies=List[str], invalid_replies=Optional[List[str]], error_message=Optional[str])\n",
204
        "    def run(self, replies: List[str]):\n",
205
        "\n",
206
        "        self.iteration_counter += 1\n",
207
        "\n",
208
        "        ## Try to parse the LLM's reply ##\n",
209
        "        # If the LLM's reply is a valid object, return `\"valid_replies\"`\n",
210
        "        try:\n",
211
        "            output_dict = json.loads(replies[0])\n",
212
        "            self.pydantic_model.parse_obj(output_dict)\n",
213
        "            print(\n",
214
        "                Fore.GREEN\n",
215
        "                + f\"OutputValidator at Iteration {self.iteration_counter}: Valid JSON from LLM - No need for looping: {replies[0]}\"\n",
216
        "            )\n",
217
        "            return {\"valid_replies\": replies}\n",
218
        "\n",
219
        "        # If the LLM's reply is corrupted or not valid, return \"invalid_replies\" and the \"error_message\" for LLM to try again\n",
220
        "        except (ValueError, ValidationError) as e:\n",
221
        "            print(\n",
222
        "                Fore.RED\n",
223
        "                + f\"OutputValidator at Iteration {self.iteration_counter}: Invalid JSON from LLM - Let's try again.\\n\"\n",
224
        "                f\"Output from LLM:\\n {replies[0]} \\n\"\n",
225
        "                f\"Error from OutputValidator: {e}\"\n",
226
        "            )\n",
227
        "            return {\"invalid_replies\": replies, \"error_message\": str(e)}"
228
      ]
229
    },
230
    {
231
      "cell_type": "markdown",
232
      "metadata": {
233
        "id": "vQ_TfSBkCeFm"
234
      },
235
      "source": [
236
        "Then, create an OutputValidator instance with `CitiesData` that you have created before."
237
      ]
238
    },
239
    {
240
      "cell_type": "code",
241
      "execution_count": null,
242
      "metadata": {
243
        "id": "bhPCLCBCCeFm"
244
      },
245
      "outputs": [],
246
      "source": [
247
        "output_validator = OutputValidator(pydantic_model=CitiesData)"
248
      ]
249
    },
250
    {
251
      "cell_type": "markdown",
252
      "metadata": {
253
        "id": "xcIWKjW4k42r"
254
      },
255
      "source": [
256
        "## Creating the Prompt\n",
257
        "\n",
258
        "Write instructions for the LLM for converting a passage into a JSON format. Ensure the instructions explain how to identify and correct errors if the JSON doesn't match the required schema. Once you create the prompt, initialize PromptBuilder to use it.  \n",
259
        "\n",
260
        "For information about Jinja2 template and PromptBuilder, see [PromptBuilder](https://docs.haystack.deepset.ai/v2.0/docs/promptbuilder)."
261
      ]
262
    },
263
    {
264
      "cell_type": "code",
265
      "execution_count": null,
266
      "metadata": {
267
        "id": "ohPpNALjdVKt"
268
      },
269
      "outputs": [],
270
      "source": [
271
        "from haystack.components.builders import PromptBuilder\n",
272
        "\n",
273
        "prompt_template = \"\"\"\n",
274
        "Create a JSON object from the information present in this passage: {{passage}}.\n",
275
        "Only use information that is present in the passage. Follow this JSON schema, but only return the actual instances without any additional schema definition:\n",
276
        "{{schema}}\n",
277
        "Make sure your response is a dict and not a list.\n",
278
        "{% if invalid_replies and error_message %}\n",
279
        "  You already created the following output in a previous attempt: {{invalid_replies}}\n",
280
        "  However, this doesn't comply with the format requirements from above and triggered this Python exception: {{error_message}}\n",
281
        "  Correct the output and try again. Just return the corrected output without any extra explanations.\n",
282
        "{% endif %}\n",
283
        "\"\"\"\n",
284
        "prompt_builder = PromptBuilder(template=prompt_template)"
285
      ]
286
    },
287
    {
288
      "cell_type": "markdown",
289
      "metadata": {
290
        "id": "KM9-Zq2FL7Nn"
291
      },
292
      "source": [
293
        "## Initalizing the Generator\n",
294
        "\n",
295
        "[OpenAIGenerator](https://docs.haystack.deepset.ai/v2.0/docs/openaigenerator) generates\n",
296
        "text using OpenAI's `gpt-3.5-turbo` model by default. Set the `OPENAI_API_KEY` variable and provide a model name to the Generator."
297
      ]
298
    },
299
    {
300
      "cell_type": "code",
301
      "execution_count": null,
302
      "metadata": {
303
        "id": "Z4cQteIgunUR"
304
      },
305
      "outputs": [],
306
      "source": [
307
        "import os\n",
308
        "from getpass import getpass\n",
309
        "\n",
310
        "from haystack.components.generators import OpenAIGenerator\n",
311
        "\n",
312
        "os.environ[\"OPENAI_API_KEY\"] = getpass(\"Enter OpenAI API key: \")\n",
313
        "generator = OpenAIGenerator()"
314
      ]
315
    },
316
    {
317
      "cell_type": "markdown",
318
      "metadata": {
319
        "id": "zbotIOgXHkC5"
320
      },
321
      "source": [
322
        "## Building the Pipeline\n",
323
        "\n",
324
        "Add all components to your pipeline and connect them. Add connections from `output_validator` back to the `prompt_builder` for cases where the produced JSON doesn't comply with the JSON schema. Set `max_loops_allowed` to avoid infinite looping."
325
      ]
326
    },
327
    {
328
      "cell_type": "code",
329
      "execution_count": null,
330
      "metadata": {
331
        "id": "eFglN9YEv-1W"
332
      },
333
      "outputs": [],
334
      "source": [
335
        "from haystack import Pipeline\n",
336
        "\n",
337
        "pipeline = Pipeline(max_loops_allowed=5)\n",
338
        "\n",
339
        "# Add components to your pipeline\n",
340
        "pipeline.add_component(instance=prompt_builder, name=\"prompt_builder\")\n",
341
        "pipeline.add_component(instance=generator, name=\"llm\")\n",
342
        "pipeline.add_component(instance=output_validator, name=\"output_validator\")\n",
343
        "\n",
344
        "# Now, connect the components to each other\n",
345
        "pipeline.connect(\"prompt_builder\", \"llm\")\n",
346
        "pipeline.connect(\"llm\", \"output_validator\")\n",
347
        "# If a component has more than one output or input, explicitly specify the connections:\n",
348
        "pipeline.connect(\"output_validator.invalid_replies\", \"prompt_builder.invalid_replies\")\n",
349
        "pipeline.connect(\"output_validator.error_message\", \"prompt_builder.error_message\")"
350
      ]
351
    },
352
    {
353
      "cell_type": "markdown",
354
      "metadata": {
355
        "id": "-UKW5wtIIT7w"
356
      },
357
      "source": [
358
        "### Visualize the Pipeline\n",
359
        "\n",
360
        "Draw the pipeline with the [`draw()`](https://docs.haystack.deepset.ai/v2.0/docs/drawing-pipeline-graphs) method to confirm the connections are correct. You can find the diagram in the Files section of this Colab."
361
      ]
362
    },
363
    {
364
      "cell_type": "code",
365
      "execution_count": null,
366
      "metadata": {
367
        "id": "RZJg6YHId300"
368
      },
369
      "outputs": [],
370
      "source": [
371
        "pipeline.draw(\"auto-correct-pipeline.png\")"
372
      ]
373
    },
374
    {
375
      "cell_type": "markdown",
376
      "metadata": {
377
        "id": "kV_kexTjImpo"
378
      },
379
      "source": [
380
        "## Testing the Pipeline\n",
381
        "\n",
382
        "Run the pipeline with an example passage that you want to convert into a JSON format and the `json_schema` you have created for `CitiesData`. For the given example passage, the generated JSON object should be like:\n",
383
        "```json\n",
384
        "{\n",
385
        "  \"cities\": [\n",
386
        "    {\n",
387
        "      \"name\": \"Berlin\",\n",
388
        "      \"country\": \"Germany\",\n",
389
        "      \"population\": 3850809\n",
390
        "    },\n",
391
        "    {\n",
392
        "      \"name\": \"Paris\",\n",
393
        "      \"country\": \"France\",\n",
394
        "      \"population\": 2161000\n",
395
        "    },\n",
396
        "    {\n",
397
        "      \"name\": \"Lisbon\",\n",
398
        "      \"country\": \"Portugal\",\n",
399
        "      \"population\": 504718\n",
400
        "    }\n",
401
        "  ]\n",
402
        "}\n",
403
        "```\n",
404
        "The output of the LLM should be compliant with the `json_schema`. If the LLM doesn't generate the correct JSON object, it will loop back and try again."
405
      ]
406
    },
407
    {
408
      "cell_type": "code",
409
      "execution_count": null,
410
      "metadata": {
411
        "colab": {
412
          "base_uri": "https://localhost:8080/"
413
        },
414
        "id": "yIoMedb6eKia",
415
        "outputId": "4a9ef924-cf26-4908-d83f-b0bc0dc03b54"
416
      },
417
      "outputs": [],
418
      "source": [
419
        "passage = \"Berlin is the capital of Germany. It has a population of 3,850,809. Paris, France's capital, has 2.161 million residents. Lisbon is the capital and the largest city of Portugal with the population of 504,718.\"\n",
420
        "result = pipeline.run({\"prompt_builder\": {\"passage\": passage, \"schema\": json_schema}})"
421
      ]
422
    },
423
    {
424
      "cell_type": "markdown",
425
      "metadata": {
426
        "id": "WWxmPgADS_Fa"
427
      },
428
      "source": [
429
        "> If you encounter `PipelineMaxLoops: Maximum loops count (5) exceeded for component 'prompt_builder'.` error, consider increasing the maximum loop count or simply rerun the pipeline."
430
      ]
431
    },
432
    {
433
      "cell_type": "markdown",
434
      "metadata": {
435
        "id": "eWPawSjgSJAM"
436
      },
437
      "source": [
438
        "### Print the Correct JSON\n",
439
        "If you didn't get any error, you can now print the corrected JSON."
440
      ]
441
    },
442
    {
443
      "cell_type": "code",
444
      "execution_count": null,
445
      "metadata": {
446
        "colab": {
447
          "base_uri": "https://localhost:8080/"
448
        },
449
        "id": "BVO47gXQQnDC",
450
        "outputId": "460a10d4-a69a-49cd-bbb2-fc4980907299"
451
      },
452
      "outputs": [],
453
      "source": [
454
        "valid_reply = result[\"output_validator\"][\"valid_replies\"][0]\n",
455
        "valid_json = json.loads(valid_reply)\n",
456
        "print(valid_json)"
457
      ]
458
    },
459
    {
460
      "cell_type": "markdown",
461
      "metadata": {
462
        "id": "Egz_4h2vI_QL"
463
      },
464
      "source": [
465
        "🎉 Congratulations! You've built a system that generates structured JSON out of unstructured text passages, and auto-corrects it by using the looping functionality of Haystack pipelines."
466
      ]
467
    }
468
  ],
469
  "metadata": {
470
    "accelerator": "GPU",
471
    "colab": {
472
      "gpuType": "T4",
473
      "provenance": []
474
    },
475
    "kernelspec": {
476
      "display_name": "Python 3",
477
      "name": "python3"
478
    },
479
    "language_info": {
480
      "name": "python"
481
    }
482
  },
483
  "nbformat": 4,
484
  "nbformat_minor": 0
485
}
486

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.