llm-applications

Форк
0
7999 строк · 797.8 Кб
1
{
2
 "cells": [
3
  {
4
   "cell_type": "markdown",
5
   "id": "154b2d0e-f7ce-453b-b3b7-eda0666a9795",
6
   "metadata": {},
7
   "source": [
8
    "# Building RAG-based LLM Applications for Production"
9
   ]
10
  },
11
  {
12
   "cell_type": "markdown",
13
   "id": "de569042-32c7-4bea-a1ef-f0e41e260645",
14
   "metadata": {},
15
   "source": [
16
    "- **Blog post**: https://www.anyscale.com/blog/a-comprehensive-guide-for-building-rag-based-llm-applications-part-1\n",
17
    "- **GitHub repository**: https://github.com/ray-project/llm-applications\n",
18
    "- **Anyscale Endpoints**: https://endpoints.anyscale.com/ (serve + fine-tune LLMs)\n",
19
    "- **Ray documentation**: https://docs.ray.io/"
20
   ]
21
  },
22
  {
23
   "cell_type": "markdown",
24
   "id": "68a962bf-4bb5-47ff-a667-89d493b1eeed",
25
   "metadata": {
26
    "tags": []
27
   },
28
   "source": [
29
    "In this guide, we will learn how to:\n",
30
    "\n",
31
    "- 💻 Develop a retrieval augmented generation (RAG) based LLM application from scratch.\n",
32
    "- 🚀 Scale the major workloads (load, chunk, embed, index, serve, etc.) across multiple workers with different compute resources.\n",
33
    "- ✅ Evaluate different configurations of our application to optimize for both per-component (ex. retrieval_score) and overall performance (quality_score).\n",
34
    "- 🔀 Implement a hybrid agent routing approach b/w OSS and closed LLMs to create the most performant and cost effective application.\n",
35
    "- 📦 Serve the application in a highly scalable and available manner.\n",
36
    "- 💡 Learn how methods like fine-tuning, prompt engineering, lexical search, reranking, data flywheel, etc. impact our application's performance."
37
   ]
38
  },
39
  {
40
   "cell_type": "markdown",
41
   "id": "b6b842ed-9c65-488f-a5b3-fcbced58c2c5",
42
   "metadata": {},
43
   "source": [
44
    "# Overview"
45
   ]
46
  },
47
  {
48
   "cell_type": "markdown",
49
   "id": "c09a1b69-8386-4043-846a-86b5396090af",
50
   "metadata": {},
51
   "source": [
52
    "Large language models (LLMs) have undoubtedly changed the way we interact with information. However, they come with their fair share of limitations as to what we can ask of them. Base LLMs (ex. Llama-2-70b, gpt-4, etc.) are only aware of the information that they've been trained on and will fall short when we require them to know information beyond that. Retrieval augmented generation (RAG) based LLM applications address this exact issue and extend the utility of LLMs and their generative reasoning abilities to our unique datasets. \n",
53
    "\n",
54
    "In this guide, we're going to build a RAG-based LLM application where we will incorporate external data sources to augment our LLM’s capabilities. Specifically, we will be building an assistant that can answer questions about [Ray](https://github.com/ray-project/ray) — a Python framework for productionizing and scaling ML workloads. The goal here is to make it easier for developers to adopt Ray, but also, as we'll see in this guide, to help improve our Ray documentation itself and provide a foundation for other LLM applications. We’ll also share challenges we faced along the way and how we overcame them.\n",
55
    "\n",
56
    "**Note**: We have generalized this entire guide so that it can easily be extended to build RAG-based LLM applications on top of your own data.\n",
57
    "\n",
58
    "<img width=\"500\" src=\"https://images.ctfassets.net/xjan103pcp94/4PX0l1ruKqfH17YvUiMFPw/c60a7a665125cb8056bebcc146c23b76/image8.png\">"
59
   ]
60
  },
61
  {
62
   "cell_type": "markdown",
63
   "id": "dc5bc988-231a-4aa4-8d9d-2192488e1724",
64
   "metadata": {},
65
   "source": [
66
    "Besides just building our LLM application, we’re also going to be focused on scaling and serving it in production. Unlike traditional machine learning, or even supervised deep learning, scale is a bottleneck for LLM applications from the very beginning. Large datasets, models, compute intensive workloads, serving requirements, etc. We’ll develop our application to be able to handle any scale as the world around us continues to grow. We’re also going to be focused on evaluation and performance. Our application involves many moving pieces: embedding models, chunking logic, the LLM itself, etc. and so it's important that we experiment with different configurations to optimize for the best quality responses. However, it's non-trivial to evaluate and quantitatively compare different configurations for a generative task. We’re going to break down evaluation of individual parts of our application (retrieval given query, generation given source), also assess the overall performance (end-to-end generation) and share findings towards an optimized configuration.\n",
67
    "\n",
68
    "**Note**: We'll be experimenting with different LLMs (OpenAI, Llama, etc.) in this guide. You will need [OpenAI credentials](https://platform.openai.com/account/api-keys) to access [ChatGPT models](https://platform.openai.com/docs/models/) and [Anyscale Endpoints](https://endpoints.anyscale.com/) (hosted/private endpoints available) to serve + fine-tune OSS LLMs."
69
   ]
70
  },
71
  {
72
   "cell_type": "markdown",
73
   "id": "35af14d4-478a-418b-a738-b17012188779",
74
   "metadata": {},
75
   "source": [
76
    "# Set up"
77
   ]
78
  },
79
  {
80
   "cell_type": "markdown",
81
   "id": "dde6f87c-c5c2-4374-b256-f6d17205d402",
82
   "metadata": {},
83
   "source": [
84
    "We're going to start by setting up our base imports, directories and initializing Ray with credentials. We'll be using [Ray](https://docs.ray.io/) to easily scale our workloads with minimal changes to our code."
85
   ]
86
  },
87
  {
88
   "cell_type": "code",
89
   "execution_count": null,
90
   "id": "b08502d8-e9a9-4a50-acd9-76f77b18ada6",
91
   "metadata": {
92
    "tags": []
93
   },
94
   "outputs": [],
95
   "source": [
96
    "import os\n",
97
    "import ray"
98
   ]
99
  },
100
  {
101
   "cell_type": "code",
102
   "execution_count": null,
103
   "id": "633996c3-45b4-4ac6-961d-56b0df9156c0",
104
   "metadata": {
105
    "tags": []
106
   },
107
   "outputs": [],
108
   "source": [
109
    "import sys; sys.path.append(\"..\")\n",
110
    "import warnings; warnings.filterwarnings(\"ignore\")\n",
111
    "from dotenv import load_dotenv; load_dotenv()\n",
112
    "%load_ext autoreload\n",
113
    "%autoreload 2"
114
   ]
115
  },
116
  {
117
   "cell_type": "code",
118
   "execution_count": null,
119
   "id": "591d064d-e843-4b5a-bb89-0fcfb16a5045",
120
   "metadata": {
121
    "tags": []
122
   },
123
   "outputs": [],
124
   "source": [
125
    "from rag.config import ROOT_DIR"
126
   ]
127
  },
128
  {
129
   "cell_type": "code",
130
   "execution_count": null,
131
   "id": "8f36dc38-f797-4db9-9979-2450764679aa",
132
   "metadata": {
133
    "tags": []
134
   },
135
   "outputs": [
136
    {
137
     "name": "stderr",
138
     "output_type": "stream",
139
     "text": [
140
      "2024-01-03 08:15:10,265\tINFO worker.py:1458 -- Connecting to existing Ray cluster at address: 10.0.28.113:6379...\n",
141
      "2024-01-03 08:15:10,305\tINFO worker.py:1633 -- Connected to Ray cluster. View the dashboard at \u001b[1m\u001b[32mhttps://session-5ljni527x7edt2q6px7nuaejct.i.anyscaleuserdata-staging.com \u001b[39m\u001b[22m\n",
142
      "2024-01-03 08:15:10,422\tINFO packaging.py:518 -- Creating a file package for local directory '/home/ray/ray-assistant/notebooks/..'.\n",
143
      "2024-01-03 08:15:10,600\tINFO packaging.py:346 -- Pushing file package 'gcs://_ray_pkg_d695da803a6e5581.zip' (46.04MiB) to Ray cluster...\n",
144
      "2024-01-03 08:15:10,747\tINFO packaging.py:359 -- Successfully pushed file package 'gcs://_ray_pkg_d695da803a6e5581.zip'.\n"
145
     ]
146
    },
147
    {
148
     "data": {
149
      "application/vnd.jupyter.widget-view+json": {
150
       "model_id": "bd4bb0f425d44275878d2993eba4a37b",
151
       "version_major": 2,
152
       "version_minor": 0
153
      },
154
      "text/html": [
155
       "<div class=\"lm-Widget p-Widget lm-Panel p-Panel jp-Cell-outputWrapper\">\n",
156
       "    <div style=\"margin-left: 50px;display: flex;flex-direction: row;align-items: center\">\n",
157
       "        <div class=\"jp-RenderedHTMLCommon\" style=\"display: flex; flex-direction: row;\">\n",
158
       "  <svg viewBox=\"0 0 567 224\" fill=\"none\" xmlns=\"http://www.w3.org/2000/svg\" style=\"height: 3em;\">\n",
159
       "    <g clip-path=\"url(#clip0_4338_178347)\">\n",
160
       "        <path d=\"M341.29 165.561H355.29L330.13 129.051C345.63 123.991 354.21 112.051 354.21 94.2307C354.21 71.3707 338.72 58.1807 311.88 58.1807H271V165.561H283.27V131.661H311.8C314.25 131.661 316.71 131.501 319.01 131.351L341.25 165.561H341.29ZM283.29 119.851V70.0007H311.82C331.3 70.0007 342.34 78.2907 342.34 94.5507C342.34 111.271 331.34 119.861 311.82 119.861L283.29 119.851ZM451.4 138.411L463.4 165.561H476.74L428.74 58.1807H416L367.83 165.561H380.83L392.83 138.411H451.4ZM446.19 126.601H398L422 72.1407L446.24 126.601H446.19ZM526.11 128.741L566.91 58.1807H554.35L519.99 114.181L485.17 58.1807H472.44L514.01 129.181V165.541H526.13V128.741H526.11Z\" fill=\"var(--jp-ui-font-color0)\"/>\n",
161
       "        <path d=\"M82.35 104.44C84.0187 97.8827 87.8248 92.0678 93.1671 87.9146C98.5094 83.7614 105.083 81.5067 111.85 81.5067C118.617 81.5067 125.191 83.7614 130.533 87.9146C135.875 92.0678 139.681 97.8827 141.35 104.44H163.75C164.476 101.562 165.622 98.8057 167.15 96.2605L127.45 56.5605C121.071 60.3522 113.526 61.6823 106.235 60.3005C98.9443 58.9187 92.4094 54.9203 87.8602 49.0574C83.3109 43.1946 81.0609 35.8714 81.5332 28.4656C82.0056 21.0599 85.1679 14.0819 90.4252 8.8446C95.6824 3.60726 102.672 0.471508 110.08 0.0272655C117.487 -0.416977 124.802 1.86091 130.647 6.4324C136.493 11.0039 140.467 17.5539 141.821 24.8501C143.175 32.1463 141.816 39.6859 138 46.0505L177.69 85.7505C182.31 82.9877 187.58 81.4995 192.962 81.4375C198.345 81.3755 203.648 82.742 208.33 85.3976C213.012 88.0532 216.907 91.9029 219.616 96.5544C222.326 101.206 223.753 106.492 223.753 111.875C223.753 117.258 222.326 122.545 219.616 127.197C216.907 131.848 213.012 135.698 208.33 138.353C203.648 141.009 198.345 142.375 192.962 142.313C187.58 142.251 182.31 140.763 177.69 138L138 177.7C141.808 184.071 143.155 191.614 141.79 198.91C140.424 206.205 136.44 212.75 130.585 217.313C124.731 221.875 117.412 224.141 110.004 223.683C102.596 223.226 95.6103 220.077 90.3621 214.828C85.1139 209.58 81.9647 202.595 81.5072 195.187C81.0497 187.779 83.3154 180.459 87.878 174.605C92.4405 168.751 98.9853 164.766 106.281 163.401C113.576 162.035 121.119 163.383 127.49 167.19L167.19 127.49C165.664 124.941 164.518 122.182 163.79 119.3H141.39C139.721 125.858 135.915 131.673 130.573 135.826C125.231 139.98 118.657 142.234 111.89 142.234C105.123 142.234 98.5494 139.98 93.2071 135.826C87.8648 131.673 84.0587 125.858 82.39 119.3H60C58.1878 126.495 53.8086 132.78 47.6863 136.971C41.5641 141.163 34.1211 142.972 26.7579 142.059C19.3947 141.146 12.6191 137.574 7.70605 132.014C2.79302 126.454 0.0813599 119.29 0.0813599 111.87C0.0813599 104.451 2.79302 97.2871 7.70605 91.7272C12.6191 86.1673 19.3947 82.5947 26.7579 81.6817C34.1211 80.7686 41.5641 82.5781 47.6863 86.7696C53.8086 90.9611 58.1878 97.2456 60 104.44H82.35ZM100.86 204.32C103.407 206.868 106.759 208.453 110.345 208.806C113.93 209.159 117.527 208.258 120.522 206.256C123.517 204.254 125.725 201.276 126.771 197.828C127.816 194.38 127.633 190.677 126.253 187.349C124.874 184.021 122.383 181.274 119.205 179.577C116.027 177.88 112.359 177.337 108.826 178.042C105.293 178.746 102.113 180.654 99.8291 183.44C97.5451 186.226 96.2979 189.718 96.3 193.32C96.2985 195.364 96.7006 197.388 97.4831 199.275C98.2656 201.163 99.4132 202.877 100.86 204.32ZM204.32 122.88C206.868 120.333 208.453 116.981 208.806 113.396C209.159 109.811 208.258 106.214 206.256 103.219C204.254 100.223 201.275 98.0151 197.827 96.97C194.38 95.9249 190.676 96.1077 187.348 97.4873C184.02 98.8669 181.274 101.358 179.577 104.536C177.879 107.714 177.337 111.382 178.041 114.915C178.746 118.448 180.653 121.627 183.439 123.911C186.226 126.195 189.717 127.443 193.32 127.44C195.364 127.443 197.388 127.042 199.275 126.259C201.163 125.476 202.878 124.328 204.32 122.88ZM122.88 19.4205C120.333 16.8729 116.981 15.2876 113.395 14.9347C109.81 14.5817 106.213 15.483 103.218 17.4849C100.223 19.4868 98.0146 22.4654 96.9696 25.9131C95.9245 29.3608 96.1073 33.0642 97.4869 36.3922C98.8665 39.7202 101.358 42.4668 104.535 44.1639C107.713 45.861 111.381 46.4036 114.914 45.6992C118.447 44.9949 121.627 43.0871 123.911 40.301C126.195 37.515 127.442 34.0231 127.44 30.4205C127.44 28.3772 127.038 26.3539 126.255 24.4664C125.473 22.5788 124.326 20.8642 122.88 19.4205ZM19.42 100.86C16.8725 103.408 15.2872 106.76 14.9342 110.345C14.5813 113.93 15.4826 117.527 17.4844 120.522C19.4863 123.518 22.4649 125.726 25.9127 126.771C29.3604 127.816 33.0638 127.633 36.3918 126.254C39.7198 124.874 42.4664 122.383 44.1635 119.205C45.8606 116.027 46.4032 112.359 45.6988 108.826C44.9944 105.293 43.0866 102.114 40.3006 99.8296C37.5145 97.5455 34.0227 96.2983 30.42 96.3005C26.2938 96.3018 22.337 97.9421 19.42 100.86ZM100.86 100.86C98.3125 103.408 96.7272 106.76 96.3742 110.345C96.0213 113.93 96.9226 117.527 98.9244 120.522C100.926 123.518 103.905 125.726 107.353 126.771C110.8 127.816 114.504 127.633 117.832 126.254C121.16 124.874 123.906 122.383 125.604 119.205C127.301 116.027 127.843 112.359 127.139 108.826C126.434 105.293 124.527 102.114 121.741 99.8296C118.955 97.5455 115.463 96.2983 111.86 96.3005C109.817 96.299 107.793 96.701 105.905 97.4835C104.018 98.2661 102.303 99.4136 100.86 100.86Z\" fill=\"#00AEEF\"/>\n",
162
       "    </g>\n",
163
       "    <defs>\n",
164
       "        <clipPath id=\"clip0_4338_178347\">\n",
165
       "            <rect width=\"566.93\" height=\"223.75\" fill=\"white\"/>\n",
166
       "        </clipPath>\n",
167
       "    </defs>\n",
168
       "  </svg>\n",
169
       "</div>\n",
170
       "\n",
171
       "        <table class=\"jp-RenderedHTMLCommon\" style=\"border-collapse: collapse;color: var(--jp-ui-font-color1);font-size: var(--jp-ui-font-size1);\">\n",
172
       "    <tr>\n",
173
       "        <td style=\"text-align: left\"><b>Python version:</b></td>\n",
174
       "        <td style=\"text-align: left\"><b>3.10.8</b></td>\n",
175
       "    </tr>\n",
176
       "    <tr>\n",
177
       "        <td style=\"text-align: left\"><b>Ray version:</b></td>\n",
178
       "        <td style=\"text-align: left\"><b>2.7.0</b></td>\n",
179
       "    </tr>\n",
180
       "    <tr>\n",
181
       "    <td style=\"text-align: left\"><b>Dashboard:</b></td>\n",
182
       "    <td style=\"text-align: left\"><b><a href=\"http://session-5ljni527x7edt2q6px7nuaejct.i.anyscaleuserdata-staging.com\" target=\"_blank\">http://session-5ljni527x7edt2q6px7nuaejct.i.anyscaleuserdata-staging.com</a></b></td>\n",
183
       "</tr>\n",
184
       "\n",
185
       "</table>\n",
186
       "\n",
187
       "    </div>\n",
188
       "</div>\n"
189
      ],
190
      "text/plain": [
191
       "RayContext(dashboard_url='session-5ljni527x7edt2q6px7nuaejct.i.anyscaleuserdata-staging.com', python_version='3.10.8', ray_version='2.7.0', ray_commit='acb4a960947869e158a973c6c4bdf1aca2d66b10', protocol_version=None)"
192
      ]
193
     },
194
     "execution_count": null,
195
     "metadata": {},
196
     "output_type": "execute_result"
197
    }
198
   ],
199
   "source": [
200
    "# Credentials\n",
201
    "ray.init(runtime_env={\n",
202
    "    \"env_vars\": {\n",
203
    "        \"OPENAI_API_BASE\": os.environ[\"OPENAI_API_BASE\"],\n",
204
    "        \"OPENAI_API_KEY\": os.environ[\"OPENAI_API_KEY\"], \n",
205
    "        \"ANYSCALE_API_BASE\": os.environ[\"ANYSCALE_API_BASE\"],\n",
206
    "        \"ANYSCALE_API_KEY\": os.environ[\"ANYSCALE_API_KEY\"],\n",
207
    "        \"DB_CONNECTION_STRING\": os.environ[\"DB_CONNECTION_STRING\"],\n",
208
    "    },\n",
209
    "    \"working_dir\": str(ROOT_DIR)\n",
210
    "})"
211
   ]
212
  },
213
  {
214
   "cell_type": "code",
215
   "execution_count": null,
216
   "id": "880b8660-e613-431f-a083-d4985711e8bf",
217
   "metadata": {
218
    "tags": []
219
   },
220
   "outputs": [
221
    {
222
     "data": {
223
      "text/plain": [
224
       "{'GPU': 1.0,\n",
225
       " 'CPU': 8.0,\n",
226
       " 'node:__internal_head__': 1.0,\n",
227
       " 'node:10.0.28.113': 1.0,\n",
228
       " 'accelerator_type:A10G': 1.0,\n",
229
       " 'object_store_memory': 9536466124.0,\n",
230
       " 'memory': 34359738368.0}"
231
      ]
232
     },
233
     "execution_count": null,
234
     "metadata": {},
235
     "output_type": "execute_result"
236
    }
237
   ],
238
   "source": [
239
    "ray.cluster_resources()"
240
   ]
241
  },
242
  {
243
   "cell_type": "markdown",
244
   "id": "8e3928f1-1404-430d-8633-45bb1a1c21d3",
245
   "metadata": {},
246
   "source": [
247
    "We've also created some mappings for the different embedding and language models we'll be developing with in our application:"
248
   ]
249
  },
250
  {
251
   "cell_type": "code",
252
   "execution_count": null,
253
   "id": "ab69df3c-bf18-4750-a680-4e488888c5e8",
254
   "metadata": {
255
    "tags": []
256
   },
257
   "outputs": [],
258
   "source": [
259
    "from rag.config import EMBEDDING_DIMENSIONS, MAX_CONTEXT_LENGTHS, PRICING"
260
   ]
261
  },
262
  {
263
   "cell_type": "code",
264
   "execution_count": null,
265
   "id": "e8ba8f76-7ae8-49c1-8907-151b1fecbaf8",
266
   "metadata": {
267
    "tags": []
268
   },
269
   "outputs": [
270
    {
271
     "data": {
272
      "text/plain": [
273
       "{'thenlper/gte-base': 768,\n",
274
       " 'thenlper/gte-large': 1024,\n",
275
       " 'BAAI/bge-large-en': 1024,\n",
276
       " 'text-embedding-ada-002': 1536,\n",
277
       " 'gte-large-fine-tuned': 1024}"
278
      ]
279
     },
280
     "execution_count": null,
281
     "metadata": {},
282
     "output_type": "execute_result"
283
    }
284
   ],
285
   "source": [
286
    "# Embedding dimensions\n",
287
    "EMBEDDING_DIMENSIONS"
288
   ]
289
  },
290
  {
291
   "cell_type": "code",
292
   "execution_count": null,
293
   "id": "b6d340fa-4a87-4cc8-8d74-c09aafd9d18d",
294
   "metadata": {
295
    "tags": []
296
   },
297
   "outputs": [
298
    {
299
     "data": {
300
      "text/plain": [
301
       "{'gpt-4': 8192,\n",
302
       " 'gpt-3.5-turbo': 4096,\n",
303
       " 'gpt-3.5-turbo-16k': 16384,\n",
304
       " 'gpt-4-1106-preview': 128000,\n",
305
       " 'meta-llama/Llama-2-7b-chat-hf': 4096,\n",
306
       " 'meta-llama/Llama-2-13b-chat-hf': 4096,\n",
307
       " 'meta-llama/Llama-2-70b-chat-hf': 4096,\n",
308
       " 'codellama/CodeLlama-34b-Instruct-hf': 16384,\n",
309
       " 'mistralai/Mistral-7B-Instruct-v0.1': 65536,\n",
310
       " 'mistralai/Mixtral-8x7B-Instruct-v0.1': 32768}"
311
      ]
312
     },
313
     "execution_count": null,
314
     "metadata": {},
315
     "output_type": "execute_result"
316
    }
317
   ],
318
   "source": [
319
    "# LLM context lengths (1 token = 3/4 word)\n",
320
    "MAX_CONTEXT_LENGTHS"
321
   ]
322
  },
323
  {
324
   "cell_type": "code",
325
   "execution_count": null,
326
   "id": "8b6c5ec8-70c6-44ea-9576-2b41a898acc0",
327
   "metadata": {
328
    "tags": []
329
   },
330
   "outputs": [
331
    {
332
     "data": {
333
      "text/plain": [
334
       "{'gpt-3.5-turbo': {'prompt': 1.5, 'sampled': 2},\n",
335
       " 'gpt-4': {'prompt': 30, 'sampled': 60},\n",
336
       " 'gpt-4-1106-preview': {'prompt': 10, 'sampled': 30},\n",
337
       " 'llama-2-7b-chat-hf': {'prompt': 0.15, 'sampled': 0.15},\n",
338
       " 'llama-2-13b-chat-hf': {'prompt': 0.25, 'sampled': 0.25},\n",
339
       " 'llama-2-70b-chat-hf': {'prompt': 1, 'sampled': 1},\n",
340
       " 'codellama-34b-instruct-hf': {'prompt': 1, 'sampled': 1},\n",
341
       " 'mistral-7b-instruct-v0.1': {'prompt': 0.15, 'sampled': 0.15},\n",
342
       " 'mixtral-8x7b-instruct-v0.1': {'prompt': 0.5, 'sampled': 0.5}}"
343
      ]
344
     },
345
     "execution_count": null,
346
     "metadata": {},
347
     "output_type": "execute_result"
348
    }
349
   ],
350
   "source": [
351
    "# Anyscale pricing\n",
352
    "PRICING"
353
   ]
354
  },
355
  {
356
   "cell_type": "markdown",
357
   "id": "bab7f73e-bdd1-4c87-93d0-92ece2344432",
358
   "metadata": {},
359
   "source": [
360
    "# Data"
361
   ]
362
  },
363
  {
364
   "cell_type": "markdown",
365
   "id": "00a1ed16-0f0d-4be1-951a-21a8cde1bfe9",
366
   "metadata": {},
367
   "source": [
368
    "Before we can start building our RAG application, we need to first create our vector DB that will contain our processed data sources.\n",
369
    "\n",
370
    "<img width=\"1000\" src=\"https://images.ctfassets.net/xjan103pcp94/3q5HUANQ4kS0V23cgEP0JF/ef3b62c5bc5c5c11b734fd3b73f6ea28/image3.png\">"
371
   ]
372
  },
373
  {
374
   "cell_type": "markdown",
375
   "id": "4c503edd-963a-4ec3-9182-39f7afc44153",
376
   "metadata": {},
377
   "source": [
378
    "## Load data"
379
   ]
380
  },
381
  {
382
   "cell_type": "code",
383
   "execution_count": null,
384
   "id": "2301b6f6-7402-4eeb-9975-7c28194b6914",
385
   "metadata": {
386
    "tags": []
387
   },
388
   "outputs": [],
389
   "source": [
390
    "from pathlib import Path\n",
391
    "from rag.config import EFS_DIR"
392
   ]
393
  },
394
  {
395
   "cell_type": "markdown",
396
   "id": "1f82562f-2b5d-4e8d-9716-b0da8670e8bf",
397
   "metadata": {},
398
   "source": [
399
    "We need to first download the [Ray documentation](https://docs.ray.io/) to a directory:\n",
400
    "```bash\n",
401
    "export EFS_DIR=$(python -c \"from rag.config import EFS_DIR; print(EFS_DIR)\")\n",
402
    "wget -e robots=off --recursive --no-clobber --page-requisites \\\n",
403
    "  --html-extension --convert-links --restrict-file-names=windows \\\n",
404
    "  --domains docs.ray.io --no-parent --accept=html --retry-on-http-error=429 \\\n",
405
    "  -P $EFS_DIR https://docs.ray.io/en/master/\n",
406
    "```"
407
   ]
408
  },
409
  {
410
   "cell_type": "markdown",
411
   "id": "77456cf7-fe2b-4884-bfc2-99a2b4ffba1a",
412
   "metadata": {},
413
   "source": [
414
    "We’re going to then load our docs contents into a [Ray Dataset](https://docs.ray.io/en/latest/data/data.html) so that we can perform operations at scale on them (ex. embed, index, etc.). With large data sources, models and application serving needs, scale is a day-1 priority for LLM applications. We want to build our applications in such a way that they can scale as our needs grow without us having to change our code later."
415
   ]
416
  },
417
  {
418
   "cell_type": "code",
419
   "execution_count": null,
420
   "id": "bba6b43b-ea82-4c21-a885-57178cec3b44",
421
   "metadata": {
422
    "tags": []
423
   },
424
   "outputs": [
425
    {
426
     "name": "stdout",
427
     "output_type": "stream",
428
     "text": [
429
      "3282 documents\n"
430
     ]
431
    }
432
   ],
433
   "source": [
434
    "# Ray dataset\n",
435
    "DOCS_DIR = Path(EFS_DIR, \"docs.ray.io/en/master/\")\n",
436
    "ds = ray.data.from_items([{\"path\": path} for path in DOCS_DIR.rglob(\"*.html\") if not path.is_dir()])\n",
437
    "print(f\"{ds.count()} documents\")"
438
   ]
439
  },
440
  {
441
   "cell_type": "markdown",
442
   "id": "ba9edff6-6dbf-4037-9675-ae05cd3eb7a7",
443
   "metadata": {
444
    "tags": []
445
   },
446
   "source": [
447
    "## Sections"
448
   ]
449
  },
450
  {
451
   "cell_type": "markdown",
452
   "id": "4f13f0dc-7f7a-4132-93e4-dc69aab164a2",
453
   "metadata": {},
454
   "source": [
455
    "Now that we have a dataset of all the paths to the html files, we're going to develop some functions that can appropriately extract the content from these files. We want to do this in a generalized manner so that we can perform this extraction across all of our docs pages (and so you can use it for your own data sources). Our process is to first identify the sections in our html page and then extract the text in between them. We save all of this into a list of dictionaries that map the text within a section to a specific url with a section anchor id.\n",
456
    "\n",
457
    "<img width=\"800\" src=\"https://images.ctfassets.net/xjan103pcp94/1eFnKmG5xqPIFtPupZ327X/f6152723e18322b90aaa8be5d2d5a6e4/image5.png\">"
458
   ]
459
  },
460
  {
461
   "cell_type": "code",
462
   "execution_count": null,
463
   "id": "f72dedb6-63d1-414f-b973-6a1009fcaf74",
464
   "metadata": {
465
    "tags": []
466
   },
467
   "outputs": [],
468
   "source": [
469
    "import matplotlib.pyplot as plt\n",
470
    "from rag.data import extract_sections"
471
   ]
472
  },
473
  {
474
   "cell_type": "code",
475
   "execution_count": null,
476
   "id": "2a8033de-2508-4030-a9a1-d6c792d5542a",
477
   "metadata": {
478
    "tags": []
479
   },
480
   "outputs": [
481
    {
482
     "data": {
483
      "text/plain": [
484
       "{'source': 'https://docs.ray.io/en/master/rllib/rllib-env.html#environments',\n",
485
       " 'text': '\\nEnvironments#\\nRLlib works with several different types of environments, including Farama-Foundation Gymnasium, user-defined, multi-agent, and also batched environments.\\nTip\\nNot all environments work with all algorithms. Check out the algorithm overview for more information.\\n'}"
486
      ]
487
     },
488
     "execution_count": null,
489
     "metadata": {},
490
     "output_type": "execute_result"
491
    }
492
   ],
493
   "source": [
494
    "sample_html_fp = Path(EFS_DIR, \"docs.ray.io/en/master/rllib/rllib-env.html\")\n",
495
    "extract_sections({\"path\": sample_html_fp})[0]"
496
   ]
497
  },
498
  {
499
   "cell_type": "markdown",
500
   "id": "24f4ed64-2dd8-43c2-9831-cdf9d6d8004c",
501
   "metadata": {},
502
   "source": [
503
    "We can apply this extraction process (extract_section) in parallel to all the file paths in our dataset with just one line using Ray Data's [flat_map](https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.flat_map.html):"
504
   ]
505
  },
506
  {
507
   "cell_type": "code",
508
   "execution_count": null,
509
   "id": "f4739d56-ddfe-42e5-9113-3d83d737999a",
510
   "metadata": {
511
    "tags": []
512
   },
513
   "outputs": [
514
    {
515
     "name": "stderr",
516
     "output_type": "stream",
517
     "text": [
518
      "2024-01-03 08:15:15,020\tINFO streaming_executor.py:93 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[FlatMap(extract_sections)]\n",
519
      "2024-01-03 08:15:15,021\tINFO streaming_executor.py:94 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\n",
520
      "2024-01-03 08:15:15,021\tINFO streaming_executor.py:96 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\n"
521
     ]
522
    },
523
    {
524
     "data": {
525
      "application/vnd.jupyter.widget-view+json": {
526
       "model_id": "",
527
       "version_major": 2,
528
       "version_minor": 0
529
      },
530
      "text/plain": [
531
       "Running 0:   0%|          | 0/200 [00:00<?, ?it/s]"
532
      ]
533
     },
534
     "metadata": {},
535
     "output_type": "display_data"
536
    },
537
    {
538
     "data": {
539
      "text/plain": [
540
       "5727"
541
      ]
542
     },
543
     "execution_count": null,
544
     "metadata": {},
545
     "output_type": "execute_result"
546
    }
547
   ],
548
   "source": [
549
    "# Extract sections\n",
550
    "sections_ds = ds.flat_map(extract_sections)\n",
551
    "sections_ds.count()"
552
   ]
553
  },
554
  {
555
   "cell_type": "code",
556
   "execution_count": null,
557
   "id": "39641fc6-adb2-4a6a-a8dd-f0433db814ff",
558
   "metadata": {
559
    "tags": []
560
   },
561
   "outputs": [],
562
   "source": [
563
    "section_lengths = []\n",
564
    "for section in sections_ds.take_all():\n",
565
    "    section_lengths.append(len(section[\"text\"]))"
566
   ]
567
  },
568
  {
569
   "cell_type": "markdown",
570
   "id": "2c144ef2-ffa5-49ed-9bb3-60a50223854c",
571
   "metadata": {},
572
   "source": [
573
    "## Chunk data"
574
   ]
575
  },
576
  {
577
   "cell_type": "markdown",
578
   "id": "78c103aa-9b34-4d7b-8b42-78737eb0888d",
579
   "metadata": {},
580
   "source": [
581
    "We now have a list of sections (with text and source of each section) but we shouldn't directly use this as context to our RAG application just yet. The text lengths of each section are all varied and many are quite large chunks. "
582
   ]
583
  },
584
  {
585
   "cell_type": "code",
586
   "execution_count": null,
587
   "id": "1b07383c-b898-4f4a-b008-837a8bf83015",
588
   "metadata": {
589
    "tags": []
590
   },
591
   "outputs": [
592
    {
593
     "data": {
594
      "image/png": "",
595
      "text/plain": [
596
       "<Figure size 1200x300 with 1 Axes>"
597
      ]
598
     },
599
     "metadata": {},
600
     "output_type": "display_data"
601
    }
602
   ],
603
   "source": [
604
    "# Plot\n",
605
    "plt.figure(figsize=(12, 3))\n",
606
    "plt.plot(section_lengths, marker='o')\n",
607
    "plt.title(\"Section lengths\")\n",
608
    "plt.ylabel(\"# chars\")\n",
609
    "plt.show()"
610
   ]
611
  },
612
  {
613
   "cell_type": "markdown",
614
   "id": "bfa7564d-fce5-4807-8bf2-de9f93c6f994",
615
   "metadata": {},
616
   "source": [
617
    "If we were to use these large sections, then we'd be inserting a lot of noisy/unwanted context and because all LLMs have a maximum context length, we wouldn't be able to fit too much other relevant context. So instead, we're going to split the text within each section into smaller chunks. Intuitively, smaller chunks will encapsulate single/few concepts and will be less noisy compared to larger chunks. We're going to choose some typical text splitting values (ex. chunk_size=300) to create our chunks for now but we'll be experimenting with a wider range of values later."
618
   ]
619
  },
620
  {
621
   "cell_type": "code",
622
   "execution_count": null,
623
   "id": "97ca89ce-97dd-4462-8dfb-e06154c56ea2",
624
   "metadata": {
625
    "tags": []
626
   },
627
   "outputs": [],
628
   "source": [
629
    "from functools import partial\n",
630
    "from langchain.text_splitter import RecursiveCharacterTextSplitter"
631
   ]
632
  },
633
  {
634
   "cell_type": "code",
635
   "execution_count": null,
636
   "id": "6b6337e6-07b2-459d-a666-45de3aa945c4",
637
   "metadata": {
638
    "tags": []
639
   },
640
   "outputs": [],
641
   "source": [
642
    "# Text splitter\n",
643
    "chunk_size = 300\n",
644
    "chunk_overlap = 50\n",
645
    "text_splitter = RecursiveCharacterTextSplitter(\n",
646
    "    separators=[\"\\n\\n\", \"\\n\", \" \", \"\"],\n",
647
    "    chunk_size=chunk_size,\n",
648
    "    chunk_overlap=chunk_overlap,\n",
649
    "    length_function=len)"
650
   ]
651
  },
652
  {
653
   "cell_type": "code",
654
   "execution_count": null,
655
   "id": "ca5208da-7196-49b5-9943-c0c9f0ab13b5",
656
   "metadata": {
657
    "tags": []
658
   },
659
   "outputs": [
660
    {
661
     "name": "stderr",
662
     "output_type": "stream",
663
     "text": [
664
      "2024-01-03 08:15:40,714\tINFO dataset.py:2380 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.\n",
665
      "2024-01-03 08:15:40,716\tINFO streaming_executor.py:93 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[FlatMap(extract_sections->Limit[1])] -> LimitOperator[limit=1]\n",
666
      "2024-01-03 08:15:40,717\tINFO streaming_executor.py:94 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\n",
667
      "2024-01-03 08:15:40,718\tINFO streaming_executor.py:96 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\n"
668
     ]
669
    },
670
    {
671
     "data": {
672
      "application/vnd.jupyter.widget-view+json": {
673
       "model_id": "",
674
       "version_major": 2,
675
       "version_minor": 0
676
      },
677
      "text/plain": [
678
       "Running 0:   0%|          | 0/1 [00:00<?, ?it/s]"
679
      ]
680
     },
681
     "metadata": {},
682
     "output_type": "display_data"
683
    },
684
    {
685
     "name": "stdout",
686
     "output_type": "stream",
687
     "text": [
688
      "page_content='Ray Dashboard#\\nRay provides a web-based dashboard for monitoring and debugging Ray applications.\\nThe visual representation of the system state, allows users to track the performance\\nof applications and troubleshoot issues.' metadata={'source': 'https://docs.ray.io/en/master/ray-observability/getting-started.html#ray-dashboard'}\n"
689
     ]
690
    }
691
   ],
692
   "source": [
693
    "# Chunk a sample section\n",
694
    "sample_section = sections_ds.take(1)[0]\n",
695
    "chunks = text_splitter.create_documents(\n",
696
    "    texts=[sample_section[\"text\"]], \n",
697
    "    metadatas=[{\"source\": sample_section[\"source\"]}])\n",
698
    "print (chunks[0])"
699
   ]
700
  },
701
  {
702
   "cell_type": "markdown",
703
   "id": "4d38775f-0a4b-4aad-866d-1a3231b32df7",
704
   "metadata": {},
705
   "source": [
706
    "While chunking our dataset is relatively fast, let’s wrap the chunking logic into a function so that we can apply the workload at scale so that chunking remains just as fast as our data sources grow:\n"
707
   ]
708
  },
709
  {
710
   "cell_type": "code",
711
   "execution_count": null,
712
   "id": "5c15a38d-066b-4623-8448-fdd0ac5ddf4d",
713
   "metadata": {
714
    "tags": []
715
   },
716
   "outputs": [],
717
   "source": [
718
    "def chunk_section(section, chunk_size, chunk_overlap):\n",
719
    "    text_splitter = RecursiveCharacterTextSplitter(\n",
720
    "        separators=[\"\\n\\n\", \"\\n\", \" \", \"\"],\n",
721
    "        chunk_size=chunk_size,\n",
722
    "        chunk_overlap=chunk_overlap,\n",
723
    "        length_function=len)\n",
724
    "    chunks = text_splitter.create_documents(\n",
725
    "        texts=[section[\"text\"]], \n",
726
    "        metadatas=[{\"source\": section[\"source\"]}])\n",
727
    "    return [{\"text\": chunk.page_content, \"source\": chunk.metadata[\"source\"]} for chunk in chunks]"
728
   ]
729
  },
730
  {
731
   "cell_type": "code",
732
   "execution_count": null,
733
   "id": "5be23c62-3108-49be-994d-5aa033162c71",
734
   "metadata": {
735
    "tags": []
736
   },
737
   "outputs": [
738
    {
739
     "name": "stderr",
740
     "output_type": "stream",
741
     "text": [
742
      "2024-01-03 08:15:41,015\tINFO streaming_executor.py:93 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[FlatMap(extract_sections)->FlatMap(partial)]\n",
743
      "2024-01-03 08:15:41,015\tINFO streaming_executor.py:94 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\n",
744
      "2024-01-03 08:15:41,016\tINFO streaming_executor.py:96 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\n"
745
     ]
746
    },
747
    {
748
     "data": {
749
      "application/vnd.jupyter.widget-view+json": {
750
       "model_id": "",
751
       "version_major": 2,
752
       "version_minor": 0
753
      },
754
      "text/plain": [
755
       "Running 0:   0%|          | 0/200 [00:00<?, ?it/s]"
756
      ]
757
     },
758
     "metadata": {},
759
     "output_type": "display_data"
760
    },
761
    {
762
     "name": "stderr",
763
     "output_type": "stream",
764
     "text": [
765
      "2024-01-03 08:16:03,225\tINFO streaming_executor.py:93 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[FlatMap(extract_sections)->FlatMap(partial->Limit[1])] -> LimitOperator[limit=1]\n",
766
      "2024-01-03 08:16:03,225\tINFO streaming_executor.py:94 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\n",
767
      "2024-01-03 08:16:03,225\tINFO streaming_executor.py:96 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\n"
768
     ]
769
    },
770
    {
771
     "name": "stdout",
772
     "output_type": "stream",
773
     "text": [
774
      "32276 chunks\n"
775
     ]
776
    },
777
    {
778
     "data": {
779
      "application/vnd.jupyter.widget-view+json": {
780
       "model_id": "",
781
       "version_major": 2,
782
       "version_minor": 0
783
      },
784
      "text/plain": [
785
       "Running 0:   0%|          | 0/1 [00:00<?, ?it/s]"
786
      ]
787
     },
788
     "metadata": {},
789
     "output_type": "display_data"
790
    },
791
    {
792
     "name": "stdout",
793
     "output_type": "stream",
794
     "text": [
795
      "{'text': 'Reference#\\nMonitor and debug your Ray applications and clusters using the API and CLI documented in these references.\\nThe guides include:\\nState API\\nState CLI\\nSystem Metrics', 'source': 'https://docs.ray.io/en/master/ray-observability/reference/index.html#reference'}\n"
796
     ]
797
    }
798
   ],
799
   "source": [
800
    "# Scale chunking\n",
801
    "chunks_ds = sections_ds.flat_map(partial(\n",
802
    "    chunk_section, \n",
803
    "    chunk_size=chunk_size, \n",
804
    "    chunk_overlap=chunk_overlap))\n",
805
    "print(f\"{chunks_ds.count()} chunks\")\n",
806
    "chunks_ds.show(1)"
807
   ]
808
  },
809
  {
810
   "cell_type": "markdown",
811
   "id": "d9c23b31-e7b3-4078-abf7-683f448f5b19",
812
   "metadata": {},
813
   "source": [
814
    "## Embed data"
815
   ]
816
  },
817
  {
818
   "cell_type": "markdown",
819
   "id": "fb44e2b4-d48a-4f27-b9e4-99579917b18e",
820
   "metadata": {},
821
   "source": [
822
    "Now that we've created small chunks from our sections, we need a way to identify the most relevant ones for a given query. A very effective and quick method is to embed our data using a pretrained model and use the same model to embed the query. We can then compute the distance between all of the chunk embeddings and our query embedding to determine the top-k chunks. There are many different pretrained models to choose from to embed our data but the most popular ones can be discovered through [HuggingFace's Massive Text Embedding Benchmark (MTEB)](https://huggingface.co/spaces/mteb/leaderboard) leaderboard. These models were pretrained on very large text corpus through tasks such as next/masked token prediction which allowed them to learn to represent subtokens in N dimensions and capture semantic relationships. We can leverage this to represent our data and identify the most relevant contexts to use to answer a given query. We're using Langchain's Embedding wrappers ([HuggingFaceEmbeddings](https://api.python.langchain.com/en/latest/embeddings/langchain.embeddings.huggingface.HuggingFaceEmbeddings.html) and [OpenAIEmbeddings](https://api.python.langchain.com/en/latest/embeddings/langchain.embeddings.openai.OpenAIEmbeddings.html)) to easily load the models and embed our document chunks.\n",
823
    "\n",
824
    "**Note**: embeddings aren't the only way to determine the more relevant chunks. We could also use an LLM to decide! However, because LLMs are much larger than these embedding models and have maximum context lengths, it's better to use embeddings to retrieve the top k chunks. And then we could use LLMs on the fewer k chunks to determine the <k chunks to use as the context to answer our query. We could also use reranking (ex. [Cohere Rerank](https://txt.cohere.com/rerank/)) to further identify the most relevant chunks to use. We could also combine embeddings with traditional information retrieval methods such as keyword matching, which could be useful for matching for unique tokens that may potentially be lost when embedding subtokens."
825
   ]
826
  },
827
  {
828
   "cell_type": "code",
829
   "execution_count": null,
830
   "id": "712fe08b-fd19-4cb8-94d9-a7570b2dc09d",
831
   "metadata": {
832
    "tags": []
833
   },
834
   "outputs": [],
835
   "source": [
836
    "from langchain.embeddings import OpenAIEmbeddings\n",
837
    "from langchain.embeddings.huggingface import HuggingFaceEmbeddings\n",
838
    "import numpy as np\n",
839
    "from ray.data import ActorPoolStrategy"
840
   ]
841
  },
842
  {
843
   "cell_type": "code",
844
   "execution_count": null,
845
   "id": "290c066a-100b-4f09-98ba-e79a93cd4302",
846
   "metadata": {
847
    "tags": []
848
   },
849
   "outputs": [],
850
   "source": [
851
    "def get_embedding_model(embedding_model_name, model_kwargs, encode_kwargs):\n",
852
    "    if embedding_model_name == \"text-embedding-ada-002\":\n",
853
    "        embedding_model = OpenAIEmbeddings(\n",
854
    "            model=embedding_model_name,\n",
855
    "            openai_api_base=os.environ[\"OPENAI_API_BASE\"],\n",
856
    "            openai_api_key=os.environ[\"OPENAI_API_KEY\"])\n",
857
    "    else:\n",
858
    "        embedding_model = HuggingFaceEmbeddings(\n",
859
    "            model_name=embedding_model_name,  # also works with model_path\n",
860
    "            model_kwargs=model_kwargs,\n",
861
    "            encode_kwargs=encode_kwargs)\n",
862
    "    return embedding_model"
863
   ]
864
  },
865
  {
866
   "cell_type": "code",
867
   "execution_count": null,
868
   "id": "83b6a5a3-cd2d-4987-838a-be13e9553080",
869
   "metadata": {
870
    "tags": []
871
   },
872
   "outputs": [],
873
   "source": [
874
    "class EmbedChunks:\n",
875
    "    def __init__(self, model_name):\n",
876
    "        self.embedding_model = get_embedding_model(\n",
877
    "            embedding_model_name=model_name,\n",
878
    "            model_kwargs={\"device\": \"cuda\"},\n",
879
    "            encode_kwargs={\"device\": \"cuda\", \"batch_size\": 100})\n",
880
    "    def __call__(self, batch):\n",
881
    "        embeddings = self.embedding_model.embed_documents(batch[\"text\"])\n",
882
    "        return {\"text\": batch[\"text\"], \"source\": batch[\"source\"], \"embeddings\": embeddings}"
883
   ]
884
  },
885
  {
886
   "cell_type": "markdown",
887
   "id": "1e39c99a-5f33-4052-ba0e-cbc69d104ec1",
888
   "metadata": {},
889
   "source": [
890
    "Here we're able to embed our chunks at scale by using [map_batches](https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.map_batches.html). All we had to do was define the `batch_size` and the compute resources."
891
   ]
892
  },
893
  {
894
   "cell_type": "code",
895
   "execution_count": null,
896
   "id": "9715a01e-dc67-4342-a0cb-30e770852097",
897
   "metadata": {
898
    "tags": []
899
   },
900
   "outputs": [],
901
   "source": [
902
    "# Embed chunks\n",
903
    "embedding_model_name = \"thenlper/gte-base\"\n",
904
    "embedded_chunks = chunks_ds.map_batches(\n",
905
    "    EmbedChunks,\n",
906
    "    fn_constructor_kwargs={\"model_name\": embedding_model_name},\n",
907
    "    batch_size=100, \n",
908
    "    num_gpus=1,\n",
909
    "    compute=ActorPoolStrategy(size=1))"
910
   ]
911
  },
912
  {
913
   "cell_type": "code",
914
   "execution_count": null,
915
   "id": "67dffa1f-19a3-4411-af3f-b161a47ee164",
916
   "metadata": {
917
    "tags": []
918
   },
919
   "outputs": [
920
    {
921
     "name": "stderr",
922
     "output_type": "stream",
923
     "text": [
924
      "2023-12-27 18:56:56,384\tINFO streaming_executor.py:93 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[FlatMap(extract_sections)->FlatMap(partial)] -> ActorPoolMapOperator[MapBatches(EmbedChunks->Limit[1])] -> LimitOperator[limit=1]\n",
925
      "2023-12-27 18:56:56,385\tINFO streaming_executor.py:94 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\n",
926
      "2023-12-27 18:56:56,385\tINFO streaming_executor.py:96 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\n",
927
      "2023-12-27 18:56:56,401\tINFO actor_pool_map_operator.py:106 -- MapBatches(EmbedChunks->Limit[1]): Waiting for 1 pool actors to start...\n"
928
     ]
929
    },
930
    {
931
     "data": {
932
      "application/vnd.jupyter.widget-view+json": {
933
       "model_id": "",
934
       "version_major": 2,
935
       "version_minor": 0
936
      },
937
      "text/plain": [
938
       "Running 0:   0%|          | 0/1 [00:00<?, ?it/s]"
939
      ]
940
     },
941
     "metadata": {},
942
     "output_type": "display_data"
943
    },
944
    {
945
     "name": "stdout",
946
     "output_type": "stream",
947
     "text": [
948
      "embedding size: 1024\n",
949
      "Reference#\n",
950
      "Monitor and debug your Ray applications and clusters using the API and CLI documented in these references.\n",
951
      "The guides include:\n",
952
      "State API\n",
953
      "State CLI\n",
954
      "System Metrics\n"
955
     ]
956
    }
957
   ],
958
   "source": [
959
    "# Sample\n",
960
    "sample = embedded_chunks.take(1)\n",
961
    "print (\"embedding size:\", len(sample[0][\"embeddings\"]))\n",
962
    "print (sample[0][\"text\"])"
963
   ]
964
  },
965
  {
966
   "cell_type": "markdown",
967
   "id": "09187588-e1dc-44c5-b88b-cc8ebe3f9c48",
968
   "metadata": {},
969
   "source": [
970
    "## Index data"
971
   ]
972
  },
973
  {
974
   "cell_type": "markdown",
975
   "id": "eb9f53f1-2720-4a47-8b2f-1b206c38c425",
976
   "metadata": {},
977
   "source": [
978
    "Now that we have our embedded chunks, we need to index (store) them somewhere so that we can retrieve them quickly for inference. While there are many popular vector database options, we're going to use [Postgres with pgvector](https://github.com/pgvector/pgvector) for it's simplificty and performance. We'll create a table (`document`) and write the (`text`, `source`, `embedding`) triplets for each embedded chunk we have.\n",
979
    "\n",
980
    "<img width=\"700\" src=\"https://images.ctfassets.net/xjan103pcp94/3z1ryYkOtUjj6N1IuavJPf/ae60dc4a10c94e2cc928c38701befb51/image2.png\">"
981
   ]
982
  },
983
  {
984
   "cell_type": "code",
985
   "execution_count": null,
986
   "id": "bb6bf17c-da04-4bd8-af2f-dc79e4e27913",
987
   "metadata": {
988
    "tags": []
989
   },
990
   "outputs": [],
991
   "source": [
992
    "import psycopg\n",
993
    "from pgvector.psycopg import register_vector\n",
994
    "os.environ[\"MIGRATION_FP\"] = f\"../migrations/vector-{EMBEDDING_DIMENSIONS[embedding_model_name]}.sql\"\n",
995
    "os.environ[\"SQL_DUMP_FP\"] = f\"{EFS_DIR}/sql_dumps/{embedding_model_name.split('/')[-1]}_{chunk_size}_{chunk_overlap}.sql\""
996
   ]
997
  },
998
  {
999
   "cell_type": "code",
1000
   "execution_count": null,
1001
   "id": "5b64894f-4de1-42b8-b56f-37800b5ce591",
1002
   "metadata": {
1003
    "tags": []
1004
   },
1005
   "outputs": [
1006
    {
1007
     "name": "stderr",
1008
     "output_type": "stream",
1009
     "text": [
1010
      "NOTICE:  table \"document\" does not exist, skipping\n"
1011
     ]
1012
    },
1013
    {
1014
     "name": "stdout",
1015
     "output_type": "stream",
1016
     "text": [
1017
      "DROP TABLE\n",
1018
      "../migrations/vector-768.sql\n",
1019
      "CREATE TABLE\n",
1020
      "/efs/shared_storage/goku/sql_dumps/gte-base_300_50.sql\n"
1021
     ]
1022
    }
1023
   ],
1024
   "source": [
1025
    "%%bash\n",
1026
    "# Set up\n",
1027
    "psql \"$DB_CONNECTION_STRING\" -c \"DROP TABLE IF EXISTS document;\"\n",
1028
    "echo $MIGRATION_FP\n",
1029
    "sudo -u postgres psql -f $MIGRATION_FP\n",
1030
    "echo $SQL_DUMP_FP"
1031
   ]
1032
  },
1033
  {
1034
   "cell_type": "markdown",
1035
   "id": "633f3c88-0c88-48b5-a6f9-4b08e3a0dc43",
1036
   "metadata": {},
1037
   "source": [
1038
    "**Note**: Run `bash setup-pgvector.sh` first!"
1039
   ]
1040
  },
1041
  {
1042
   "cell_type": "code",
1043
   "execution_count": null,
1044
   "id": "806b47aa-f3c1-44d3-b041-a3a4f0867653",
1045
   "metadata": {
1046
    "tags": []
1047
   },
1048
   "outputs": [
1049
    {
1050
     "name": "stdout",
1051
     "output_type": "stream",
1052
     "text": [
1053
      "DROP TABLE\n",
1054
      "CREATE TABLE\n",
1055
      "SET\n",
1056
      "SET\n",
1057
      "SET\n",
1058
      "SET\n",
1059
      "SET\n",
1060
      " set_config \n",
1061
      "------------\n",
1062
      " \n",
1063
      "(1 row)\n",
1064
      "\n",
1065
      "SET\n",
1066
      "SET\n",
1067
      "SET\n",
1068
      "SET\n",
1069
      "ALTER TABLE\n",
1070
      "ALTER TABLE\n"
1071
     ]
1072
    },
1073
    {
1074
     "name": "stderr",
1075
     "output_type": "stream",
1076
     "text": [
1077
      "psql:/efs/shared_storage/goku/sql_dumps/gte-base_300_50.sql:20: ERROR:  relation \"public.data_document\" does not exist\n",
1078
      "psql:/efs/shared_storage/goku/sql_dumps/gte-base_300_50.sql:22: ERROR:  relation \"public.data_document\" does not exist\n"
1079
     ]
1080
    },
1081
    {
1082
     "name": "stdout",
1083
     "output_type": "stream",
1084
     "text": [
1085
      "DROP SEQUENCE\n",
1086
      "DROP TABLE\n"
1087
     ]
1088
    },
1089
    {
1090
     "name": "stderr",
1091
     "output_type": "stream",
1092
     "text": [
1093
      "psql:/efs/shared_storage/goku/sql_dumps/gte-base_300_50.sql:25: ERROR:  sequence \"data_document_id_seq\" does not exist\n",
1094
      "psql:/efs/shared_storage/goku/sql_dumps/gte-base_300_50.sql:26: ERROR:  table \"data_document\" does not exist\n"
1095
     ]
1096
    },
1097
    {
1098
     "name": "stdout",
1099
     "output_type": "stream",
1100
     "text": [
1101
      "DROP EXTENSION\n",
1102
      "CREATE EXTENSION\n",
1103
      "COMMENT\n",
1104
      "SET\n",
1105
      "SET\n",
1106
      "CREATE TABLE\n",
1107
      "ALTER TABLE\n",
1108
      "CREATE SEQUENCE\n",
1109
      "ALTER SEQUENCE\n",
1110
      "ALTER SEQUENCE\n",
1111
      "CREATE TABLE\n",
1112
      "ALTER TABLE\n",
1113
      "CREATE SEQUENCE\n",
1114
      "ALTER SEQUENCE\n",
1115
      "ALTER SEQUENCE\n",
1116
      "ALTER TABLE\n",
1117
      "ALTER TABLE\n",
1118
      "COPY 40433\n",
1119
      "COPY 32276\n",
1120
      " setval \n",
1121
      "--------\n",
1122
      "  40433\n",
1123
      "(1 row)\n",
1124
      "\n",
1125
      " setval \n",
1126
      "--------\n",
1127
      "  32276\n",
1128
      "(1 row)\n",
1129
      "\n",
1130
      "ALTER TABLE\n",
1131
      "ALTER TABLE\n",
1132
      " count \n",
1133
      "-------\n",
1134
      " 32276\n",
1135
      "(1 row)\n",
1136
      "\n"
1137
     ]
1138
    }
1139
   ],
1140
   "source": [
1141
    "%%bash\n",
1142
    "# Drop table and load index\n",
1143
    "psql \"$DB_CONNECTION_STRING\" -c \"DROP TABLE IF EXISTS document;\"  # drop\n",
1144
    "sudo -u postgres psql -f $MIGRATION_FP  # create\n",
1145
    "psql \"$DB_CONNECTION_STRING\" -f $SQL_DUMP_FP  # load\n",
1146
    "psql \"$DB_CONNECTION_STRING\" -c \"SELECT count(*) FROM document;\"  # num rows"
1147
   ]
1148
  },
1149
  {
1150
   "cell_type": "markdown",
1151
   "id": "3fe36601-1269-482b-9bac-b52b32fb338e",
1152
   "metadata": {},
1153
   "source": [
1154
    "If we don't have an index saved already, we can index the data and save it:"
1155
   ]
1156
  },
1157
  {
1158
   "cell_type": "code",
1159
   "execution_count": null,
1160
   "id": "21480e47-4a17-49d8-b4e7-c301e8040a69",
1161
   "metadata": {
1162
    "tags": []
1163
   },
1164
   "outputs": [],
1165
   "source": [
1166
    "class StoreResults:\n",
1167
    "    def __call__(self, batch):\n",
1168
    "        with psycopg.connect(os.environ[\"DB_CONNECTION_STRING\"]) as conn:\n",
1169
    "            register_vector(conn)\n",
1170
    "            with conn.cursor() as cur:\n",
1171
    "                for text, source, embedding in zip(batch[\"text\"], batch[\"source\"], batch[\"embeddings\"]):\n",
1172
    "                    cur.execute(\"INSERT INTO document (text, source, embedding) VALUES (%s, %s, %s)\", (text, source, embedding,),)\n",
1173
    "        return {}"
1174
   ]
1175
  },
1176
  {
1177
   "cell_type": "markdown",
1178
   "id": "4d4f119d-535d-4548-bb83-177e3bcefeed",
1179
   "metadata": {},
1180
   "source": [
1181
    "And once again, we can use Ray Data’s [map_batches](https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.map_batches.html) to perform this indexing in parallel:"
1182
   ]
1183
  },
1184
  {
1185
   "cell_type": "code",
1186
   "execution_count": null,
1187
   "id": "9c4582d9-40ba-4a94-81ac-259b3851f837",
1188
   "metadata": {
1189
    "tags": []
1190
   },
1191
   "outputs": [
1192
    {
1193
     "name": "stderr",
1194
     "output_type": "stream",
1195
     "text": [
1196
      "2023-12-27 18:57:35,648\tINFO streaming_executor.py:93 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[FlatMap(extract_sections)->FlatMap(partial)] -> ActorPoolMapOperator[MapBatches(EmbedChunks)] -> ActorPoolMapOperator[MapBatches(StoreResults)]\n",
1197
      "2023-12-27 18:57:35,648\tINFO streaming_executor.py:94 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\n",
1198
      "2023-12-27 18:57:35,649\tINFO streaming_executor.py:96 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\n",
1199
      "2023-12-27 18:57:35,664\tINFO actor_pool_map_operator.py:106 -- MapBatches(EmbedChunks): Waiting for 1 pool actors to start...\n",
1200
      "2023-12-27 18:57:42,073\tINFO actor_pool_map_operator.py:106 -- MapBatches(StoreResults): Waiting for 6 pool actors to start...\n"
1201
     ]
1202
    },
1203
    {
1204
     "data": {
1205
      "application/vnd.jupyter.widget-view+json": {
1206
       "model_id": "6c8a60a6a0684709b2fe3ac7b81ebf9a",
1207
       "version_major": 2,
1208
       "version_minor": 0
1209
      },
1210
      "text/plain": [
1211
       "Running 0:   0%|          | 0/200 [00:00<?, ?it/s]"
1212
      ]
1213
     },
1214
     "metadata": {},
1215
     "output_type": "display_data"
1216
    }
1217
   ],
1218
   "source": [
1219
    "# Index data\n",
1220
    "embedded_chunks.map_batches(\n",
1221
    "    StoreResults,\n",
1222
    "    batch_size=128,\n",
1223
    "    num_cpus=1,\n",
1224
    "    compute=ActorPoolStrategy(size=6),\n",
1225
    ").count()"
1226
   ]
1227
  },
1228
  {
1229
   "cell_type": "code",
1230
   "execution_count": null,
1231
   "id": "40925cd6-41e8-4651-9692-aeb399b68af6",
1232
   "metadata": {
1233
    "tags": []
1234
   },
1235
   "outputs": [],
1236
   "source": [
1237
    "%%bash\n",
1238
    "# Save index\n",
1239
    "rm -rf $SQL_DUMP_FP\n",
1240
    "mkdir -p $(dirname \"$SQL_DUMP_FP\") && touch $SQL_DUMP_FP\n",
1241
    "sudo -u postgres pg_dump -c > $SQL_DUMP_FP  # save"
1242
   ]
1243
  },
1244
  {
1245
   "cell_type": "markdown",
1246
   "id": "318798c3-d119-4eb5-ad81-b2834516151a",
1247
   "metadata": {},
1248
   "source": [
1249
    "# Retrieval"
1250
   ]
1251
  },
1252
  {
1253
   "cell_type": "markdown",
1254
   "id": "974cc146-c337-4478-a119-5daadedd340c",
1255
   "metadata": {},
1256
   "source": [
1257
    "With our embedded chunks indexed in our vector database, we're ready to perform retrieval for a given query. We'll start by using the same embedding model we used to embed our text chunks to now embed the incoming query.\n",
1258
    "\n",
1259
    "<img width=\"1000\" src=\"https://images.ctfassets.net/xjan103pcp94/1hKBrFU2lyR5LLebFyq2ZL/8845c36ff98eb47005338de6ab6dbf50/image14.png\">"
1260
   ]
1261
  },
1262
  {
1263
   "cell_type": "code",
1264
   "execution_count": null,
1265
   "id": "480d4c49-5870-471e-a617-86f7d3fa13d0",
1266
   "metadata": {
1267
    "tags": []
1268
   },
1269
   "outputs": [],
1270
   "source": [
1271
    "import json\n",
1272
    "import numpy as np"
1273
   ]
1274
  },
1275
  {
1276
   "cell_type": "code",
1277
   "execution_count": null,
1278
   "id": "39c3f410-89a2-4992-8cd1-63aca5bf9936",
1279
   "metadata": {
1280
    "tags": []
1281
   },
1282
   "outputs": [
1283
    {
1284
     "data": {
1285
      "application/vnd.jupyter.widget-view+json": {
1286
       "model_id": "e4980860fc4e4b33b6a4d2023da47571",
1287
       "version_major": 2,
1288
       "version_minor": 0
1289
      },
1290
      "text/plain": [
1291
       "Downloading .gitattributes:   0%|          | 0.00/1.52k [00:00<?, ?B/s]"
1292
      ]
1293
     },
1294
     "metadata": {},
1295
     "output_type": "display_data"
1296
    },
1297
    {
1298
     "data": {
1299
      "application/vnd.jupyter.widget-view+json": {
1300
       "model_id": "6f6efbd4d4f44897b9d0afa4221e98ba",
1301
       "version_major": 2,
1302
       "version_minor": 0
1303
      },
1304
      "text/plain": [
1305
       "Downloading 1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]"
1306
      ]
1307
     },
1308
     "metadata": {},
1309
     "output_type": "display_data"
1310
    },
1311
    {
1312
     "data": {
1313
      "application/vnd.jupyter.widget-view+json": {
1314
       "model_id": "bb527f71a7074b8a959871bf6af3234c",
1315
       "version_major": 2,
1316
       "version_minor": 0
1317
      },
1318
      "text/plain": [
1319
       "Downloading README.md:   0%|          | 0.00/68.1k [00:00<?, ?B/s]"
1320
      ]
1321
     },
1322
     "metadata": {},
1323
     "output_type": "display_data"
1324
    },
1325
    {
1326
     "data": {
1327
      "application/vnd.jupyter.widget-view+json": {
1328
       "model_id": "ba6ab30492bc418db4a8c0c0e94b1807",
1329
       "version_major": 2,
1330
       "version_minor": 0
1331
      },
1332
      "text/plain": [
1333
       "Downloading config.json:   0%|          | 0.00/618 [00:00<?, ?B/s]"
1334
      ]
1335
     },
1336
     "metadata": {},
1337
     "output_type": "display_data"
1338
    },
1339
    {
1340
     "data": {
1341
      "application/vnd.jupyter.widget-view+json": {
1342
       "model_id": "1c69f83aa7244e7a82c373239579bcd8",
1343
       "version_major": 2,
1344
       "version_minor": 0
1345
      },
1346
      "text/plain": [
1347
       "Downloading model.safetensors:   0%|          | 0.00/219M [00:00<?, ?B/s]"
1348
      ]
1349
     },
1350
     "metadata": {},
1351
     "output_type": "display_data"
1352
    },
1353
    {
1354
     "data": {
1355
      "application/vnd.jupyter.widget-view+json": {
1356
       "model_id": "587dbb9e4596404eb4ea754d14ccf409",
1357
       "version_major": 2,
1358
       "version_minor": 0
1359
      },
1360
      "text/plain": [
1361
       "Downloading onnx/config.json:   0%|          | 0.00/630 [00:00<?, ?B/s]"
1362
      ]
1363
     },
1364
     "metadata": {},
1365
     "output_type": "display_data"
1366
    },
1367
    {
1368
     "data": {
1369
      "application/vnd.jupyter.widget-view+json": {
1370
       "model_id": "54c33bf01b874262bf676416647b8790",
1371
       "version_major": 2,
1372
       "version_minor": 0
1373
      },
1374
      "text/plain": [
1375
       "Downloading model.onnx:   0%|          | 0.00/436M [00:00<?, ?B/s]"
1376
      ]
1377
     },
1378
     "metadata": {},
1379
     "output_type": "display_data"
1380
    },
1381
    {
1382
     "data": {
1383
      "application/vnd.jupyter.widget-view+json": {
1384
       "model_id": "79ffe43a2e0e4fa68acae617d41c2ad6",
1385
       "version_major": 2,
1386
       "version_minor": 0
1387
      },
1388
      "text/plain": [
1389
       "Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]"
1390
      ]
1391
     },
1392
     "metadata": {},
1393
     "output_type": "display_data"
1394
    },
1395
    {
1396
     "data": {
1397
      "application/vnd.jupyter.widget-view+json": {
1398
       "model_id": "0bc1532dc584492bbfd21ec03cdfa9b9",
1399
       "version_major": 2,
1400
       "version_minor": 0
1401
      },
1402
      "text/plain": [
1403
       "Downloading onnx/tokenizer.json:   0%|          | 0.00/712k [00:00<?, ?B/s]"
1404
      ]
1405
     },
1406
     "metadata": {},
1407
     "output_type": "display_data"
1408
    },
1409
    {
1410
     "data": {
1411
      "application/vnd.jupyter.widget-view+json": {
1412
       "model_id": "0ee54809d1bc4740a573b52c80b53337",
1413
       "version_major": 2,
1414
       "version_minor": 0
1415
      },
1416
      "text/plain": [
1417
       "Downloading (…)okenizer_config.json:   0%|          | 0.00/314 [00:00<?, ?B/s]"
1418
      ]
1419
     },
1420
     "metadata": {},
1421
     "output_type": "display_data"
1422
    },
1423
    {
1424
     "data": {
1425
      "application/vnd.jupyter.widget-view+json": {
1426
       "model_id": "f0fe9fcc9c114b5eb0e93a5425057e7d",
1427
       "version_major": 2,
1428
       "version_minor": 0
1429
      },
1430
      "text/plain": [
1431
       "Downloading onnx/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]"
1432
      ]
1433
     },
1434
     "metadata": {},
1435
     "output_type": "display_data"
1436
    },
1437
    {
1438
     "data": {
1439
      "application/vnd.jupyter.widget-view+json": {
1440
       "model_id": "f666e1d6c9434e7e85dcb944cc484412",
1441
       "version_major": 2,
1442
       "version_minor": 0
1443
      },
1444
      "text/plain": [
1445
       "Downloading pytorch_model.bin:   0%|          | 0.00/219M [00:00<?, ?B/s]"
1446
      ]
1447
     },
1448
     "metadata": {},
1449
     "output_type": "display_data"
1450
    },
1451
    {
1452
     "data": {
1453
      "application/vnd.jupyter.widget-view+json": {
1454
       "model_id": "3e3e1f8cc218404f9bc7f12a1bc87a90",
1455
       "version_major": 2,
1456
       "version_minor": 0
1457
      },
1458
      "text/plain": [
1459
       "Downloading (…)nce_bert_config.json:   0%|          | 0.00/57.0 [00:00<?, ?B/s]"
1460
      ]
1461
     },
1462
     "metadata": {},
1463
     "output_type": "display_data"
1464
    },
1465
    {
1466
     "data": {
1467
      "application/vnd.jupyter.widget-view+json": {
1468
       "model_id": "0ff2417687d5418497a0e9f431e25c3e",
1469
       "version_major": 2,
1470
       "version_minor": 0
1471
      },
1472
      "text/plain": [
1473
       "Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]"
1474
      ]
1475
     },
1476
     "metadata": {},
1477
     "output_type": "display_data"
1478
    },
1479
    {
1480
     "data": {
1481
      "application/vnd.jupyter.widget-view+json": {
1482
       "model_id": "f7dc16c9f10d425abb153bdfccab229f",
1483
       "version_major": 2,
1484
       "version_minor": 0
1485
      },
1486
      "text/plain": [
1487
       "Downloading tokenizer.json:   0%|          | 0.00/712k [00:00<?, ?B/s]"
1488
      ]
1489
     },
1490
     "metadata": {},
1491
     "output_type": "display_data"
1492
    },
1493
    {
1494
     "data": {
1495
      "application/vnd.jupyter.widget-view+json": {
1496
       "model_id": "3c4bfb3f00354be8a9ebbb2da3246cb5",
1497
       "version_major": 2,
1498
       "version_minor": 0
1499
      },
1500
      "text/plain": [
1501
       "Downloading tokenizer_config.json:   0%|          | 0.00/314 [00:00<?, ?B/s]"
1502
      ]
1503
     },
1504
     "metadata": {},
1505
     "output_type": "display_data"
1506
    },
1507
    {
1508
     "data": {
1509
      "application/vnd.jupyter.widget-view+json": {
1510
       "model_id": "38ae029cea3f42ebadc61dde75694d08",
1511
       "version_major": 2,
1512
       "version_minor": 0
1513
      },
1514
      "text/plain": [
1515
       "Downloading vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]"
1516
      ]
1517
     },
1518
     "metadata": {},
1519
     "output_type": "display_data"
1520
    },
1521
    {
1522
     "data": {
1523
      "application/vnd.jupyter.widget-view+json": {
1524
       "model_id": "8a05efb7db1646649629a846092e4f8d",
1525
       "version_major": 2,
1526
       "version_minor": 0
1527
      },
1528
      "text/plain": [
1529
       "Downloading modules.json:   0%|          | 0.00/385 [00:00<?, ?B/s]"
1530
      ]
1531
     },
1532
     "metadata": {},
1533
     "output_type": "display_data"
1534
    },
1535
    {
1536
     "data": {
1537
      "text/plain": [
1538
       "768"
1539
      ]
1540
     },
1541
     "execution_count": null,
1542
     "metadata": {},
1543
     "output_type": "execute_result"
1544
    }
1545
   ],
1546
   "source": [
1547
    "# Embed query\n",
1548
    "embedding_model = HuggingFaceEmbeddings(model_name=embedding_model_name)\n",
1549
    "query = \"What is the default batch size for map_batches?\"\n",
1550
    "embedding = np.array(embedding_model.embed_query(query))\n",
1551
    "len(embedding)"
1552
   ]
1553
  },
1554
  {
1555
   "cell_type": "markdown",
1556
   "id": "cd3d604e-437c-4917-9503-cdc6a74df982",
1557
   "metadata": {},
1558
   "source": [
1559
    "Then, we'll retrieve the top most revelant chunks by extracting the closest embedded chunks to our embedded query. We use cosine distance (`<=>`) but there are [many options](https://github.com/pgvector/pgvector#vector-operators) to choose from. Once we retrieve the top `num_chunks`, we can collect the text for each chunk and use it as context to generate a response."
1560
   ]
1561
  },
1562
  {
1563
   "cell_type": "code",
1564
   "execution_count": null,
1565
   "id": "073a796f-e0e7-46d3-8151-01f47befec4c",
1566
   "metadata": {
1567
    "tags": []
1568
   },
1569
   "outputs": [],
1570
   "source": [
1571
    "# Get context\n",
1572
    "num_chunks = 5\n",
1573
    "with psycopg.connect(os.environ[\"DB_CONNECTION_STRING\"]) as conn:\n",
1574
    "    register_vector(conn)\n",
1575
    "    with conn.cursor() as cur:\n",
1576
    "        # cur.execute(\"SELECT * FROM document ORDER BY embedding <=> %s LIMIT %s\", (embedding, num_chunks))\n",
1577
    "        cur.execute(\"SELECT *, (embedding <=> %s) AS similarity_score FROM document ORDER BY similarity_score LIMIT %s\", (embedding, num_chunks))\n",
1578
    "        rows = cur.fetchall()\n",
1579
    "        ids = [row[0] for row in rows]\n",
1580
    "        context = [{\"text\": row[1]} for row in rows]\n",
1581
    "        sources = [row[2] for row in rows]\n",
1582
    "        scores = [row[4] for row in rows]"
1583
   ]
1584
  },
1585
  {
1586
   "cell_type": "code",
1587
   "execution_count": null,
1588
   "id": "4a3cfb58-4c25-4a14-9634-f6a24fac970b",
1589
   "metadata": {
1590
    "tags": []
1591
   },
1592
   "outputs": [
1593
    {
1594
     "name": "stdout",
1595
     "output_type": "stream",
1596
     "text": [
1597
      "14887\n",
1598
      "0.06207154583656549\n",
1599
      "https://docs.ray.io/en/master/data/api/doc/ray.data.Dataset.map_batches.html#ray-data-dataset-map-batches\n",
1600
      "entire blocks as batches (blocks may contain different numbers of rows).\n",
1601
      "The actual size of the batch provided to fn may be smaller than\n",
1602
      "batch_size if batch_size doesn’t evenly divide the block(s) sent\n",
1603
      "to a given map task. Default batch_size is 4096 with “default”.\n",
1604
      "\n",
1605
      "15009\n",
1606
      "0.06654224236622186\n",
1607
      "https://docs.ray.io/en/master/data/transforming-data.html#configuring-batch-size\n",
1608
      "batch_size.\n",
1609
      "Note\n",
1610
      "The default batch size depends on your resource type. If you’re using CPUs,\n",
1611
      "the default batch size is 4096. If you’re using GPUs, you must specify an explicit\n",
1612
      "batch size.\n",
1613
      "\n",
1614
      "15051\n",
1615
      "0.07562640027010847\n",
1616
      "https://docs.ray.io/en/master/data/batch_inference.html#configuring-batch-size\n",
1617
      "# Specify that each input batch should be of size 2.\n",
1618
      "ds.map_batches(assert_batch, batch_size=2)\n",
1619
      "Caution\n",
1620
      "The default batch_size of 4096 may be too large for datasets with large rows\n",
1621
      "(for example, tables with many columns or a collection of large images).\n",
1622
      "\n",
1623
      "15035\n",
1624
      "0.08958362418550658\n",
1625
      "https://docs.ray.io/en/master/data/batch_inference.html#configuring-batch-size\n",
1626
      "Configuring Batch Size#\n",
1627
      "Configure the size of the input batch that’s passed to __call__ by setting the batch_size argument for ds.map_batches()\n",
1628
      "\n",
1629
      "2283\n",
1630
      "0.09523273435425617\n",
1631
      "https://docs.ray.io/en/master/tune/getting-started.html#setting-up-a-tuner-for-a-training-run-with-tune\n",
1632
      "batch_size=64,\n",
1633
      "        shuffle=True)\n",
1634
      "\n"
1635
     ]
1636
    }
1637
   ],
1638
   "source": [
1639
    "for i, item in enumerate(context):\n",
1640
    "    print (ids[i])\n",
1641
    "    print (scores[i])\n",
1642
    "    print (sources[i])\n",
1643
    "    print (item[\"text\"])\n",
1644
    "    print ()"
1645
   ]
1646
  },
1647
  {
1648
   "cell_type": "markdown",
1649
   "id": "6fc7d446-3936-4c4a-aa01-f6dc46991416",
1650
   "metadata": {},
1651
   "source": [
1652
    "Let's wrap this into a convenient function:"
1653
   ]
1654
  },
1655
  {
1656
   "cell_type": "code",
1657
   "execution_count": null,
1658
   "id": "d26a88bd-fbf0-4023-bfc7-e0501073caeb",
1659
   "metadata": {
1660
    "tags": []
1661
   },
1662
   "outputs": [],
1663
   "source": [
1664
    "def semantic_search(query, embedding_model, k):\n",
1665
    "    embedding = np.array(embedding_model.embed_query(query))\n",
1666
    "    with psycopg.connect(os.environ[\"DB_CONNECTION_STRING\"]) as conn:\n",
1667
    "        register_vector(conn)\n",
1668
    "        with conn.cursor() as cur:\n",
1669
    "            cur.execute(\"SELECT * FROM document ORDER BY embedding <=> %s LIMIT %s\", (embedding, k),)\n",
1670
    "            rows = cur.fetchall()\n",
1671
    "            semantic_context = [{\"id\": row[0], \"text\": row[1], \"source\": row[2]} for row in rows]\n",
1672
    "    return semantic_context"
1673
   ]
1674
  },
1675
  {
1676
   "cell_type": "markdown",
1677
   "id": "f5738d23-91e3-4016-826e-716872f76b62",
1678
   "metadata": {
1679
    "tags": []
1680
   },
1681
   "source": [
1682
    "# Generation"
1683
   ]
1684
  },
1685
  {
1686
   "cell_type": "markdown",
1687
   "id": "0bec2b7c-35be-40d1-9e12-29d7d6178d4d",
1688
   "metadata": {},
1689
   "source": [
1690
    "We can now use the context to generate a response from our LLM. Without this relevant context that we retrieved, the LLM may not have been able to accurately answer our question. And as our data grows, we can just as easily embed and index any new data and be able to retrieve it to answer questions.\n",
1691
    "\n",
1692
    "<img width=\"500\" src=\"https://images.ctfassets.net/xjan103pcp94/38I8en8Tyf0cM4LUhjygoq/739d456c80841b4c28fe80f73ea5856b/image16.png\">"
1693
   ]
1694
  },
1695
  {
1696
   "cell_type": "code",
1697
   "execution_count": null,
1698
   "id": "b55cc1d7-e110-4d9d-abc4-36576db25f92",
1699
   "metadata": {
1700
    "tags": []
1701
   },
1702
   "outputs": [],
1703
   "source": [
1704
    "import openai\n",
1705
    "import time"
1706
   ]
1707
  },
1708
  {
1709
   "cell_type": "code",
1710
   "execution_count": null,
1711
   "id": "6510f9dc-5870-4c9a-8ea0-7661d1914f8a",
1712
   "metadata": {
1713
    "tags": []
1714
   },
1715
   "outputs": [
1716
    {
1717
     "data": {
1718
      "application/vnd.jupyter.widget-view+json": {
1719
       "model_id": "d930a41ee53f48fdb23caa48bc8883fa",
1720
       "version_major": 2,
1721
       "version_minor": 0
1722
      },
1723
      "text/plain": [
1724
       "Downloading tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]"
1725
      ]
1726
     },
1727
     "metadata": {},
1728
     "output_type": "display_data"
1729
    },
1730
    {
1731
     "data": {
1732
      "application/vnd.jupyter.widget-view+json": {
1733
       "model_id": "8b31a4cdb53843e7a57bf22267ff8b6f",
1734
       "version_major": 2,
1735
       "version_minor": 0
1736
      },
1737
      "text/plain": [
1738
       "Downloading vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]"
1739
      ]
1740
     },
1741
     "metadata": {},
1742
     "output_type": "display_data"
1743
    },
1744
    {
1745
     "data": {
1746
      "application/vnd.jupyter.widget-view+json": {
1747
       "model_id": "acfa0c58588243af979c6af1a33b20ea",
1748
       "version_major": 2,
1749
       "version_minor": 0
1750
      },
1751
      "text/plain": [
1752
       "Downloading tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]"
1753
      ]
1754
     },
1755
     "metadata": {},
1756
     "output_type": "display_data"
1757
    },
1758
    {
1759
     "data": {
1760
      "application/vnd.jupyter.widget-view+json": {
1761
       "model_id": "3f9c1e847d6e48ce8e98e75db76328ff",
1762
       "version_major": 2,
1763
       "version_minor": 0
1764
      },
1765
      "text/plain": [
1766
       "Downloading config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]"
1767
      ]
1768
     },
1769
     "metadata": {},
1770
     "output_type": "display_data"
1771
    }
1772
   ],
1773
   "source": [
1774
    "from rag.generate import prepare_response\n",
1775
    "from rag.utils import get_client"
1776
   ]
1777
  },
1778
  {
1779
   "cell_type": "code",
1780
   "execution_count": null,
1781
   "id": "697e65e8-4d69-4870-9f09-0e128008b94e",
1782
   "metadata": {
1783
    "tags": []
1784
   },
1785
   "outputs": [],
1786
   "source": [
1787
    "def generate_response(\n",
1788
    "    llm, temperature=0.0, stream=True,\n",
1789
    "    system_content=\"\", assistant_content=\"\", user_content=\"\", \n",
1790
    "    max_retries=1, retry_interval=60):\n",
1791
    "    \"\"\"Generate response from an LLM.\"\"\"\n",
1792
    "    retry_count = 0\n",
1793
    "    client = get_client(llm=llm)\n",
1794
    "    messages = [{\"role\": role, \"content\": content} for role, content in [\n",
1795
    "        (\"system\", system_content), \n",
1796
    "        (\"assistant\", assistant_content), \n",
1797
    "        (\"user\", user_content)] if content]\n",
1798
    "    while retry_count <= max_retries:\n",
1799
    "        try:\n",
1800
    "            chat_completion = client.chat.completions.create(\n",
1801
    "                model=llm,\n",
1802
    "                temperature=temperature,\n",
1803
    "                stream=stream,\n",
1804
    "                messages=messages,\n",
1805
    "            )\n",
1806
    "            return prepare_response(chat_completion, stream=stream)\n",
1807
    "\n",
1808
    "        except Exception as e:\n",
1809
    "            print(f\"Exception: {e}\")\n",
1810
    "            time.sleep(retry_interval)  # default is per-minute rate limits\n",
1811
    "            retry_count += 1\n",
1812
    "    return \"\""
1813
   ]
1814
  },
1815
  {
1816
   "cell_type": "markdown",
1817
   "id": "8dfe6bd9-bad1-4b21-9aea-fc2e127ba3fc",
1818
   "metadata": {},
1819
   "source": [
1820
    "**Note**: We’re using a temperature of 0.0 to enable reproducible experiments but you should adjust this based on your use case. For use cases that need to always be factually grounded, we recommend very low temperature values while more creative tasks can benefit from higher temperatures."
1821
   ]
1822
  },
1823
  {
1824
   "cell_type": "code",
1825
   "execution_count": null,
1826
   "id": "7864459b-8bb3-469b-9583-ae6fa527b124",
1827
   "metadata": {
1828
    "tags": []
1829
   },
1830
   "outputs": [
1831
    {
1832
     "name": "stdout",
1833
     "output_type": "stream",
1834
     "text": [
1835
      "['entire blocks as batches (blocks may contain different numbers of rows).\\nThe actual size of the batch provided to fn may be smaller than\\nbatch_size if batch_size doesn’t evenly divide the block(s) sent\\nto a given map task. Default batch_size is 4096 with “default”.', 'batch_size.\\nNote\\nThe default batch size depends on your resource type. If you’re using CPUs,\\nthe default batch size is 4096. If you’re using GPUs, you must specify an explicit\\nbatch size.', '# Specify that each input batch should be of size 2.\\nds.map_batches(assert_batch, batch_size=2)\\nCaution\\nThe default batch_size of 4096 may be too large for datasets with large rows\\n(for example, tables with many columns or a collection of large images).', 'Configuring Batch Size#\\nConfigure the size of the input batch that’s passed to __call__ by setting the batch_size argument for ds.map_batches()', 'batch_size=64,\\n        shuffle=True)']\n"
1836
     ]
1837
    }
1838
   ],
1839
   "source": [
1840
    "context_results = semantic_search(query=query, embedding_model=embedding_model, k=5)\n",
1841
    "context = [item[\"text\"] for item in context_results]\n",
1842
    "print(context)"
1843
   ]
1844
  },
1845
  {
1846
   "cell_type": "code",
1847
   "execution_count": null,
1848
   "id": "9de3685b-6839-445f-9baa-68a5e863562a",
1849
   "metadata": {
1850
    "tags": []
1851
   },
1852
   "outputs": [
1853
    {
1854
     "name": "stdout",
1855
     "output_type": "stream",
1856
     "text": [
1857
      "  The default batch size for map_batches is 4096."
1858
     ]
1859
    }
1860
   ],
1861
   "source": [
1862
    "# Generate response\n",
1863
    "query = \"What is the default batch size for map_batches?\"\n",
1864
    "response = generate_response(\n",
1865
    "    llm=\"meta-llama/Llama-2-70b-chat-hf\",\n",
1866
    "    temperature=0.0,\n",
1867
    "    stream=True,\n",
1868
    "    system_content=\"Answer the query using the context provided. Be succinct.\",\n",
1869
    "    user_content=f\"query: {query}, context: {context}\")\n",
1870
    "# Stream response\n",
1871
    "for content in response:\n",
1872
    "    print(content, end='', flush=True)"
1873
   ]
1874
  },
1875
  {
1876
   "cell_type": "markdown",
1877
   "id": "77e0511e-d79c-4d54-a50f-163c11fc8643",
1878
   "metadata": {},
1879
   "source": [
1880
    "## Agent"
1881
   ]
1882
  },
1883
  {
1884
   "cell_type": "markdown",
1885
   "id": "bad51191-19d1-40d5-bbc8-72245fafd154",
1886
   "metadata": {},
1887
   "source": [
1888
    "Let's combine the context retrieval and response generation together into a convenient query agent that we can use to easily generate our responses. This will take care of setting up our agent (embedding and LLM model), as well as the context retrieval, and pass it to our LLM for response generation."
1889
   ]
1890
  },
1891
  {
1892
   "cell_type": "code",
1893
   "execution_count": null,
1894
   "id": "dc002554-a525-48e0-8389-e60594d29221",
1895
   "metadata": {
1896
    "tags": []
1897
   },
1898
   "outputs": [],
1899
   "source": [
1900
    "from rag.embed import get_embedding_model\n",
1901
    "from rag.utils import get_num_tokens, trim"
1902
   ]
1903
  },
1904
  {
1905
   "cell_type": "code",
1906
   "execution_count": null,
1907
   "id": "ffcc892a-ceee-487a-aecf-db116f53e89f",
1908
   "metadata": {
1909
    "tags": []
1910
   },
1911
   "outputs": [],
1912
   "source": [
1913
    "class QueryAgent:\n",
1914
    "    def __init__(self, embedding_model_name=\"thenlper/gte-base\",\n",
1915
    "                 llm=\"meta-llama/Llama-2-70b-chat-hf\", temperature=0.0, \n",
1916
    "                 max_context_length=4096, system_content=\"\", assistant_content=\"\"):\n",
1917
    "        \n",
1918
    "        # Embedding model\n",
1919
    "        self.embedding_model = get_embedding_model(\n",
1920
    "            embedding_model_name=embedding_model_name, \n",
1921
    "            model_kwargs={\"device\": \"cuda\"}, \n",
1922
    "            encode_kwargs={\"device\": \"cuda\", \"batch_size\": 100})\n",
1923
    "        \n",
1924
    "        # Context length (restrict input length to 50% of total context length)\n",
1925
    "        max_context_length = int(0.5*max_context_length)\n",
1926
    "        \n",
1927
    "        # LLM\n",
1928
    "        self.llm = llm\n",
1929
    "        self.temperature = temperature\n",
1930
    "        self.context_length = max_context_length - get_num_tokens(system_content + assistant_content)\n",
1931
    "        self.system_content = system_content\n",
1932
    "        self.assistant_content = assistant_content\n",
1933
    "\n",
1934
    "    def __call__(self, query, num_chunks=5, stream=True):\n",
1935
    "        # Get sources and context\n",
1936
    "        context_results = semantic_search(\n",
1937
    "            query=query, \n",
1938
    "            embedding_model=self.embedding_model, \n",
1939
    "            k=num_chunks)\n",
1940
    "            \n",
1941
    "        # Generate response\n",
1942
    "        context = [item[\"text\"] for item in context_results]\n",
1943
    "        sources = [item[\"source\"] for item in context_results]\n",
1944
    "        user_content = f\"query: {query}, context: {context}\"\n",
1945
    "        answer = generate_response(\n",
1946
    "            llm=self.llm,\n",
1947
    "            temperature=self.temperature,\n",
1948
    "            stream=stream,\n",
1949
    "            system_content=self.system_content,\n",
1950
    "            assistant_content=self.assistant_content,\n",
1951
    "            user_content=trim(user_content, self.context_length))\n",
1952
    "\n",
1953
    "        # Result\n",
1954
    "        result = {\n",
1955
    "            \"question\": query,\n",
1956
    "            \"sources\": sources,\n",
1957
    "            \"answer\": answer,\n",
1958
    "            \"llm\": self.llm,\n",
1959
    "        }\n",
1960
    "        return result"
1961
   ]
1962
  },
1963
  {
1964
   "cell_type": "markdown",
1965
   "id": "29fc2167-3f0b-4cac-96a0-67ecb9854f2c",
1966
   "metadata": {},
1967
   "source": [
1968
    "With this, we can use our RAG application in just a few lines:"
1969
   ]
1970
  },
1971
  {
1972
   "cell_type": "code",
1973
   "execution_count": null,
1974
   "id": "e6e0125b-cc86-4811-aded-e77b661e068a",
1975
   "metadata": {
1976
    "tags": []
1977
   },
1978
   "outputs": [],
1979
   "source": [
1980
    "embedding_model_name = \"thenlper/gte-base\"\n",
1981
    "llm = \"meta-llama/Llama-2-7b-chat-hf\""
1982
   ]
1983
  },
1984
  {
1985
   "cell_type": "code",
1986
   "execution_count": null,
1987
   "id": "ad3b1224-922b-40b5-9979-9075c6ef100e",
1988
   "metadata": {
1989
    "tags": []
1990
   },
1991
   "outputs": [
1992
    {
1993
     "name": "stdout",
1994
     "output_type": "stream",
1995
     "text": [
1996
      "{\n",
1997
      "  \"question\": \"What is the default batch size for map_batches?\",\n",
1998
      "  \"sources\": [\n",
1999
      "    \"https://docs.ray.io/en/master/data/api/doc/ray.data.Dataset.map_batches.html#ray-data-dataset-map-batches\",\n",
2000
      "    \"https://docs.ray.io/en/master/data/transforming-data.html#configuring-batch-size\",\n",
2001
      "    \"https://docs.ray.io/en/master/data/batch_inference.html#configuring-batch-size\",\n",
2002
      "    \"https://docs.ray.io/en/master/data/batch_inference.html#configuring-batch-size\",\n",
2003
      "    \"https://docs.ray.io/en/master/tune/getting-started.html#setting-up-a-tuner-for-a-training-run-with-tune\"\n",
2004
      "  ],\n",
2005
      "  \"answer\": \"  The default batch size for `map_batches` is 4096. However, this can vary depending on the resource type, with a default of 4096 for CPUs and a specific batch size required for GPUs. It's important to note that the default batch size may be too large for datasets with large rows, and configuring the batch size can help improve performance.\",\n",
2006
      "  \"llm\": \"meta-llama/Llama-2-7b-chat-hf\"\n",
2007
      "}\n"
2008
     ]
2009
    }
2010
   ],
2011
   "source": [
2012
    "query = \"What is the default batch size for map_batches?\"\n",
2013
    "system_content = \"Answer the query using the context provided. Be succinct.\"\n",
2014
    "agent = QueryAgent(\n",
2015
    "    embedding_model_name=embedding_model_name,\n",
2016
    "    llm=llm,\n",
2017
    "    max_context_length=MAX_CONTEXT_LENGTHS[llm],\n",
2018
    "    system_content=system_content)\n",
2019
    "result = agent(query=query, stream=False)\n",
2020
    "print(json.dumps(result, indent=2))"
2021
   ]
2022
  },
2023
  {
2024
   "cell_type": "markdown",
2025
   "id": "34feb8ca-5e72-4b12-8ab3-55ca5af05109",
2026
   "metadata": {},
2027
   "source": [
2028
    "# Evaluation"
2029
   ]
2030
  },
2031
  {
2032
   "cell_type": "markdown",
2033
   "id": "6676dbb2-a143-463c-8de2-9a2763e12059",
2034
   "metadata": {},
2035
   "source": [
2036
    "So far, we've chosen typical/arbitrary values for the various parts of our RAG application. But if we were to change something, such as our chunking logic, embedding model, LLM, etc. how can we know that we have a better configuration than before? A generative task like this is very difficult to quantitatively assess and so we need to develop reliable ways to do so.\n",
2037
    "\n",
2038
    "Because we have many moving parts in our application, we need to perform both unit/component and end-to-end evaluation. Component-wise evaluation can involve evaluating our retrieval in isolation (is the best source in our set of retrieved chunks) and evaluating our LLMs response (given the best source, is the LLM able to produce a quality answer). And for end-to-end evaluation, we can assess the quality of the entire system (given the data sources, what is the quality of the response).\n",
2039
    "We'll be asking our evaluator LLM to score the quality of the response between 1-5 using the context, however, we could also have it produce scores for other dimensions such as hallucination (is the generated answer using information only from the provided context), toxicity, etc.\n",
2040
    "\n",
2041
    "**Note**: We could have constrained the score to be binary (0/1), which might be more interpretable (ex. the response was either correct or incorrect). However, we introduced a higher variance in our scores to develop a deeper, fine-grained, understanding of how LLMs score responses (ex. LLM bias towards responses).\n",
2042
    "\n",
2043
    "\n",
2044
    "\n",
2045
    "<img width=\"1000\" src=\"https://images.ctfassets.net/xjan103pcp94/17UQdsEImsXOOdDlT06bvi/4a9b9e46e157541a1178b6938624176a/llm_evaluations.png\">"
2046
   ]
2047
  },
2048
  {
2049
   "cell_type": "code",
2050
   "execution_count": null,
2051
   "id": "7c6f7b9e-4b77-4b0c-8660-ba8047356415",
2052
   "metadata": {
2053
    "tags": []
2054
   },
2055
   "outputs": [],
2056
   "source": [
2057
    "# If running tests / small samples, set num_samples to <10\n",
2058
    "EXPERIMENTS_DIR = Path(ROOT_DIR, \"experiments\")\n",
2059
    "NUM_SAMPLES = None  # None = all samples"
2060
   ]
2061
  },
2062
  {
2063
   "cell_type": "markdown",
2064
   "id": "4fa7a068-fc77-4928-bf58-52321616f9d5",
2065
   "metadata": {},
2066
   "source": [
2067
    "## Evaluator"
2068
   ]
2069
  },
2070
  {
2071
   "cell_type": "markdown",
2072
   "id": "0a19d3e8-b7e8-4d69-aa70-118af1fb1708",
2073
   "metadata": {},
2074
   "source": [
2075
    "We're going to start by determining our evaluator. Given a response to a query and relevant context, our evaluator should be a trusted way to score/assess the quality of the response. But before we can determine our evaluator, we need a dataset of questions and the source where the answer comes from. We can use this dataset to ask our different evaluators to provide an answer and then rate their answer (ex. score between 1-5). We can then inspect this dataset to determine if our evaluator is unbiased and has sound reasoning for the scores that are assigned.\n",
2076
    "\n",
2077
    "**Note**: We’re evaluating the ability of our LLM to generate a response given the relevant context. This is a component-level evaluation (`quality_score (LLM)`) because we aren’t using retrieval to fetch the relevant context."
2078
   ]
2079
  },
2080
  {
2081
   "cell_type": "markdown",
2082
   "id": "03acb7e0-7bcb-4e2f-8552-5fe182a278db",
2083
   "metadata": {},
2084
   "source": [
2085
    "We'll start by manually creating our dataset (keep reading if you can’t manually create a dataset). We have a list of user queries and the ideal source to answer the query [datasets/eval-dataset-v1.jsonl](https://github.com/ray-project/llm-applications/blob/main/datasets/eval-dataset-v1.jsonl). We will our LLM app above to generate reference answer for each query/source pair using `gpt-4`."
2086
   ]
2087
  },
2088
  {
2089
   "cell_type": "code",
2090
   "execution_count": null,
2091
   "id": "fa6cab00-d68d-4745-a735-048d72d8e5d6",
2092
   "metadata": {
2093
    "tags": []
2094
   },
2095
   "outputs": [],
2096
   "source": [
2097
    "from bs4 import BeautifulSoup\n",
2098
    "from IPython.display import JSON, clear_output, display\n",
2099
    "from tqdm import tqdm\n",
2100
    "import urllib.parse"
2101
   ]
2102
  },
2103
  {
2104
   "cell_type": "code",
2105
   "execution_count": null,
2106
   "id": "8c58d5dd-bd0f-40eb-99e3-ea305a925127",
2107
   "metadata": {
2108
    "tags": []
2109
   },
2110
   "outputs": [],
2111
   "source": [
2112
    "from rag.evaluate import extract_from_response\n",
2113
    "from rag.data import fetch_text"
2114
   ]
2115
  },
2116
  {
2117
   "cell_type": "code",
2118
   "execution_count": null,
2119
   "id": "ca019cc9-3241-4453-b63d-ef7c44ba584a",
2120
   "metadata": {
2121
    "tags": []
2122
   },
2123
   "outputs": [],
2124
   "source": [
2125
    "# Load dataset\n",
2126
    "with open(Path(ROOT_DIR, \"datasets/eval-dataset-v1.jsonl\"), \"r\") as f:\n",
2127
    "    data = [json.loads(item) for item in list(f)]"
2128
   ]
2129
  },
2130
  {
2131
   "cell_type": "code",
2132
   "execution_count": null,
2133
   "id": "3cd9e509-1fe6-42ad-886f-0e118e8dd35d",
2134
   "metadata": {
2135
    "tags": []
2136
   },
2137
   "outputs": [
2138
    {
2139
     "data": {
2140
      "text/plain": [
2141
       "[{'question': 'I’m struggling a bit with Ray Data type conversions when I do map_batches. Any advice?',\n",
2142
       "  'source': 'https://docs.ray.io/en/master/data/transforming-data.html#configuring-batch-format'},\n",
2143
       " {'question': 'How does autoscaling work in a Ray Serve application?',\n",
2144
       "  'source': 'https://docs.ray.io/en/master/serve/scaling-and-resource-allocation.html#autoscaling'},\n",
2145
       " {'question': 'how do I get the address of a ray node',\n",
2146
       "  'source': 'https://docs.ray.io/en/master/ray-core/miscellaneous.html#node-information'},\n",
2147
       " {'question': 'Does Ray support NCCL?',\n",
2148
       "  'source': 'https://docs.ray.io/en/master/ray-more-libs/ray-collective.html'},\n",
2149
       " {'question': 'Is Ray integrated with DeepSpeed?',\n",
2150
       "  'source': 'https://docs.ray.io/en/master/ray-air/examples/gptj_deepspeed_fine_tuning.html#fine-tuning-the-model-with-ray-air-a-name-train-a'}]"
2151
      ]
2152
     },
2153
     "execution_count": null,
2154
     "metadata": {},
2155
     "output_type": "execute_result"
2156
    }
2157
   ],
2158
   "source": [
2159
    "data[:5]"
2160
   ]
2161
  },
2162
  {
2163
   "cell_type": "code",
2164
   "execution_count": null,
2165
   "id": "eaeb8f67-0fef-4b0e-a362-cceb04f3295b",
2166
   "metadata": {
2167
    "tags": []
2168
   },
2169
   "outputs": [
2170
    {
2171
     "data": {
2172
      "text/plain": [
2173
       "'\\nConfiguring batch format#\\nRay Data represents batches as dicts of NumPy ndarrays or pandas DataFrames. By\\ndefault, Ray Data represents batches as dicts of NumPy ndarrays.\\nTo configure the batch type, specify batch_format in\\nmap_batches(). You can return either format from your function.\\n\\n\\n\\nNumPy\\nfrom typing import Dict\\nimport numpy as np\\nimport ray\\n\\ndef increase_brightness(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:\\n    batch[\"image\"] = np.clip(batch[\"image\"] + 4, 0, 255)\\n    return batch\\n\\nds = (\\n    ray.data.read_images(\"s3://anonymous@ray-example-data/image-datasets/simple\")\\n    .map_batches(increase_brightness, batch_format=\"numpy\")\\n)\\n\\n\\n\\n\\n\\npandas\\nimport pandas as pd\\nimport ray\\n\\ndef drop_nas(batch: pd.DataFrame) -> pd.DataFrame:\\n    return batch.dropna()\\n\\nds = (\\n    ray.data.read_csv(\"s3://anonymous@air-example-data/iris.csv\")\\n    .map_batches(drop_nas, batch_format=\"pandas\")\\n)\\n\\n\\n\\n\\n'"
2174
      ]
2175
     },
2176
     "execution_count": null,
2177
     "metadata": {},
2178
     "output_type": "execute_result"
2179
    }
2180
   ],
2181
   "source": [
2182
    "# Sample\n",
2183
    "uri = \"https://docs.ray.io/en/master/data/transforming-data.html#configuring-batch-format\"\n",
2184
    "fetch_text(uri=uri)"
2185
   ]
2186
  },
2187
  {
2188
   "cell_type": "code",
2189
   "execution_count": null,
2190
   "id": "dec3460f-c07a-4a2f-95f7-d1ef85d1a064",
2191
   "metadata": {
2192
    "tags": []
2193
   },
2194
   "outputs": [],
2195
   "source": [
2196
    "# Content for inference\n",
2197
    "system_content = \"\"\"\n",
2198
    "    Answer the query using the context provided. Be succinct.\n",
2199
    "    Then, you must {score} your response between 1 and 5.\n",
2200
    "    You must return your response in a line with only the score.\n",
2201
    "    Do not add any more details.\n",
2202
    "    On a separate line provide your {reasoning} for the score as well.\n",
2203
    "    Return your response following the exact format outlined below.\n",
2204
    "    Do not add or remove anything.\n",
2205
    "    And all of this must be in a valid JSON format.\n",
2206
    "    \n",
2207
    "    {\"answer\": answer,\n",
2208
    "     \"score\": score,\n",
2209
    "     \"reasoning\": reasoning}\n",
2210
    "    \"\"\"\n",
2211
    "assistant_content = \"\""
2212
   ]
2213
  },
2214
  {
2215
   "cell_type": "markdown",
2216
   "id": "f95f9309-1d92-4e43-8de4-f5286d11d4af",
2217
   "metadata": {},
2218
   "source": [
2219
    "We can extract the text from this context and pass it to our LLM to generate a response to the question. We’re also going to ask it to score the quality of its response for the query. To do this, we’ve defined a `QueryAgentWithContext` that inherits from `QueryAgent`, with the change that we’re providing the context and it doesn’t need to retrieve it."
2220
   ]
2221
  },
2222
  {
2223
   "cell_type": "code",
2224
   "execution_count": null,
2225
   "id": "7dc33d2e-51b1-4417-8f8d-3699d5840fc0",
2226
   "metadata": {
2227
    "tags": []
2228
   },
2229
   "outputs": [],
2230
   "source": [
2231
    "class QueryAgentWithContext(QueryAgent):\n",
2232
    "    def __call__(self, query, context):\n",
2233
    "        user_content = f\"query: {query}, context: {context}\"\n",
2234
    "        response = generate_response(\n",
2235
    "            llm=self.llm,\n",
2236
    "            temperature=self.temperature,\n",
2237
    "            stream=False,\n",
2238
    "            system_content=self.system_content,\n",
2239
    "            assistant_content=self.assistant_content,\n",
2240
    "            user_content=user_content[: self.context_length])\n",
2241
    "        return response"
2242
   ]
2243
  },
2244
  {
2245
   "cell_type": "code",
2246
   "execution_count": null,
2247
   "id": "63c2db65-36fc-48cb-8e86-18b96554f977",
2248
   "metadata": {
2249
    "tags": []
2250
   },
2251
   "outputs": [],
2252
   "source": [
2253
    "def get_references(data, llm, temperature, system_content, assistant_content, num_samples=None):\n",
2254
    "    # Initialize agent\n",
2255
    "    agent = QueryAgentWithContext(\n",
2256
    "        llm=llm, \n",
2257
    "        temperature=temperature,\n",
2258
    "        system_content=system_content,\n",
2259
    "        assistant_content=assistant_content)\n",
2260
    "    \n",
2261
    "    results = []\n",
2262
    "    for row in tqdm(data[:num_samples]):\n",
2263
    "        # Generate response\n",
2264
    "        query = row[\"question\"]\n",
2265
    "        context = fetch_text(uri=row[\"source\"])\n",
2266
    "        response = agent(query=query, context=context)\n",
2267
    "\n",
2268
    "        # Extract from response\n",
2269
    "        answer, score, reasoning = extract_from_response(response=response)\n",
2270
    "        result = ({\n",
2271
    "                \"question\": query,\n",
2272
    "                \"source\": row[\"source\"],\n",
2273
    "                \"answer\": answer,\n",
2274
    "                \"score\": score,\n",
2275
    "                \"reasoning\": reasoning,\n",
2276
    "            })\n",
2277
    "        results.append(result)\n",
2278
    "        clear_output(wait=True)\n",
2279
    "        display(JSON(json.dumps(result, indent=2)))\n",
2280
    "    return results"
2281
   ]
2282
  },
2283
  {
2284
   "cell_type": "code",
2285
   "execution_count": null,
2286
   "id": "bd52ee15-d3fa-40a0-b3a1-2b28eca8db2e",
2287
   "metadata": {
2288
    "tags": []
2289
   },
2290
   "outputs": [],
2291
   "source": [
2292
    "# Refernces\n",
2293
    "REFERENCES_FILE_PATH = Path(EXPERIMENTS_DIR, \"references\", \"gpt-4.json\")\n",
2294
    "REFERENCES_FILE_PATH.parent.mkdir(parents=True, exist_ok=True)"
2295
   ]
2296
  },
2297
  {
2298
   "cell_type": "code",
2299
   "execution_count": null,
2300
   "id": "25cb4b98-ddc7-462e-8d6f-92f703bfc838",
2301
   "metadata": {
2302
    "tags": []
2303
   },
2304
   "outputs": [
2305
    {
2306
     "data": {
2307
      "application/json": {
2308
       "answer": "You can specify the batch format in the map_batches() function. If you're working with NumPy ndarrays, your function should accept and return a dictionary of ndarrays. If you're working with pandas DataFrames, your function should accept and return a DataFrame. Make sure your function is compatible with the batch format you're using.",
2309
       "question": "I’m struggling a bit with Ray Data type conversions when I do map_batches. Any advice?",
2310
       "reasoning": "The context provides clear instructions on how to handle data type conversions when using the map_batches function in Ray Data. It explains how to configure the batch type and how to ensure your function is compatible with the chosen batch format.",
2311
       "score": 5,
2312
       "source": "https://docs.ray.io/en/master/data/transforming-data.html#configuring-batch-format"
2313
      },
2314
      "text/plain": [
2315
       "<IPython.core.display.JSON object>"
2316
      ]
2317
     },
2318
     "metadata": {
2319
      "application/json": {
2320
       "expanded": false,
2321
       "root": "root"
2322
      }
2323
     },
2324
     "output_type": "display_data"
2325
    },
2326
    {
2327
     "name": "stderr",
2328
     "output_type": "stream",
2329
     "text": [
2330
      "100%|██████████| 1/1 [00:13<00:00, 13.85s/it]"
2331
     ]
2332
    },
2333
    {
2334
     "name": "stdout",
2335
     "output_type": "stream",
2336
     "text": [
2337
      "4.519774011299435\n"
2338
     ]
2339
    },
2340
    {
2341
     "name": "stderr",
2342
     "output_type": "stream",
2343
     "text": [
2344
      "\n"
2345
     ]
2346
    }
2347
   ],
2348
   "source": [
2349
    "# gpt-4\n",
2350
    "results = get_references(\n",
2351
    "    data=data, num_samples=NUM_SAMPLES, llm=\"gpt-4\", temperature=0.0, \n",
2352
    "    system_content=system_content, assistant_content=assistant_content)\n",
2353
    "print (np.mean([float(result[\"score\"]) for result in results if result[\"score\"]]))"
2354
   ]
2355
  },
2356
  {
2357
   "cell_type": "code",
2358
   "execution_count": null,
2359
   "id": "0b8356fb-2a34-4697-b478-36911b582e99",
2360
   "metadata": {
2361
    "tags": []
2362
   },
2363
   "outputs": [],
2364
   "source": [
2365
    "# Save to file\n",
2366
    "with open(REFERENCES_FILE_PATH, \"w\") as fp:\n",
2367
    "    json.dump(results, fp, indent=4)"
2368
   ]
2369
  },
2370
  {
2371
   "cell_type": "markdown",
2372
   "id": "4c522532-98c2-46e9-b262-1747145f34e1",
2373
   "metadata": {},
2374
   "source": [
2375
    "We can now create a dataset with query, source, response, score and reasoning. We can inspect this to determine if our evaluator is of high quality. We found that `gpt-4` was a high quality evaluator based on the scores and reasonings it provided. We performed the same evaluation with other LLMs (ex. `Llama-2-70b`) and we found that they lacked the appropriate reasoning and were very generous with responses from themselves.\n",
2376
    "\n",
2377
    "**Note**: A more thorough evaluation would also test for the following by asking the evaluator to compare responses from different LLMs across the following:\n",
2378
    "- position (which responses we show first) \n",
2379
    "- verbosity (longer responses are favored) \n",
2380
    "- nepotism (ex. GPT4 prefers GPT 3.5, etc.)\n"
2381
   ]
2382
  },
2383
  {
2384
   "cell_type": "code",
2385
   "execution_count": null,
2386
   "id": "dceefac3-e86c-4f2b-96ee-99e3f026692f",
2387
   "metadata": {
2388
    "tags": []
2389
   },
2390
   "outputs": [],
2391
   "source": [
2392
    "EVALUATOR = \"gpt-4\"\n",
2393
    "REFERENCES_FILE_PATH = Path(EXPERIMENTS_DIR, \"references\", \"gpt-4.json\")"
2394
   ]
2395
  },
2396
  {
2397
   "cell_type": "markdown",
2398
   "id": "36a5e137-1f4a-4c31-b1af-0c9e48131f82",
2399
   "metadata": {},
2400
   "source": [
2401
    "## Cold start"
2402
   ]
2403
  },
2404
  {
2405
   "cell_type": "markdown",
2406
   "id": "67a7f87b-9f90-40c7-85df-2e0c287a5144",
2407
   "metadata": {},
2408
   "source": [
2409
    "We may not always have a prepared dataset of questions and the best source to answer that question readily available. To address this cold start problem, we could use an LLM to look at our text chunks and generate questions that the specific chunk would answer. This provides us with quality questions and the exact source the answer is in. However, this dataset generation method could be a bit noisy. The generated questions may not always have high alignment to what our users may ask. And the specific chunk we say is the best source may also have that exact information in other chunks. Nonetheless, this is a great way to start our development process while we collect + manually label a high quality dataset.\n",
2410
    "\n",
2411
    "<img width=\"800\" src=\"https://images.ctfassets.net/xjan103pcp94/3QR9zkjtpgeqK8XKPteTav/76aa9e7743330e7fcf73b07332a7ddf2/image10.png\">"
2412
   ]
2413
  },
2414
  {
2415
   "cell_type": "code",
2416
   "execution_count": null,
2417
   "id": "30844ad6-adb6-43d8-af54-0604e7942816",
2418
   "metadata": {
2419
    "tags": []
2420
   },
2421
   "outputs": [],
2422
   "source": [
2423
    "# Prompt\n",
2424
    "num_questions = 3\n",
2425
    "system_content = f\"\"\"\n",
2426
    "Create {num_questions} questions using only the context provided.\n",
2427
    "End each question with a '?' character and then in a newline write the answer to that question using only the context provided.\n",
2428
    "Separate each question/answer pair by a newline.\n",
2429
    "\"\"\""
2430
   ]
2431
  },
2432
  {
2433
   "cell_type": "code",
2434
   "execution_count": null,
2435
   "id": "51b0e827-5eaa-4c52-b52e-0d46f39a3666",
2436
   "metadata": {
2437
    "tags": []
2438
   },
2439
   "outputs": [],
2440
   "source": [
2441
    "# Generate questions\n",
2442
    "synthetic_data = []\n",
2443
    "for chunk in chunks[:1]:  # small samples\n",
2444
    "    response = generate_response(\n",
2445
    "        llm=\"gpt-4\",\n",
2446
    "        temperature=0.0,\n",
2447
    "        stream=False,\n",
2448
    "        system_content=system_content,\n",
2449
    "        user_content=f\"context: {chunk.page_content}\")\n",
2450
    "    entries = response.split(\"\\n\\n\")\n",
2451
    "    for entry in entries:\n",
2452
    "        question, answer = entry.split(\"\\n\")\n",
2453
    "        synthetic_data.append({\"question\": question, \"source\": chunk.metadata[\"source\"], \"answer\": answer})"
2454
   ]
2455
  },
2456
  {
2457
   "cell_type": "code",
2458
   "execution_count": null,
2459
   "id": "f083eaf6-6c0e-4e8d-a4a2-cc32bea470af",
2460
   "metadata": {
2461
    "tags": []
2462
   },
2463
   "outputs": [
2464
    {
2465
     "data": {
2466
      "text/plain": [
2467
       "[{'question': 'What is the Ray Dashboard used for?',\n",
2468
       "  'source': 'https://docs.ray.io/en/master/ray-observability/getting-started.html#ray-dashboard',\n",
2469
       "  'answer': 'The Ray Dashboard is used for monitoring and debugging Ray applications.'},\n",
2470
       " {'question': 'What does the visual representation of the Ray Dashboard allow users to do?',\n",
2471
       "  'source': 'https://docs.ray.io/en/master/ray-observability/getting-started.html#ray-dashboard',\n",
2472
       "  'answer': 'The visual representation of the Ray Dashboard allows users to track the performance of applications and troubleshoot issues.'},\n",
2473
       " {'question': 'Is the Ray Dashboard web-based?',\n",
2474
       "  'source': 'https://docs.ray.io/en/master/ray-observability/getting-started.html#ray-dashboard',\n",
2475
       "  'answer': 'Yes, the Ray Dashboard is web-based.'}]"
2476
      ]
2477
     },
2478
     "execution_count": null,
2479
     "metadata": {},
2480
     "output_type": "execute_result"
2481
    }
2482
   ],
2483
   "source": [
2484
    "synthetic_data[:3]"
2485
   ]
2486
  },
2487
  {
2488
   "cell_type": "markdown",
2489
   "id": "5b029edf-e018-427f-a1c9-2b4bf689d09c",
2490
   "metadata": {},
2491
   "source": [
2492
    "## Experiments"
2493
   ]
2494
  },
2495
  {
2496
   "cell_type": "markdown",
2497
   "id": "8dea24ed-0a47-47a6-9881-0af1ea8c5691",
2498
   "metadata": {},
2499
   "source": [
2500
    "With our evaluator set, we're ready to start experimenting with the various components in our LLM application. While we could perform this as a large [hyperparameter tuning experiment](https://docs.ray.io/en/latest/tune/index.html), where we can search across promising combinations of values/decisions, we're going to evaluate one decision at a time and set the best value for the next experiment.\n",
2501
    "\n",
2502
    "**Note**: this approach is slightly imperfect because many of our decisions are not indepedent (ex. `chunk_size` and `num_chunks` should ideally be evaluated across many combinations of values).\n",
2503
    "\n",
2504
    "<img width=\"700\" src=\"https://images.ctfassets.net/xjan103pcp94/2LlTUhNFzfLM775IVSxjkX/af49d7b4e0fdd4a482d29cf6eab5067f/image13.png\">"
2505
   ]
2506
  },
2507
  {
2508
   "cell_type": "markdown",
2509
   "id": "f8b51fd3-7a86-41dd-91bb-adcb83bf7269",
2510
   "metadata": {},
2511
   "source": [
2512
    "### Utilities"
2513
   ]
2514
  },
2515
  {
2516
   "cell_type": "markdown",
2517
   "id": "07b3f59f-f79b-4cdb-b55f-6d1d5d245628",
2518
   "metadata": {},
2519
   "source": [
2520
    "Before we start our experiments, we’re going to define a few more utility functions. Our evaluation workflow will use our evaluator to assess the end-to-end quality (`quality_score (overall)`) of our application since the response depends on the retrieved context and the LLM. But we’ll also include a `retrieval_score` to measure the quality of our retrieval process (chunking + embedding). Our logic for determining the `retrieval_score` registers a success if the best source is anywhere in our retrieved num_chunks sources. We don't account for order, exact page section, etc. but we could add those constraints to have a more conservative retrieval score.\n",
2521
    "\n",
2522
    "\n",
2523
    "<img width=\"700\" src=\"https://images.ctfassets.net/xjan103pcp94/2lhpSUNrMmi7WAHpd3wslR/15facf649e30571e8d806d354f475f0b/image6.png\">"
2524
   ]
2525
  },
2526
  {
2527
   "cell_type": "markdown",
2528
   "id": "a42a1ebb-b3d2-48d7-9d9f-bf5231cfe5b8",
2529
   "metadata": {},
2530
   "source": [
2531
    "We'll set where our labeled data and reference reports are located. We'll be using the former to generate responses and the latter dataset to evaluate those responses."
2532
   ]
2533
  },
2534
  {
2535
   "cell_type": "code",
2536
   "execution_count": null,
2537
   "id": "c084e6dd-b8b7-4492-9346-a3455e26e96f",
2538
   "metadata": {
2539
    "tags": []
2540
   },
2541
   "outputs": [],
2542
   "source": [
2543
    "import matplotlib.pyplot as plt\n",
2544
    "from rag.generate import generate_responses\n",
2545
    "from rag.evaluate import evaluate_responses"
2546
   ]
2547
  },
2548
  {
2549
   "cell_type": "markdown",
2550
   "id": "3734e048-033c-46c3-834f-5cc4b380e00a",
2551
   "metadata": {},
2552
   "source": [
2553
    "Let's define a function to determine our retrieval score, which registers a success if the best source is anywhere in our retrieval `num_chunks` sources. We don't account for order, exact page section, etc. but we could add those constraints to have a more conservative retreival score."
2554
   ]
2555
  },
2556
  {
2557
   "cell_type": "code",
2558
   "execution_count": null,
2559
   "id": "73a57c87-8ef0-4d45-9594-34ca07ff7210",
2560
   "metadata": {
2561
    "tags": []
2562
   },
2563
   "outputs": [],
2564
   "source": [
2565
    "def get_retrieval_score(references, generated):\n",
2566
    "    matches = np.zeros(len(references))\n",
2567
    "    for i in range(len(references)):\n",
2568
    "        reference_source = references[i][\"source\"].split(\"#\")[0]\n",
2569
    "        if not reference_source:\n",
2570
    "            matches[i] = 1\n",
2571
    "            continue\n",
2572
    "        for source in generated[i][\"sources\"]:\n",
2573
    "            # sections don't have to perfectly match\n",
2574
    "            if reference_source == source.split(\"#\")[0]:\n",
2575
    "                matches[i] = 1\n",
2576
    "                continue\n",
2577
    "    retrieval_score = np.mean(matches)\n",
2578
    "    return retrieval_score"
2579
   ]
2580
  },
2581
  {
2582
   "cell_type": "markdown",
2583
   "id": "7a7c708e-8fd1-4560-b65f-cf62d0034722",
2584
   "metadata": {},
2585
   "source": [
2586
    "We'll define one encompassing function that will generate and evaluate the responses so that we can run these experiments with one function call. Regardless of what configuration(s) we want to evaluate, we’ll need to first generate responses using that configuration and then evaluate those responses using our evaluator:"
2587
   ]
2588
  },
2589
  {
2590
   "cell_type": "code",
2591
   "execution_count": null,
2592
   "id": "66a76e95-9c1b-488f-bd9f-750a176d3d77",
2593
   "metadata": {
2594
    "tags": []
2595
   },
2596
   "outputs": [],
2597
   "source": [
2598
    "def run_experiment(\n",
2599
    "    experiment_name,\n",
2600
    "    chunk_size, chunk_overlap, num_chunks,\n",
2601
    "    embedding_model_name, embedding_dim,\n",
2602
    "    llm, evaluator,\n",
2603
    "    docs_dir, experiments_dir, references_fp,\n",
2604
    "    system_content=\"Answer the query using the context provided. Be succinct.\",\n",
2605
    "    use_lexical_search=False,\n",
2606
    "    lexical_search_k=1,\n",
2607
    "    use_reranking=False,\n",
2608
    "    rerank_threshold=0.0,\n",
2609
    "    rerank_k=7,\n",
2610
    "    num_samples=None,\n",
2611
    "    sql_dump_fp=None):\n",
2612
    "    \"\"\"Generate responses and evaluate them.\"\"\"\n",
2613
    "    \n",
2614
    "    # Generate responses\n",
2615
    "    generate_responses(\n",
2616
    "        experiment_name=experiment_name, \n",
2617
    "        chunk_size=chunk_size, \n",
2618
    "        chunk_overlap=chunk_overlap, \n",
2619
    "        num_chunks=num_chunks,\n",
2620
    "        embedding_model_name=embedding_model_name,\n",
2621
    "        embedding_dim=embedding_dim,\n",
2622
    "        use_lexical_search=use_lexical_search,\n",
2623
    "        lexical_search_k=lexical_search_k,\n",
2624
    "        use_reranking=use_reranking,\n",
2625
    "        rerank_threshold=rerank_threshold,\n",
2626
    "        rerank_k=rerank_k,\n",
2627
    "        llm=llm, \n",
2628
    "        temperature=0.0, \n",
2629
    "        max_context_length=MAX_CONTEXT_LENGTHS[llm], \n",
2630
    "        system_content=system_content,\n",
2631
    "        assistant_content=\"\",\n",
2632
    "        docs_dir=docs_dir,\n",
2633
    "        experiments_dir=experiments_dir,\n",
2634
    "        references_fp=references_fp,\n",
2635
    "        num_samples=num_samples,\n",
2636
    "        sql_dump_fp=sql_dump_fp)\n",
2637
    "\n",
2638
    "    # Evaluate responses\n",
2639
    "    evaluation_system_content = \"\"\"\n",
2640
    "        Your job is to rate the quality of our generated answer {generated_answer}\n",
2641
    "        given a query {query} and a reference answer {reference_answer}.\n",
2642
    "        Your score has to be between 1 and 5.\n",
2643
    "        You must return your response in a line with only the score.\n",
2644
    "        Do not return answers in any other format.\n",
2645
    "        On a separate line provide your reasoning for the score as well.\n",
2646
    "        \"\"\"\n",
2647
    "    evaluate_responses(\n",
2648
    "        experiment_name=experiment_name,\n",
2649
    "        evaluator=evaluator, \n",
2650
    "        temperature=0.0, \n",
2651
    "        max_context_length=MAX_CONTEXT_LENGTHS[evaluator],\n",
2652
    "        system_content=evaluation_system_content,\n",
2653
    "        assistant_content=\"\",\n",
2654
    "        experiments_dir=experiments_dir,\n",
2655
    "        references_fp=references_fp,\n",
2656
    "        responses_fp=str(Path(experiments_dir, \"responses\", f\"{experiment_name}.json\")),\n",
2657
    "        num_samples=num_samples)"
2658
   ]
2659
  },
2660
  {
2661
   "cell_type": "code",
2662
   "execution_count": null,
2663
   "id": "169201fd-aa78-4459-a55f-04c97f1938aa",
2664
   "metadata": {
2665
    "tags": []
2666
   },
2667
   "outputs": [],
2668
   "source": [
2669
    "def print_experiment(experiment_name, experiments_dir, evaluator=EVALUATOR, verbose=True):\n",
2670
    "    eval_fp = Path(experiments_dir, \"evaluations\", f\"{experiment_name}_{evaluator}.json\")\n",
2671
    "    with open(eval_fp, \"r\") as fp:\n",
2672
    "        d = json.load(fp)\n",
2673
    "    retrieval_score = d[\"retrieval_score\"]\n",
2674
    "    quality_score = d[\"quality_score\"]\n",
2675
    "    if verbose:\n",
2676
    "        print (experiment_name)\n",
2677
    "        print (\"  retrieval score:\", retrieval_score)\n",
2678
    "        print (\"  quality score:\", quality_score)\n",
2679
    "        print ()\n",
2680
    "    return {\"retrieval_score\": retrieval_score, \"quality_score\": quality_score}"
2681
   ]
2682
  },
2683
  {
2684
   "cell_type": "code",
2685
   "execution_count": null,
2686
   "id": "6f44e646-740e-47b7-a878-e15fb432394d",
2687
   "metadata": {
2688
    "tags": []
2689
   },
2690
   "outputs": [],
2691
   "source": [
2692
    "def plot_scores(scores):\n",
2693
    "    # Prepare data for plotting\n",
2694
    "    experiment_names = list(scores.keys())\n",
2695
    "    retrieval_scores = [scores[experiment_name][\"retrieval_score\"] for experiment_name in experiment_names]\n",
2696
    "    quality_scores = [scores[experiment_name][\"quality_score\"] for experiment_name in experiment_names]\n",
2697
    "    \n",
2698
    "    # Plotting\n",
2699
    "    plt.figure(figsize=(10, 3))\n",
2700
    "    for i, experiment_name in enumerate(experiment_names):\n",
2701
    "        plt.scatter(quality_scores[i], retrieval_scores[i], label=experiment_name)\n",
2702
    "        plt.text(quality_scores[i]+0.005, retrieval_scores[i]+0.005, experiment_name, ha=\"right\")\n",
2703
    "        \n",
2704
    "    # Add labels and title\n",
2705
    "    plt.xlabel(\"Quality Score\")\n",
2706
    "    plt.ylabel(\"Retrieval Score\")\n",
2707
    "    plt.legend(title=\"Experiments\")\n",
2708
    "    \n",
2709
    "    # Show the plot\n",
2710
    "    plt.show()"
2711
   ]
2712
  },
2713
  {
2714
   "cell_type": "code",
2715
   "execution_count": null,
2716
   "id": "8accfe8a-5041-4b94-a225-9e34a73ac5b8",
2717
   "metadata": {
2718
    "tags": []
2719
   },
2720
   "outputs": [],
2721
   "source": [
2722
    "llm = \"mistralai/Mixtral-8x7B-Instruct-v0.1\"\n",
2723
    "embedding_model_name = \"thenlper/gte-base\""
2724
   ]
2725
  },
2726
  {
2727
   "cell_type": "markdown",
2728
   "id": "6b31ebee-f839-4aaa-b328-72cee088c830",
2729
   "metadata": {
2730
    "tags": []
2731
   },
2732
   "source": [
2733
    "### Context"
2734
   ]
2735
  },
2736
  {
2737
   "cell_type": "markdown",
2738
   "id": "e89f56c4-2629-4ff7-b57c-f319170937de",
2739
   "metadata": {},
2740
   "source": [
2741
    "We're first going to test if the additonal context we provide is helpful at all. This is to validate that the RAG system is indeed worth the effort. We can do this by settings `num_chunks=0` (no context) and comparing that to `num_chunks=5`."
2742
   ]
2743
  },
2744
  {
2745
   "cell_type": "code",
2746
   "execution_count": null,
2747
   "id": "91cefda6-0ec7-40a2-afc0-b8af2bdd3332",
2748
   "metadata": {
2749
    "tags": []
2750
   },
2751
   "outputs": [],
2752
   "source": [
2753
    "# Without context\n",
2754
    "num_chunks = 0\n",
2755
    "experiment_name = f\"without-context\"\n",
2756
    "embedding_model_name = \"thenlper/gte-base\"\n",
2757
    "run_experiment(\n",
2758
    "    experiment_name=experiment_name, \n",
2759
    "    chunk_size=300, \n",
2760
    "    chunk_overlap=50,\n",
2761
    "    num_chunks=num_chunks,\n",
2762
    "    embedding_model_name=embedding_model_name,\n",
2763
    "    embedding_dim=EMBEDDING_DIMENSIONS[embedding_model_name],\n",
2764
    "    llm=llm,\n",
2765
    "    evaluator=EVALUATOR,\n",
2766
    "    docs_dir=DOCS_DIR, \n",
2767
    "    experiments_dir=EXPERIMENTS_DIR, \n",
2768
    "    references_fp=REFERENCES_FILE_PATH,\n",
2769
    "    num_samples=NUM_SAMPLES)"
2770
   ]
2771
  },
2772
  {
2773
   "cell_type": "code",
2774
   "execution_count": null,
2775
   "id": "a819c26a-9e05-484b-8e57-edfd26a84d4d",
2776
   "metadata": {
2777
    "tags": []
2778
   },
2779
   "outputs": [
2780
    {
2781
     "data": {
2782
      "application/json": {
2783
       "generated_answer": " To kill a specific serve replica, you can follow these steps:\n\n1. Get a handle to the serve replica using `ray.get_actor()` with the appropriate NAME and namespace.\n2. Kill the replica using `ray.kill()` and set `no_restart=True`.\n\nHere is an example:\n```python\nimport ray\n\nreplica_handle = ray.get_actor(\"SERVE_REPLICA::SleepyPid#RlRptP\", namespace=\"serve\")\nray.kill(replica_handle, no_restart=True)\n```\nNote that you can get the NAME of the replica from the output of `ray list actors` command.\n\nIf you are running KubeRay, you can exec into a Ray pod before running these commands.\n\nAlso, you can simulate replica failures by manually killing deployment replicas using the `ray summary actors` command.",
2784
       "question": "how do I kill a specific serve replica",
2785
       "reasoning": "The generated answer is highly detailed and provides a step-by-step guide on how to kill a specific serve replica. It even includes a code example for better understanding. The reference answer does not provide any useful information, making the generated answer superior.",
2786
       "reference_answer": "The context does not provide information on how to kill a specific serve replica.",
2787
       "score": 5,
2788
       "sources": [
2789
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
2790
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#serve-controller-failure",
2791
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
2792
        "https://docs.ray.io/en/master/serve/scaling-and-resource-allocation.html#scaling-horizontally-with-num-replicas",
2793
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#serve-controller-failure"
2794
       ]
2795
      },
2796
      "text/plain": [
2797
       "<IPython.core.display.JSON object>"
2798
      ]
2799
     },
2800
     "metadata": {
2801
      "application/json": {
2802
       "expanded": false,
2803
       "root": "root"
2804
      }
2805
     },
2806
     "output_type": "display_data"
2807
    },
2808
    {
2809
     "name": "stderr",
2810
     "output_type": "stream",
2811
     "text": [
2812
      "100%|██████████| 177/177 [25:46<00:00,  8.74s/it]\n"
2813
     ]
2814
    }
2815
   ],
2816
   "source": [
2817
    "# With context\n",
2818
    "num_chunks = 5\n",
2819
    "experiment_name = \"with-context\"\n",
2820
    "embedding_model_name = \"thenlper/gte-base\"\n",
2821
    "run_experiment(\n",
2822
    "    experiment_name=experiment_name, \n",
2823
    "    chunk_size=300, \n",
2824
    "    chunk_overlap=50, \n",
2825
    "    num_chunks=num_chunks,\n",
2826
    "    embedding_model_name=embedding_model_name,\n",
2827
    "    embedding_dim=EMBEDDING_DIMENSIONS[embedding_model_name],\n",
2828
    "    llm=llm,\n",
2829
    "    evaluator=EVALUATOR,\n",
2830
    "    docs_dir=DOCS_DIR, \n",
2831
    "    experiments_dir=EXPERIMENTS_DIR, \n",
2832
    "    references_fp=REFERENCES_FILE_PATH,\n",
2833
    "    num_samples=NUM_SAMPLES)"
2834
   ]
2835
  },
2836
  {
2837
   "cell_type": "code",
2838
   "execution_count": null,
2839
   "id": "8327cf28-5114-4e93-b769-014108dd743a",
2840
   "metadata": {
2841
    "tags": []
2842
   },
2843
   "outputs": [
2844
    {
2845
     "name": "stdout",
2846
     "output_type": "stream",
2847
     "text": [
2848
      "without-context\n",
2849
      "  retrieval score: 0.0\n",
2850
      "  quality score: 3.194915254237288\n",
2851
      "\n",
2852
      "with-context\n",
2853
      "  retrieval score: 0.5254237288135594\n",
2854
      "  quality score: 3.5112994350282487\n",
2855
      "\n"
2856
     ]
2857
    },
2858
    {
2859
     "data": {
2860
      "image/png": "",
2861
      "text/plain": [
2862
       "<Figure size 1000x300 with 1 Axes>"
2863
      ]
2864
     },
2865
     "metadata": {},
2866
     "output_type": "display_data"
2867
    }
2868
   ],
2869
   "source": [
2870
    "scores = {}\n",
2871
    "for experiment_name in [\"without-context\", \"with-context\"]:\n",
2872
    "    scores[experiment_name] = print_experiment(experiment_name, EXPERIMENTS_DIR)\n",
2873
    "plot_scores(scores=scores)"
2874
   ]
2875
  },
2876
  {
2877
   "cell_type": "markdown",
2878
   "id": "df01f761-d3a7-4783-9fa7-57f62679b548",
2879
   "metadata": {},
2880
   "source": [
2881
    "**Sanity check**: the retrieval score for without-context is zero since we’re using any context.\n",
2882
    "\n",
2883
    "As we can see, using context (RAG) does indeed help in the quality of our answers (and by a meaningful margin)."
2884
   ]
2885
  },
2886
  {
2887
   "cell_type": "markdown",
2888
   "id": "6acc3a24-007d-4add-b0dd-5832351c6d63",
2889
   "metadata": {
2890
    "tags": []
2891
   },
2892
   "source": [
2893
    "### Chunk size"
2894
   ]
2895
  },
2896
  {
2897
   "cell_type": "markdown",
2898
   "id": "11775f23-89ef-432d-a00d-2a11ffaebf67",
2899
   "metadata": {},
2900
   "source": [
2901
    "Next, we'll access various chunk sizes. Smaller chunks (but not too small!) are able to encapsulate atomic concepts which yields more precise retrieval. While larger chunks may be more noisy. Popular strategies include using small chunks but retrieving a bit of the [surrounding chunks](https://gpt-index.readthedocs.io/en/latest/end_to_end_tutorials/dev_practices/production_rag.html#decoupling-chunks-used-for-retrieval-vs-chunks-used-for-synthesis) around it (since it may have relevnat info) or store [mulitple embeddings](https://python.langchain.com/docs/modules/data_connection/retrievers/multi_vector) per document (ex. summary embedding per document)."
2902
   ]
2903
  },
2904
  {
2905
   "cell_type": "code",
2906
   "execution_count": null,
2907
   "id": "64b691b8-b2ea-4f77-857a-09e50053699c",
2908
   "metadata": {
2909
    "tags": []
2910
   },
2911
   "outputs": [],
2912
   "source": [
2913
    "experiment_names = []\n",
2914
    "chunk_sizes = [100, 300, 500, 700, 900]"
2915
   ]
2916
  },
2917
  {
2918
   "cell_type": "code",
2919
   "execution_count": null,
2920
   "id": "53c054d1-9e4d-4688-b4b9-cf5a39b382c3",
2921
   "metadata": {
2922
    "tags": []
2923
   },
2924
   "outputs": [
2925
    {
2926
     "data": {
2927
      "application/json": {
2928
       "generated_answer": " To kill a specific serve replica, you can follow these steps:\n\n1. First, you need to get the handle of the replica you want to kill. You can do this by using the `ray.get_actor()` function and providing the ACTOR\\_ID of the replica. In your case, the ACTOR\\_ID is \"SERVE\\_REPLICA::SleepyPid#RlRptP\".\n2. Once you have the replica handle, you can use the `ray.kill()` function to kill the replica. Set the `no_restart` parameter to `True` to prevent the replica from automatically restarting.\n\nHere's an example command to kill the specific replica:\n```python\nreplica_handle = ray.get_actor(\"SERVE_REPLICA::SleepyPid#RlRptP\", namespace=\"serve\")\nray.kill(replica_handle, no_restart=True)\n```\nNote that you can get the list of all actors and their corresponding ACTOR\\_IDs by running the command `ray list actors`.",
2929
       "question": "how do I kill a specific serve replica",
2930
       "reasoning": "The generated answer is detailed, accurate, and provides a step-by-step guide on how to kill a specific serve replica. It also includes a code snippet for better understanding. The reference answer does not provide any useful information, making the generated answer superior.",
2931
       "reference_answer": "The context does not provide information on how to kill a specific serve replica.",
2932
       "score": 5,
2933
       "sources": [
2934
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
2935
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#serve-controller-failure",
2936
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#serve-controller-failure",
2937
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
2938
        "https://docs.ray.io/en/master/serve/production-guide/kubernetes.html#next-steps"
2939
       ]
2940
      },
2941
      "text/plain": [
2942
       "<IPython.core.display.JSON object>"
2943
      ]
2944
     },
2945
     "metadata": {
2946
      "application/json": {
2947
       "expanded": false,
2948
       "root": "root"
2949
      }
2950
     },
2951
     "output_type": "display_data"
2952
    },
2953
    {
2954
     "name": "stderr",
2955
     "output_type": "stream",
2956
     "text": [
2957
      "100%|██████████| 177/177 [24:52<00:00,  8.43s/it]\n"
2958
     ]
2959
    }
2960
   ],
2961
   "source": [
2962
    "for chunk_size in chunk_sizes:\n",
2963
    "    experiment_name = f\"chunk-size-{chunk_size}\"\n",
2964
    "    experiment_names.append(experiment_name)\n",
2965
    "    run_experiment(\n",
2966
    "        experiment_name=experiment_name, \n",
2967
    "        chunk_size=chunk_size, \n",
2968
    "        chunk_overlap=50, \n",
2969
    "        num_chunks=5,\n",
2970
    "        embedding_model_name=embedding_model_name,\n",
2971
    "        embedding_dim=EMBEDDING_DIMENSIONS[embedding_model_name],\n",
2972
    "        llm=llm,\n",
2973
    "        evaluator=EVALUATOR,\n",
2974
    "        docs_dir=DOCS_DIR, \n",
2975
    "        experiments_dir=EXPERIMENTS_DIR, \n",
2976
    "        references_fp=REFERENCES_FILE_PATH,\n",
2977
    "        num_samples=NUM_SAMPLES)"
2978
   ]
2979
  },
2980
  {
2981
   "cell_type": "code",
2982
   "execution_count": null,
2983
   "id": "5155e8b6-5a52-4154-b05e-12af5a0e413d",
2984
   "metadata": {
2985
    "tags": []
2986
   },
2987
   "outputs": [
2988
    {
2989
     "name": "stdout",
2990
     "output_type": "stream",
2991
     "text": [
2992
      "chunk-size-100\n",
2993
      "  retrieval score: 0.4180790960451977\n",
2994
      "  quality score: 3.288135593220339\n",
2995
      "\n",
2996
      "chunk-size-300\n",
2997
      "  retrieval score: 0.5254237288135594\n",
2998
      "  quality score: 3.531073446327684\n",
2999
      "\n",
3000
      "chunk-size-500\n",
3001
      "  retrieval score: 0.5480225988700564\n",
3002
      "  quality score: 3.6271186440677967\n",
3003
      "\n",
3004
      "chunk-size-700\n",
3005
      "  retrieval score: 0.519774011299435\n",
3006
      "  quality score: 3.76271186440678\n",
3007
      "\n",
3008
      "chunk-size-900\n",
3009
      "  retrieval score: 0.5706214689265536\n",
3010
      "  quality score: 3.6779661016949152\n",
3011
      "\n"
3012
     ]
3013
    },
3014
    {
3015
     "data": {
3016
      "image/png": "",
3017
      "text/plain": [
3018
       "<Figure size 1000x300 with 1 Axes>"
3019
      ]
3020
     },
3021
     "metadata": {},
3022
     "output_type": "display_data"
3023
    }
3024
   ],
3025
   "source": [
3026
    "scores = {}\n",
3027
    "for experiment_name in experiment_names:\n",
3028
    "    scores[experiment_name] = print_experiment(experiment_name, EXPERIMENTS_DIR)\n",
3029
    "plot_scores(scores=scores)"
3030
   ]
3031
  },
3032
  {
3033
   "cell_type": "markdown",
3034
   "id": "680569cb-d1b3-464e-b771-1f5dd3d0cc66",
3035
   "metadata": {},
3036
   "source": [
3037
    "It appears that larger chunk sizes do help but tapers off (too much context might be too noisy). Larger chunk sizes [aren’t always better](https://arxiv.org/abs/2307.03172).\n",
3038
    "\n",
3039
    "**Note**: If we were to use larger chunk sizes (ours is based on characters), keep in mind that [most](https://huggingface.co/spaces/mteb/leaderboard) open source embedding models have a maximum sequence length of 512 sub-word tokens. This means that if our chunk contains more than 512 sub-word tokens (4 chars ≈ 1 token), the embedding wouldn't account for it anyway (unless we finetune our embedding model to have longer sequence lengths)."
3040
   ]
3041
  },
3042
  {
3043
   "cell_type": "code",
3044
   "execution_count": null,
3045
   "id": "04df6ec4-7edf-4a27-93ae-7ee2b3ff7241",
3046
   "metadata": {
3047
    "tags": []
3048
   },
3049
   "outputs": [],
3050
   "source": [
3051
    "CHUNK_SIZE = 700\n",
3052
    "CHUNK_OVERLAP = 50"
3053
   ]
3054
  },
3055
  {
3056
   "cell_type": "markdown",
3057
   "id": "fc80dd05-ced6-49b4-a193-c52fcebc118e",
3058
   "metadata": {
3059
    "tags": []
3060
   },
3061
   "source": [
3062
    "### Number of chunks"
3063
   ]
3064
  },
3065
  {
3066
   "cell_type": "markdown",
3067
   "id": "84fbb854-b016-4c42-97bc-56c7eebfa3dd",
3068
   "metadata": {},
3069
   "source": [
3070
    "Next, we'll experiment with the number of chunks to use. More chunks will allow us to add more context but too many could potentially introduce a lot of noise.\n",
3071
    "\n",
3072
    "**Note**: The `chunk_size` we chose multiplied by the `num_chunks` below fits inside the LLM's context length. We're experimenting with the chunk size and number of chunks as if they were indepdent variables but they area heavily related. Especially since all of our LLMs have a finite maximum context length. So ideally, we would tune for a combination if `chunk_size` * `num_chunks`."
3073
   ]
3074
  },
3075
  {
3076
   "cell_type": "code",
3077
   "execution_count": null,
3078
   "id": "9d74ca9d-8f04-4e91-a927-87820acb1aed",
3079
   "metadata": {
3080
    "tags": []
3081
   },
3082
   "outputs": [],
3083
   "source": [
3084
    "experiment_names = []\n",
3085
    "num_chunks_list = [1, 3, 5, 7, 9, 11, 13, 15]"
3086
   ]
3087
  },
3088
  {
3089
   "cell_type": "code",
3090
   "execution_count": null,
3091
   "id": "0136db87-bb7b-4d7f-ad7a-549c1e6fd489",
3092
   "metadata": {
3093
    "tags": []
3094
   },
3095
   "outputs": [],
3096
   "source": [
3097
    "experiment_names = []\n",
3098
    "num_chunks_list = [14]"
3099
   ]
3100
  },
3101
  {
3102
   "cell_type": "code",
3103
   "execution_count": null,
3104
   "id": "3f09d999-3e02-46c7-a48f-eb9be5edc37a",
3105
   "metadata": {
3106
    "tags": []
3107
   },
3108
   "outputs": [
3109
    {
3110
     "data": {
3111
      "text/plain": [
3112
       "PosixPath('/home/ray/ray-assistant/notebooks/../experiments/references/gpt-4-turbo.json')"
3113
      ]
3114
     },
3115
     "execution_count": null,
3116
     "metadata": {},
3117
     "output_type": "execute_result"
3118
    }
3119
   ],
3120
   "source": [
3121
    "REFERENCES_FILE_PATH"
3122
   ]
3123
  },
3124
  {
3125
   "cell_type": "code",
3126
   "execution_count": null,
3127
   "id": "c514d9c7-e16d-44c4-88b7-72a49b5c4197",
3128
   "metadata": {
3129
    "tags": []
3130
   },
3131
   "outputs": [
3132
    {
3133
     "data": {
3134
      "application/json": {
3135
       "generated_answer": " To kill a specific serve replica in Ray, you can follow these steps:\n\n1. Get a list of actors in Ray, including the replicas, by running `ray list actors` in the Python interpreter.\n2. Identify the NAME of the replica you want to kill from the list.\n3. Get a handle to the replica by running `replica_handle = ray.get_actor(\"NAME\", namespace=\"serve\")` in the Python interpreter.\n4. Kill the replica by running `ray.kill(replica_handle, no_restart=True)` in the Python interpreter.\n\nFor example, if you have a replica with the NAME \"SERVE\\_REPLICA::SleepyPid#PHvdJW\", you can kill it by running:\n```makefile\n$ python\n>>> import ray\n>>> replica_handle = ray.get_actor(\"SERVE_REPLICA::SleepyPid#PHvdJW\", namespace=\"serve\")\n>>> ray.kill(replica_handle, no_restart=True)\n>>> exit()\n```",
3136
       "question": "how do I kill a specific serve replica",
3137
       "reasoning": "",
3138
       "reference_answer": "The context provided does not contain the necessary information to answer the query.",
3139
       "score": 4,
3140
       "sources": [
3141
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
3142
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#serve-controller-failure",
3143
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
3144
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#head-node-failure",
3145
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#serve-controller-failure",
3146
        "https://docs.ray.io/en/master/serve/advanced-guides/dev-workflow.html#local-development-with-http-requests",
3147
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
3148
        "https://docs.ray.io/en/master/serve/scaling-and-resource-allocation.html#autoscaling",
3149
        "https://docs.ray.io/en/master/serve/production-guide/kubernetes.html#next-steps",
3150
        "https://docs.ray.io/en/master/serve/production-guide/config.html#serve-config-files-serve-build",
3151
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#httpproxy-failure",
3152
        "https://docs.ray.io/en/master/serve/advanced-guides/dev-workflow.html#local-development-with-http-requests",
3153
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#replica-health-checking",
3154
        "https://docs.ray.io/en/master/serve/architecture.html#ray-serve-autoscaling"
3155
       ]
3156
      },
3157
      "text/plain": [
3158
       "<IPython.core.display.JSON object>"
3159
      ]
3160
     },
3161
     "metadata": {
3162
      "application/json": {
3163
       "expanded": false,
3164
       "root": "root"
3165
      }
3166
     },
3167
     "output_type": "display_data"
3168
    },
3169
    {
3170
     "name": "stderr",
3171
     "output_type": "stream",
3172
     "text": [
3173
      "100%|██████████| 177/177 [21:21<00:00,  7.24s/it]\n"
3174
     ]
3175
    }
3176
   ],
3177
   "source": [
3178
    "for num_chunks in num_chunks_list:\n",
3179
    "    experiment_name = f\"num-chunks-{num_chunks}\"\n",
3180
    "    experiment_names.append(experiment_name)\n",
3181
    "    run_experiment(\n",
3182
    "        experiment_name=experiment_name,\n",
3183
    "        chunk_size=CHUNK_SIZE, \n",
3184
    "        chunk_overlap=CHUNK_OVERLAP, \n",
3185
    "        num_chunks=num_chunks,\n",
3186
    "        embedding_model_name=embedding_model_name,\n",
3187
    "        embedding_dim=EMBEDDING_DIMENSIONS[embedding_model_name],\n",
3188
    "        llm=llm,\n",
3189
    "        evaluator=EVALUATOR,\n",
3190
    "        docs_dir=DOCS_DIR, \n",
3191
    "        experiments_dir=EXPERIMENTS_DIR, \n",
3192
    "        references_fp=REFERENCES_FILE_PATH,\n",
3193
    "        num_samples=NUM_SAMPLES)"
3194
   ]
3195
  },
3196
  {
3197
   "cell_type": "code",
3198
   "execution_count": null,
3199
   "id": "d688ed6b-2e32-4c80-b83f-e5ebbca88ea6",
3200
   "metadata": {
3201
    "tags": []
3202
   },
3203
   "outputs": [
3204
    {
3205
     "name": "stdout",
3206
     "output_type": "stream",
3207
     "text": [
3208
      "num-chunks-1\n",
3209
      "  retrieval score: 0.2542372881355932\n",
3210
      "  quality score: 3.4237288135593222\n",
3211
      "\n",
3212
      "num-chunks-3\n",
3213
      "  retrieval score: 0.4689265536723164\n",
3214
      "  quality score: 3.6299435028248586\n",
3215
      "\n",
3216
      "num-chunks-5\n",
3217
      "  retrieval score: 0.519774011299435\n",
3218
      "  quality score: 3.709039548022599\n",
3219
      "\n",
3220
      "num-chunks-7\n",
3221
      "  retrieval score: 0.6271186440677966\n",
3222
      "  quality score: 3.7937853107344632\n",
3223
      "\n",
3224
      "num-chunks-9\n",
3225
      "  retrieval score: 0.6836158192090396\n",
3226
      "  quality score: 3.8983050847457625\n",
3227
      "\n",
3228
      "num-chunks-11\n",
3229
      "  retrieval score: 0.7175141242937854\n",
3230
      "  quality score: 3.9180790960451977\n",
3231
      "\n",
3232
      "num-chunks-13\n",
3233
      "  retrieval score: 0.7570621468926554\n",
3234
      "  quality score: 3.983050847457627\n",
3235
      "\n",
3236
      "num-chunks-15\n",
3237
      "  retrieval score: 0.7740112994350282\n",
3238
      "  quality score: 3.983050847457627\n",
3239
      "\n"
3240
     ]
3241
    },
3242
    {
3243
     "data": {
3244
      "image/png": "",
3245
      "text/plain": [
3246
       "<Figure size 1000x300 with 1 Axes>"
3247
      ]
3248
     },
3249
     "metadata": {},
3250
     "output_type": "display_data"
3251
    }
3252
   ],
3253
   "source": [
3254
    "scores = {}\n",
3255
    "for experiment_name in experiment_names:\n",
3256
    "    scores[experiment_name] = print_experiment(experiment_name, EXPERIMENTS_DIR)\n",
3257
    "plot_scores(scores=scores)"
3258
   ]
3259
  },
3260
  {
3261
   "cell_type": "markdown",
3262
   "id": "ee986a80-8659-4344-bb22-71bb62b32946",
3263
   "metadata": {},
3264
   "source": [
3265
    "Increasing our number of chunks improves our retrieval and quality scores up to a certain point. However, for some models (ex. `llama-2`), the context length is much smaller so we won't be able to use as many chunks. For all of these models, it's worth investing in extending context size via RoPE scaling (rotary position embeddings), etc."
3266
   ]
3267
  },
3268
  {
3269
   "cell_type": "code",
3270
   "execution_count": null,
3271
   "id": "fbe4828f-9639-4957-898e-27dd0ce3ee32",
3272
   "metadata": {
3273
    "tags": []
3274
   },
3275
   "outputs": [],
3276
   "source": [
3277
    "NUM_CHUNKS = 13"
3278
   ]
3279
  },
3280
  {
3281
   "cell_type": "markdown",
3282
   "id": "1a04dff3-5323-419f-a290-849c96899292",
3283
   "metadata": {},
3284
   "source": [
3285
    "### Embedding models"
3286
   ]
3287
  },
3288
  {
3289
   "cell_type": "markdown",
3290
   "id": "df10c471-22b5-479c-bbbd-59ff3835d7b9",
3291
   "metadata": {},
3292
   "source": [
3293
    "So far, we've used [thenlper/gte-base](https://huggingface.co/thenlper/gte-base) as our embedding model because it's a relatively small (0.22 GB) and performant option. But now, let's explore other popular options such as the current leader on the [MTEB leaderboard](https://huggingface.co/spaces/mteb/leaderboard), [BAAI/bge-large-en](https://huggingface.co/BAAI/bge-large-en) (1.34 GB), [thenlper/gte-large](https://huggingface.co/thenlper/gte-large) (a larger version of `gte-base`), and OpenAI's [text-embedding-ada-002](https://openai.com/blog/new-and-improved-embedding-model)."
3294
   ]
3295
  },
3296
  {
3297
   "cell_type": "code",
3298
   "execution_count": null,
3299
   "id": "198ec597-8aaf-4c45-a275-2094211eebb4",
3300
   "metadata": {
3301
    "tags": []
3302
   },
3303
   "outputs": [],
3304
   "source": [
3305
    "experiment_names = []\n",
3306
    "embedding_model_names = [\"thenlper/gte-base\", \"thenlper/gte-large\", \"BAAI/bge-large-en\", \"text-embedding-ada-002\"]"
3307
   ]
3308
  },
3309
  {
3310
   "cell_type": "code",
3311
   "execution_count": null,
3312
   "id": "1913f50c-ef13-487d-beeb-77ee38f91067",
3313
   "metadata": {
3314
    "tags": []
3315
   },
3316
   "outputs": [
3317
    {
3318
     "data": {
3319
      "application/json": {
3320
       "generated_answer": "  To kill a specific serve replica, you can use the `ray kill` command with the actor ID of the replica. You can get the actor ID by running `ray list actors` and filtering the output by the class name of the replica. For example, if you want to kill a replica with the class name `ServeReplica:SleepyPid`, you can run `ray kill <actor_id>`.",
3321
       "question": "how do I kill a specific serve replica",
3322
       "reasoning": "The generated answer is detailed and provides a step-by-step guide on how to kill a specific serve replica, which is exactly what the question asked for. The reference answer does not provide any useful information, so the generated answer is much better in this case.",
3323
       "reference_answer": "The context does not provide information on how to kill a specific serve replica.",
3324
       "score": 5,
3325
       "sources": [
3326
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
3327
        "https://docs.ray.io/en/master/serve/advanced-guides/dev-workflow.html#local-development-with-http-requests",
3328
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#serve-controller-failure",
3329
        "https://docs.ray.io/en/master/serve/advanced-guides/dev-workflow.html#local-development-with-http-requests",
3330
        "https://docs.ray.io/en/master/ray-observability/getting-started.html#serve-application-detail-page",
3331
        "https://docs.ray.io/en/master/cluster/kubernetes/troubleshooting/rayservice-troubleshooting.html#method-5-ray-state-cli",
3332
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
3333
        "https://docs.ray.io/en/master/serve/monitoring.html#ray-dashboard",
3334
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#httpproxy-failure"
3335
       ]
3336
      },
3337
      "text/plain": [
3338
       "<IPython.core.display.JSON object>"
3339
      ]
3340
     },
3341
     "metadata": {
3342
      "application/json": {
3343
       "expanded": false,
3344
       "root": "root"
3345
      }
3346
     },
3347
     "output_type": "display_data"
3348
    },
3349
    {
3350
     "name": "stderr",
3351
     "output_type": "stream",
3352
     "text": [
3353
      "100%|██████████| 177/177 [20:57<00:00,  7.10s/it]\n"
3354
     ]
3355
    }
3356
   ],
3357
   "source": [
3358
    "for embedding_model_name in embedding_model_names:\n",
3359
    "    experiment_name = f\"{embedding_model_name.split('/')[-1]}\"\n",
3360
    "    experiment_names.append(experiment_name)\n",
3361
    "    run_experiment(\n",
3362
    "        experiment_name=experiment_name, \n",
3363
    "        chunk_size=CHUNK_SIZE, \n",
3364
    "        chunk_overlap=CHUNK_OVERLAP, \n",
3365
    "        num_chunks=NUM_CHUNKS,\n",
3366
    "        embedding_model_name=embedding_model_name,\n",
3367
    "        embedding_dim=EMBEDDING_DIMENSIONS[embedding_model_name],\n",
3368
    "        llm=llm,\n",
3369
    "        evaluator=EVALUATOR,\n",
3370
    "        docs_dir=DOCS_DIR, \n",
3371
    "        experiments_dir=EXPERIMENTS_DIR, \n",
3372
    "        references_fp=REFERENCES_FILE_PATH,\n",
3373
    "        num_samples=NUM_SAMPLES)"
3374
   ]
3375
  },
3376
  {
3377
   "cell_type": "code",
3378
   "execution_count": null,
3379
   "id": "60ec4e25-b11c-4ee1-a9a9-ce29eb6dc81e",
3380
   "metadata": {
3381
    "tags": []
3382
   },
3383
   "outputs": [
3384
    {
3385
     "name": "stdout",
3386
     "output_type": "stream",
3387
     "text": [
3388
      "gte-base\n",
3389
      "  retrieval score: 0.7570621468926554\n",
3390
      "  quality score: 3.9322033898305087\n",
3391
      "\n",
3392
      "gte-large\n",
3393
      "  retrieval score: 0.7796610169491526\n",
3394
      "  quality score: 3.9350282485875705\n",
3395
      "\n",
3396
      "bge-large-en\n",
3397
      "  retrieval score: 0.4745762711864407\n",
3398
      "  quality score: 3.480225988700565\n",
3399
      "\n",
3400
      "text-embedding-ada-002\n",
3401
      "  retrieval score: 0.6497175141242938\n",
3402
      "  quality score: 3.5395480225988702\n",
3403
      "\n"
3404
     ]
3405
    },
3406
    {
3407
     "data": {
3408
      "image/png": "",
3409
      "text/plain": [
3410
       "<Figure size 1000x300 with 1 Axes>"
3411
      ]
3412
     },
3413
     "metadata": {},
3414
     "output_type": "display_data"
3415
    }
3416
   ],
3417
   "source": [
3418
    "scores = {}\n",
3419
    "for experiment_name in experiment_names:\n",
3420
    "    scores[experiment_name] = print_experiment(experiment_name, EXPERIMENTS_DIR)\n",
3421
    "plot_scores(scores=scores)"
3422
   ]
3423
  },
3424
  {
3425
   "cell_type": "markdown",
3426
   "id": "7fd24b35-1db3-4326-ab0c-c4b484fb5aea",
3427
   "metadata": {},
3428
   "source": [
3429
    "This is an interesting outcome because the #1 (`BAAI/bge-large-en`) on the current [leaderboard](https://huggingface.co/spaces/mteb/leaderboard) isn't necessarily the best for our specific task. Using the smaller `thenlper/gte-large` produced the best retrieval and quality scores in our experiments."
3430
   ]
3431
  },
3432
  {
3433
   "cell_type": "code",
3434
   "execution_count": null,
3435
   "id": "845ad771-65e1-44cf-813f-3aa167c07e31",
3436
   "metadata": {
3437
    "tags": []
3438
   },
3439
   "outputs": [],
3440
   "source": [
3441
    "EMBEDDING_MODEL_NAME = \"thenlper/gte-large\""
3442
   ]
3443
  },
3444
  {
3445
   "cell_type": "markdown",
3446
   "id": "8b21b32f-bacb-4703-b16c-d4a7014779dc",
3447
   "metadata": {},
3448
   "source": [
3449
    "### OSS vs. closed LLMs"
3450
   ]
3451
  },
3452
  {
3453
   "cell_type": "markdown",
3454
   "id": "f393785f-17da-45eb-bf69-1483f74a370e",
3455
   "metadata": {},
3456
   "source": [
3457
    "We're now going to use the best configurations from above to evaluate different choices for the main LLM.\n",
3458
    "\n",
3459
    "**Note**:\n",
3460
    "- We've been using a specific LLM so far to decide on the configuration so that specific LLM's performance here will be a bit biased.\n",
3461
    "- This list is not exhaustive and even for the LLMs we use, there are versions with longer context windows available."
3462
   ]
3463
  },
3464
  {
3465
   "cell_type": "code",
3466
   "execution_count": null,
3467
   "id": "a91c87cb-ba0d-4044-9616-b2cbad239587",
3468
   "metadata": {
3469
    "tags": []
3470
   },
3471
   "outputs": [],
3472
   "source": [
3473
    "experiment_names = []\n",
3474
    "llms = [\"gpt-3.5-turbo\",\n",
3475
    "        \"gpt-4\",\n",
3476
    "        \"gpt-4-1106-preview\",\n",
3477
    "        \"meta-llama/Llama-2-7b-chat-hf\", \n",
3478
    "        \"meta-llama/Llama-2-13b-chat-hf\", \n",
3479
    "        \"meta-llama/Llama-2-70b-chat-hf\",\n",
3480
    "        \"codellama/CodeLlama-34b-Instruct-hf\",\n",
3481
    "        \"mistralai/Mistral-7B-Instruct-v0.1\",\n",
3482
    "        \"mistralai/Mixtral-8x7B-Instruct-v0.1\"]"
3483
   ]
3484
  },
3485
  {
3486
   "cell_type": "code",
3487
   "execution_count": null,
3488
   "id": "a60c0aed-86da-4529-96be-f10088260363",
3489
   "metadata": {},
3490
   "outputs": [],
3491
   "source": [
3492
    "# Some LLMs have much shorter context lengths (ex Llama-2 has 4096) so we'll restrict the number of chunks to a minimum.\n",
3493
    "minimum_num_chunks = 9"
3494
   ]
3495
  },
3496
  {
3497
   "cell_type": "code",
3498
   "execution_count": null,
3499
   "id": "072e8114-23ba-402e-a03d-594089e9b4d5",
3500
   "metadata": {
3501
    "tags": []
3502
   },
3503
   "outputs": [
3504
    {
3505
     "data": {
3506
      "application/json": {
3507
       "generated_answer": "To kill a specific serve replica in Ray Serve, you can follow these steps:\n\n1. Obtain a handle to the replica actor using `ray.get_actor()` with the replica's unique name and the namespace it belongs to.\n2. Use `ray.kill()` to kill the replica actor, passing the handle obtained in the previous step. You can set `no_restart=True` if you do not want the replica to be automatically restarted.\n\nHere is an example code snippet based on the context provided:\n\n```python\nimport ray\n\n# Get a handle to the specific serve replica actor\nreplica_handle = ray.get_actor(\"SERVE_REPLICA::SleepyPid#RlRptP\", namespace=\"serve\")\n\n# Kill the replica actor without restarting it\nray.kill(replica_handle, no_restart=True)\n```\n\nAfter running this code, the specific serve replica will be killed, and if `no_restart` is set to `True`, it will not be restarted. Other replicas will continue processing requests.",
3508
       "question": "how do I kill a specific serve replica",
3509
       "reasoning": "The generated answer provides a detailed and technically accurate response to the query, including a step-by-step guide and a code snippet. However, it does not perfectly match the reference answer, which states that the context does not provide information on how to kill a specific serve replica. Despite this, the generated answer is still highly informative and useful.",
3510
       "reference_answer": "The context does not provide information on how to kill a specific serve replica.",
3511
       "score": 4,
3512
       "sources": [
3513
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
3514
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
3515
        "https://docs.ray.io/en/master/serve/advanced-guides/dev-workflow.html#local-development-with-http-requests",
3516
        "https://docs.ray.io/en/master/serve/api/index.html#delete-api-serve-deployments",
3517
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
3518
        "https://docs.ray.io/en/master/serve/scaling-and-resource-allocation.html#autoscaling",
3519
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#worker-node-failure",
3520
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#httpproxy-failure",
3521
        "https://docs.ray.io/en/master/serve/api/index.html#delete-api-serve-applications"
3522
       ]
3523
      },
3524
      "text/plain": [
3525
       "<IPython.core.display.JSON object>"
3526
      ]
3527
     },
3528
     "metadata": {
3529
      "application/json": {
3530
       "expanded": false,
3531
       "root": "root"
3532
      }
3533
     },
3534
     "output_type": "display_data"
3535
    },
3536
    {
3537
     "name": "stderr",
3538
     "output_type": "stream",
3539
     "text": [
3540
      "100%|██████████| 177/177 [13:18<00:00,  4.51s/it]\n"
3541
     ]
3542
    }
3543
   ],
3544
   "source": [
3545
    "for llm in llms:\n",
3546
    "    experiment_name = f\"{llm.split('/')[-1].lower()}\"\n",
3547
    "    experiment_names.append(experiment_name)\n",
3548
    "    run_experiment(\n",
3549
    "        experiment_name=experiment_name, \n",
3550
    "        chunk_size=CHUNK_SIZE, \n",
3551
    "        chunk_overlap=CHUNK_OVERLAP, \n",
3552
    "        num_chunks=minimum_num_chunks,\n",
3553
    "        embedding_model_name=EMBEDDING_MODEL_NAME,\n",
3554
    "        embedding_dim=EMBEDDING_DIMENSIONS[EMBEDDING_MODEL_NAME],\n",
3555
    "        llm=llm,\n",
3556
    "        evaluator=EVALUATOR,\n",
3557
    "        docs_dir=DOCS_DIR, \n",
3558
    "        experiments_dir=EXPERIMENTS_DIR, \n",
3559
    "        references_fp=REFERENCES_FILE_PATH,\n",
3560
    "        num_samples=NUM_SAMPLES)"
3561
   ]
3562
  },
3563
  {
3564
   "cell_type": "code",
3565
   "execution_count": null,
3566
   "id": "a8afbfe1-b129-408a-8ce0-8f321c585174",
3567
   "metadata": {
3568
    "tags": []
3569
   },
3570
   "outputs": [
3571
    {
3572
     "name": "stdout",
3573
     "output_type": "stream",
3574
     "text": [
3575
      "gpt-3.5-turbo\n",
3576
      "  retrieval score: 0.7288135593220338\n",
3577
      "  quality score: 3.559322033898305\n",
3578
      "\n",
3579
      "gpt-4\n",
3580
      "  retrieval score: 0.7288135593220338\n",
3581
      "  quality score: 3.8728813559322033\n",
3582
      "\n",
3583
      "gpt-4-1106-preview\n",
3584
      "  retrieval score: 0.7288135593220338\n",
3585
      "  quality score: 4.209039548022599\n",
3586
      "\n",
3587
      "llama-2-7b-chat-hf\n",
3588
      "  retrieval score: 0.7288135593220338\n",
3589
      "  quality score: 3.2966101694915255\n",
3590
      "\n",
3591
      "llama-2-13b-chat-hf\n",
3592
      "  retrieval score: 0.7288135593220338\n",
3593
      "  quality score: 3.4152542372881354\n",
3594
      "\n",
3595
      "llama-2-70b-chat-hf\n",
3596
      "  retrieval score: 0.7288135593220338\n",
3597
      "  quality score: 3.598870056497175\n",
3598
      "\n",
3599
      "codellama-34b-instruct-hf\n",
3600
      "  retrieval score: 0.7288135593220338\n",
3601
      "  quality score: 3.593220338983051\n",
3602
      "\n",
3603
      "mistral-7b-instruct-v0.1\n",
3604
      "  retrieval score: 0.7288135593220338\n",
3605
      "  quality score: 3.440677966101695\n",
3606
      "\n",
3607
      "mixtral-8x7b-instruct-v0.1\n",
3608
      "  retrieval score: 0.7288135593220338\n",
3609
      "  quality score: 3.943502824858757\n",
3610
      "\n"
3611
     ]
3612
    }
3613
   ],
3614
   "source": [
3615
    "scores = {}\n",
3616
    "for experiment_name in experiment_names:\n",
3617
    "    scores[experiment_name] = print_experiment(experiment_name, EXPERIMENTS_DIR)"
3618
   ]
3619
  },
3620
  {
3621
   "cell_type": "markdown",
3622
   "id": "9add6de4-bc41-4f94-ba5a-edc3d81309a4",
3623
   "metadata": {
3624
    "tags": []
3625
   },
3626
   "source": [
3627
    "**Sanity check**: the retrieval scores are all the same because the LLM we choose doesn’t impact that part of our application."
3628
   ]
3629
  },
3630
  {
3631
   "cell_type": "markdown",
3632
   "id": "97d1cd68-f77a-4f13-9454-3add1fc65158",
3633
   "metadata": {},
3634
   "source": [
3635
    "`mixtral-8x7b-instruct-v0.1` outperforms the other OSS LLMs and even the current `gpt-4` (currently 0613) and not too far behind `gpt-4-turbo` (currently 1106-preview)."
3636
   ]
3637
  },
3638
  {
3639
   "cell_type": "code",
3640
   "execution_count": null,
3641
   "id": "0275cb71-6876-404a-bdbe-f79347162696",
3642
   "metadata": {
3643
    "tags": []
3644
   },
3645
   "outputs": [],
3646
   "source": [
3647
    "LLM = \"mistralai/Mixtral-8x7B-Instruct-v0.1\""
3648
   ]
3649
  },
3650
  {
3651
   "cell_type": "markdown",
3652
   "id": "5ab5aca0-212b-4e82-97af-1ed95130ebc6",
3653
   "metadata": {},
3654
   "source": [
3655
    "**Note**: Some of our LLMs have much larger context lengths, ex. `gpt-4` is 8192 tokens and `gpt-3.5-turbo-16k` is 16,384 tokens. We could increase the number of chunks that we use for these since we saw that increasing `num_chunks` continued to improve the retrieval and quality scores. However, we will keep this value fixed for now since the performance started to taper off anyway and so we can compare these performances under the exact same configurations. "
3656
   ]
3657
  },
3658
  {
3659
   "cell_type": "markdown",
3660
   "id": "e4b0a5cc-15e7-486b-86d2-4165841da03a",
3661
   "metadata": {},
3662
   "source": [
3663
    "### MoEs without context"
3664
   ]
3665
  },
3666
  {
3667
   "cell_type": "markdown",
3668
   "id": "09ff5ffe-dce4-4bc6-93db-f4c924c23100",
3669
   "metadata": {
3670
    "tags": []
3671
   },
3672
   "source": [
3673
    "Curious how well these mixture of experts (MoE) fare without any context."
3674
   ]
3675
  },
3676
  {
3677
   "cell_type": "code",
3678
   "execution_count": null,
3679
   "id": "7d0b694d-fd77-4b2b-8c6d-55f46307aea3",
3680
   "metadata": {
3681
    "tags": []
3682
   },
3683
   "outputs": [],
3684
   "source": [
3685
    "experiment_names = []\n",
3686
    "moes = [\"gpt-4\",\n",
3687
    "        \"gpt-4-1106-preview\",\n",
3688
    "        \"mistralai/Mixtral-8x7B-Instruct-v0.1\"]"
3689
   ]
3690
  },
3691
  {
3692
   "cell_type": "code",
3693
   "execution_count": null,
3694
   "id": "6bcf4b91-c5a3-4465-b427-5b2720f72f69",
3695
   "metadata": {
3696
    "tags": []
3697
   },
3698
   "outputs": [
3699
    {
3700
     "data": {
3701
      "application/json": {
3702
       "generated_answer": " I'm sorry for the confusion, but I need more context to provide a helpful answer. However, if you're asking about killing a specific replica of a service in a general sense (not tied to a specific technology), you might want to follow these steps:\n\n1. Identify the replica you want to kill, based on its process ID (PID) or some other unique identifier.\n2. Use the `kill` command followed by the PID or identifier to stop the replica. For example, `kill -9 PID` will forcefully terminate the process.\n\nPlease note that this is a general approach, and the specific steps might vary depending on the technology or environment you're using. If you could provide more context, I'd be happy to give a more tailored answer.",
3703
       "question": "how do I kill a specific serve replica",
3704
       "reasoning": "The generated answer provides a general approach to killing a specific service replica, which is useful even though the context does not provide specific information. However, it does not perfectly match the reference answer, hence the score is not a perfect 5.",
3705
       "reference_answer": "The context does not provide information on how to kill a specific serve replica.",
3706
       "score": 4,
3707
       "sources": []
3708
      },
3709
      "text/plain": [
3710
       "<IPython.core.display.JSON object>"
3711
      ]
3712
     },
3713
     "metadata": {
3714
      "application/json": {
3715
       "expanded": false,
3716
       "root": "root"
3717
      }
3718
     },
3719
     "output_type": "display_data"
3720
    },
3721
    {
3722
     "name": "stderr",
3723
     "output_type": "stream",
3724
     "text": [
3725
      "100%|██████████| 177/177 [12:13<00:00,  4.14s/it]\n"
3726
     ]
3727
    }
3728
   ],
3729
   "source": [
3730
    "for moe in moes:\n",
3731
    "    experiment_name = f\"without-context-{moe.split('/')[-1].lower()}\"\n",
3732
    "    experiment_names.append(experiment_name)\n",
3733
    "    run_experiment(\n",
3734
    "        experiment_name=experiment_name, \n",
3735
    "        chunk_size=CHUNK_SIZE, \n",
3736
    "        chunk_overlap=CHUNK_OVERLAP, \n",
3737
    "        num_chunks=0,  # no retrieved context\n",
3738
    "        embedding_model_name=EMBEDDING_MODEL_NAME,\n",
3739
    "        embedding_dim=EMBEDDING_DIMENSIONS[EMBEDDING_MODEL_NAME],\n",
3740
    "        llm=moe,\n",
3741
    "        evaluator=EVALUATOR,\n",
3742
    "        docs_dir=DOCS_DIR, \n",
3743
    "        experiments_dir=EXPERIMENTS_DIR, \n",
3744
    "        references_fp=REFERENCES_FILE_PATH,\n",
3745
    "        num_samples=NUM_SAMPLES)"
3746
   ]
3747
  },
3748
  {
3749
   "cell_type": "code",
3750
   "execution_count": null,
3751
   "id": "955176f4-dde5-4a2b-ad28-b538e30de118",
3752
   "metadata": {
3753
    "tags": []
3754
   },
3755
   "outputs": [
3756
    {
3757
     "name": "stdout",
3758
     "output_type": "stream",
3759
     "text": [
3760
      "without-context-gpt-4\n",
3761
      "  retrieval score: 0.0\n",
3762
      "  quality score: 1.4887005649717515\n",
3763
      "\n",
3764
      "without-context-gpt-4-1106-preview\n",
3765
      "  retrieval score: 0.0\n",
3766
      "  quality score: 3.8163841807909606\n",
3767
      "\n",
3768
      "without-context-mixtral-8x7b-instruct-v0.1\n",
3769
      "  retrieval score: 0.0\n",
3770
      "  quality score: 3.189265536723164\n",
3771
      "\n"
3772
     ]
3773
    }
3774
   ],
3775
   "source": [
3776
    "scores = {}\n",
3777
    "for experiment_name in experiment_names:\n",
3778
    "    scores[experiment_name] = print_experiment(experiment_name, EXPERIMENTS_DIR)"
3779
   ]
3780
  },
3781
  {
3782
   "cell_type": "markdown",
3783
   "id": "16420a7d-445a-48da-a4c1-5238c4061f6e",
3784
   "metadata": {
3785
    "tags": []
3786
   },
3787
   "source": [
3788
    "# Fine-tuning"
3789
   ]
3790
  },
3791
  {
3792
   "cell_type": "markdown",
3793
   "id": "285c9d07-1658-4779-9e13-801c0c77186a",
3794
   "metadata": {},
3795
   "source": [
3796
    "Everything we have explored so far involves optimizing for how our data is preprocessed and using our models (embedding, LLM, etc.) as is. However, it's also worth exploring fine-tuning our models with data unique to our use case. This could help us better represent our data and ultimately increase our retrieval and quality scores. In this section, we're going to fine-tune our embedding model. The intuition here is that it may be worth it to learn a more contextual representation of our tokens than the default embedding models can. This can especially be impactful if we have a lot of:\n",
3797
    "- new tokens that the default tokenization process creates subtokens out of that lose the significance of the token\n",
3798
    "- existing tokens that have contextually different meanings in our use case\n",
3799
    "\n",
3800
    "<img width=\"800\" src=\"https://images.ctfassets.net/xjan103pcp94/4G5324lsDZwq0jES7uBH0l/a715cd50af7061e1b3c57ec3e8038f05/rag-based-llm-applications-finetune-embeddings.png\">\n",
3801
    "\n",
3802
    "When it comes to fine-tuning our embedding model, we will exploring two approaches:\n",
3803
    "- **full parameter**: including the embedding layer and all subsequent encoder layers (transformer blocks)\n",
3804
    "- **embedding layer**: to better represent our unique subtokens and avoid overfitting (version of linear adapter)\n",
3805
    "\n",
3806
    "**Note**: we will not be exploring fine-tuning our LLM in this section because our previous [experiments](https://www.anyscale.com/blog/fine-tuning-llama-2-a-comprehensive-case-study-for-tailoring-models-to-unique-applications) ([LoRa vs. full parameter](https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2)) have shown that fine-tuning has helped tremendously with [form not facts](https://www.anyscale.com/blog/fine-tuning-is-for-form-not-facts), which in our case won't help too much (compared to for ex. SQL generation). However, your use cases might benefit from fine-tuning, so be sure to check out our [Anyscale Endpoints fine-tuning](https://www.anyscale.com/endpoints) to easily tune and serve models (fully hosted or private on your cloud)."
3807
   ]
3808
  },
3809
  {
3810
   "cell_type": "markdown",
3811
   "id": "56747523-2ef7-489e-bba3-089a29ca25c5",
3812
   "metadata": {},
3813
   "source": [
3814
    "## Synthetic dataset"
3815
   ]
3816
  },
3817
  {
3818
   "cell_type": "markdown",
3819
   "id": "e3f35bef-5f68-4483-86d9-413ec4d5ea8f",
3820
   "metadata": {},
3821
   "source": [
3822
    "Our first step will be to create a dataset to fine-tune our embedding model on. Our current embedding models have been trained via self-supervised learning (word2vec, GloVe, next/masked token prediction, etc.) and so we will continue fine-tuning with a self-supervised workflow. We're going to reuse a very similar approach as our cold start QA dataset section earlier so that we can map sections in our data to questions. The fine-tuning task here will be for the model to determine which sections in our dataset maps best to the input query. This optimization task will allow our embedding model to learn better representations of tokens in our dataset.\n",
3823
    "\n",
3824
    "**Note**: While we could create a dataset mapping section titles with section text, we are creating a synthetic Q&A dataset because it will be most representative of the types of data we want to learn how to embed."
3825
   ]
3826
  },
3827
  {
3828
   "cell_type": "markdown",
3829
   "id": "5bb8c5c1-3bab-4539-9b79-cb9931a57c15",
3830
   "metadata": {},
3831
   "source": [
3832
    "Our prompt is going to be a bit different because we want to generate a variety of different questions and we're going to use `llama-70b` here so that we can scale this QA generation process (and avoid any rate limits). To be thorough, we're going to generate one question from every section in our dataset so that we can try to capture as many unique tokens as possible."
3833
   ]
3834
  },
3835
  {
3836
   "cell_type": "code",
3837
   "execution_count": null,
3838
   "id": "4de030ee-22bc-45ba-a6ad-2f82f6e04434",
3839
   "metadata": {
3840
    "tags": []
3841
   },
3842
   "outputs": [],
3843
   "source": [
3844
    "system_content = f\"\"\"\n",
3845
    "Create one question using only the context provided starting with \"What\", \"How\" or \"Why\".\n",
3846
    "Only respond with the question, don't say anything else (unecessary starting words, hints, etc.)\n",
3847
    "\"\"\""
3848
   ]
3849
  },
3850
  {
3851
   "cell_type": "code",
3852
   "execution_count": null,
3853
   "id": "0224a7ef-25ab-4d81-a6e8-a1532822e0d7",
3854
   "metadata": {
3855
    "tags": []
3856
   },
3857
   "outputs": [],
3858
   "source": [
3859
    "# Generate questions\n",
3860
    "embedding_qa = []\n",
3861
    "sections = sections_ds.take_all()\n",
3862
    "max_context_length = int(0.5*MAX_CONTEXT_LENGTHS[LLM]-get_num_tokens(system_content))\n",
3863
    "for section in tqdm(sections):\n",
3864
    "    user_content = trim(\n",
3865
    "        text=f\"context: {section['text']}\", \n",
3866
    "        max_context_length=max_context_length)\n",
3867
    "    response = generate_response(\n",
3868
    "        llm=\"meta-llama/Llama-2-70b-chat-hf\",\n",
3869
    "        temperature=0.0,\n",
3870
    "        stream=False,\n",
3871
    "        system_content=system_content,\n",
3872
    "        user_content=user_content,\n",
3873
    "        max_retries=1)\n",
3874
    "    if response:\n",
3875
    "        embedding_qa.append({\"question\": response, \"source\": section[\"source\"]})\n",
3876
    "print (len(embedding_qa))"
3877
   ]
3878
  },
3879
  {
3880
   "cell_type": "code",
3881
   "execution_count": null,
3882
   "id": "65ed5a1c-a55f-426d-a839-04385bd9268d",
3883
   "metadata": {
3884
    "tags": []
3885
   },
3886
   "outputs": [],
3887
   "source": [
3888
    "# Path\n",
3889
    "EMBEDDING_QA_FILE_PATH = Path(ROOT_DIR, \"datasets\", \"embedding_qa.json\")\n",
3890
    "EMBEDDING_QA_FILE_PATH.parent.mkdir(parents=True, exist_ok=True)"
3891
   ]
3892
  },
3893
  {
3894
   "cell_type": "code",
3895
   "execution_count": null,
3896
   "id": "98298e4e-b681-45eb-8ea4-37d66cdf8685",
3897
   "metadata": {
3898
    "tags": []
3899
   },
3900
   "outputs": [],
3901
   "source": [
3902
    "# Save to file\n",
3903
    "with open(EMBEDDING_QA_FILE_PATH, \"w\") as fp:\n",
3904
    "    json.dump(embedding_qa, fp, indent=4)"
3905
   ]
3906
  },
3907
  {
3908
   "cell_type": "markdown",
3909
   "id": "62ecc4c4-79da-4b9c-a9c3-7cdfb53615c3",
3910
   "metadata": {},
3911
   "source": [
3912
    "## Training data"
3913
   ]
3914
  },
3915
  {
3916
   "cell_type": "markdown",
3917
   "id": "3826c940-a40b-42ea-a50e-0297c78b6929",
3918
   "metadata": {},
3919
   "source": [
3920
    "We're now going to split our dataset into training and validation splits."
3921
   ]
3922
  },
3923
  {
3924
   "cell_type": "code",
3925
   "execution_count": null,
3926
   "id": "2c4a7e67-6f17-4607-ae27-ab3253c7ae33",
3927
   "metadata": {
3928
    "tags": []
3929
   },
3930
   "outputs": [],
3931
   "source": [
3932
    "from sentence_transformers import InputExample"
3933
   ]
3934
  },
3935
  {
3936
   "cell_type": "code",
3937
   "execution_count": null,
3938
   "id": "ba0b81ee-20eb-4721-af49-64544c36dc26",
3939
   "metadata": {
3940
    "tags": []
3941
   },
3942
   "outputs": [],
3943
   "source": [
3944
    "# Load from file\n",
3945
    "with open(EMBEDDING_QA_FILE_PATH, \"r\") as fp:\n",
3946
    "    embedding_qa = json.load(fp)"
3947
   ]
3948
  },
3949
  {
3950
   "cell_type": "code",
3951
   "execution_count": null,
3952
   "id": "c7fadbbf-28ba-4fcc-baa2-e9df78af3c67",
3953
   "metadata": {
3954
    "tags": []
3955
   },
3956
   "outputs": [],
3957
   "source": [
3958
    "# Split counts\n",
3959
    "num_train_samples = int(len(embedding_qa)*0.8)\n",
3960
    "emb_qa_train = embedding_qa[:num_train_samples]\n",
3961
    "emb_qa_val = embedding_qa[num_train_samples:]"
3962
   ]
3963
  },
3964
  {
3965
   "cell_type": "code",
3966
   "execution_count": null,
3967
   "id": "a4c6cad8-8859-4710-a01f-8aff5e8b923b",
3968
   "metadata": {
3969
    "tags": []
3970
   },
3971
   "outputs": [
3972
    {
3973
     "name": "stderr",
3974
     "output_type": "stream",
3975
     "text": [
3976
      "100%|██████████| 4581/4581 [03:44<00:00, 20.40it/s]\n"
3977
     ]
3978
    }
3979
   ],
3980
   "source": [
3981
    "# Training dataset\n",
3982
    "train_dataset = []\n",
3983
    "for item in tqdm(emb_qa_train):\n",
3984
    "    query = item[\"question\"]\n",
3985
    "    source_text = fetch_text(item[\"source\"])\n",
3986
    "    example = InputExample(texts=[query, source_text])\n",
3987
    "    train_dataset.append(example)"
3988
   ]
3989
  },
3990
  {
3991
   "cell_type": "markdown",
3992
   "id": "cfebe1c5-8a76-41df-9c16-aa9757e8e553",
3993
   "metadata": {},
3994
   "source": [
3995
    "## Validation"
3996
   ]
3997
  },
3998
  {
3999
   "cell_type": "markdown",
4000
   "id": "4cf33abf-6031-4829-8121-8dbde6b84570",
4001
   "metadata": {},
4002
   "source": [
4003
    "Our validation evaluation criteria involves an information retrieval (IR) evaluator that will retrieve the top k similar documents from the corpus for each query. The [InformationRetrievalEvaluator](https://www.sbert.net/docs/package_reference/evaluation.html#sentence_transformers.evaluation.InformationRetrievalEvaluator) requires the following inputs:\n",
4004
    "\n",
4005
    "- queries: `Dict[str, str]`  #  qid => query\n",
4006
    "- corpus: `Dict[str, str]`  #  cid => doc\n",
4007
    "- relevant_docs: `Dict[str, Set[str]]`  #  qid => Set[cid]\n",
4008
    "\n",
4009
    "**Note**: While our dataset may have multiple valid sections for a particular query, we will treat all other sections besides the one used to generate the query, as negative samples. This isn't an ideal scenario but the noise introduced is minimal, especially since we are using this to tune a representation layer (and not for a classification task)."
4010
   ]
4011
  },
4012
  {
4013
   "cell_type": "code",
4014
   "execution_count": null,
4015
   "id": "fccb9bc9-bd51-4289-9576-c822a7d85ca6",
4016
   "metadata": {
4017
    "tags": []
4018
   },
4019
   "outputs": [],
4020
   "source": [
4021
    "from sentence_transformers.evaluation import InformationRetrievalEvaluator"
4022
   ]
4023
  },
4024
  {
4025
   "cell_type": "code",
4026
   "execution_count": null,
4027
   "id": "678540b5-4f1c-441b-8a64-fb1860fd6c55",
4028
   "metadata": {
4029
    "tags": []
4030
   },
4031
   "outputs": [
4032
    {
4033
     "name": "stderr",
4034
     "output_type": "stream",
4035
     "text": [
4036
      "100%|██████████| 1146/1146 [01:31<00:00, 12.47it/s]\n"
4037
     ]
4038
    }
4039
   ],
4040
   "source": [
4041
    "# Validation dataset\n",
4042
    "queries, corpus, relevant_docs = {}, {}, {}\n",
4043
    "for i, item in tqdm(enumerate(emb_qa_val), total=len(emb_qa_val)):\n",
4044
    "    queries[f\"qid_{i}\"] = item[\"question\"]\n",
4045
    "    corpus[f\"cid_{i}\"] = fetch_text(item[\"source\"])\n",
4046
    "    relevant_docs[f\"qid_{i}\"] = set([f\"cid_{i}\"])\n",
4047
    "evaluator = InformationRetrievalEvaluator(queries, corpus, relevant_docs)"
4048
   ]
4049
  },
4050
  {
4051
   "cell_type": "markdown",
4052
   "id": "3e0f8ada-a80d-480a-be2f-90a54e4a2f37",
4053
   "metadata": {},
4054
   "source": [
4055
    "We'll be using [MultipleNegativesRankingLoss](https://www.sbert.net/docs/package_reference/losses.html#multiplenegativesrankingloss) as our loss function. It will use the data points (`InputExample(texts=[query, source_text])` in our training data as positive pairs and all other combinations as negative pairs. And the objective will be to increase the cosine similarity (default `similarity_fct`) for our positive pair and decrease it for the other pairs."
4056
   ]
4057
  },
4058
  {
4059
   "cell_type": "code",
4060
   "execution_count": null,
4061
   "id": "9810f112-b3f5-4767-ba8b-d4b556b51db8",
4062
   "metadata": {
4063
    "tags": []
4064
   },
4065
   "outputs": [],
4066
   "source": [
4067
    "# Custom callback to view validation performance\n",
4068
    "def val_callback(score, epoch, steps):\n",
4069
    "    print (f\"EPOCH: {epoch}, VAL SCORE:{score:.4f}\\n\")"
4070
   ]
4071
  },
4072
  {
4073
   "cell_type": "markdown",
4074
   "id": "57a1d156-c6b6-4488-afd7-ec7fe6add16b",
4075
   "metadata": {
4076
    "tags": []
4077
   },
4078
   "source": [
4079
    "## Embedding model"
4080
   ]
4081
  },
4082
  {
4083
   "cell_type": "markdown",
4084
   "id": "dfcf2c7d-20b0-4363-b8e5-a0b314a9c95f",
4085
   "metadata": {},
4086
   "source": [
4087
    "Now we're ready to initialize our embedding model for fine-tuning."
4088
   ]
4089
  },
4090
  {
4091
   "cell_type": "code",
4092
   "execution_count": null,
4093
   "id": "0141b57e-72e7-42f4-8ab9-9ebd1c401de2",
4094
   "metadata": {
4095
    "tags": []
4096
   },
4097
   "outputs": [],
4098
   "source": [
4099
    "from sentence_transformers import SentenceTransformer"
4100
   ]
4101
  },
4102
  {
4103
   "cell_type": "code",
4104
   "execution_count": null,
4105
   "id": "3291d4e9-aed7-44f3-9742-6270682e9025",
4106
   "metadata": {
4107
    "tags": []
4108
   },
4109
   "outputs": [
4110
    {
4111
     "data": {
4112
      "application/vnd.jupyter.widget-view+json": {
4113
       "model_id": "32dd444af3f1402dbcf6daea5e405744",
4114
       "version_major": 2,
4115
       "version_minor": 0
4116
      },
4117
      "text/plain": [
4118
       "Downloading .gitattributes:   0%|          | 0.00/1.52k [00:00<?, ?B/s]"
4119
      ]
4120
     },
4121
     "metadata": {},
4122
     "output_type": "display_data"
4123
    },
4124
    {
4125
     "data": {
4126
      "application/vnd.jupyter.widget-view+json": {
4127
       "model_id": "9d0efc4247e741fb8effa7073c8bae21",
4128
       "version_major": 2,
4129
       "version_minor": 0
4130
      },
4131
      "text/plain": [
4132
       "Downloading 1_Pooling/config.json:   0%|          | 0.00/191 [00:00<?, ?B/s]"
4133
      ]
4134
     },
4135
     "metadata": {},
4136
     "output_type": "display_data"
4137
    },
4138
    {
4139
     "data": {
4140
      "application/vnd.jupyter.widget-view+json": {
4141
       "model_id": "6d2153ba02c94f0eaaaf9034ea394568",
4142
       "version_major": 2,
4143
       "version_minor": 0
4144
      },
4145
      "text/plain": [
4146
       "Downloading README.md:   0%|          | 0.00/67.9k [00:00<?, ?B/s]"
4147
      ]
4148
     },
4149
     "metadata": {},
4150
     "output_type": "display_data"
4151
    },
4152
    {
4153
     "data": {
4154
      "application/vnd.jupyter.widget-view+json": {
4155
       "model_id": "ee9b41bbfc4343c48ebdd05e51b6cb57",
4156
       "version_major": 2,
4157
       "version_minor": 0
4158
      },
4159
      "text/plain": [
4160
       "Downloading config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]"
4161
      ]
4162
     },
4163
     "metadata": {},
4164
     "output_type": "display_data"
4165
    },
4166
    {
4167
     "data": {
4168
      "application/vnd.jupyter.widget-view+json": {
4169
       "model_id": "1bb6ee29368f47148f91343d9ac4185a",
4170
       "version_major": 2,
4171
       "version_minor": 0
4172
      },
4173
      "text/plain": [
4174
       "Downloading model.safetensors:   0%|          | 0.00/670M [00:00<?, ?B/s]"
4175
      ]
4176
     },
4177
     "metadata": {},
4178
     "output_type": "display_data"
4179
    },
4180
    {
4181
     "data": {
4182
      "application/vnd.jupyter.widget-view+json": {
4183
       "model_id": "66467ecf19ea42bea6224261fd2d6845",
4184
       "version_major": 2,
4185
       "version_minor": 0
4186
      },
4187
      "text/plain": [
4188
       "Downloading onnx/config.json:   0%|          | 0.00/632 [00:00<?, ?B/s]"
4189
      ]
4190
     },
4191
     "metadata": {},
4192
     "output_type": "display_data"
4193
    },
4194
    {
4195
     "data": {
4196
      "application/vnd.jupyter.widget-view+json": {
4197
       "model_id": "4c188d54748a4713b4ee5f8807819aaa",
4198
       "version_major": 2,
4199
       "version_minor": 0
4200
      },
4201
      "text/plain": [
4202
       "Downloading model.onnx:   0%|          | 0.00/1.34G [00:00<?, ?B/s]"
4203
      ]
4204
     },
4205
     "metadata": {},
4206
     "output_type": "display_data"
4207
    },
4208
    {
4209
     "data": {
4210
      "application/vnd.jupyter.widget-view+json": {
4211
       "model_id": "ce84db913b0c469ea00412d7c9305913",
4212
       "version_major": 2,
4213
       "version_minor": 0
4214
      },
4215
      "text/plain": [
4216
       "Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]"
4217
      ]
4218
     },
4219
     "metadata": {},
4220
     "output_type": "display_data"
4221
    },
4222
    {
4223
     "data": {
4224
      "application/vnd.jupyter.widget-view+json": {
4225
       "model_id": "f09ef9b2be7f44c38f0b3488e4b42cbf",
4226
       "version_major": 2,
4227
       "version_minor": 0
4228
      },
4229
      "text/plain": [
4230
       "Downloading onnx/tokenizer.json:   0%|          | 0.00/712k [00:00<?, ?B/s]"
4231
      ]
4232
     },
4233
     "metadata": {},
4234
     "output_type": "display_data"
4235
    },
4236
    {
4237
     "data": {
4238
      "application/vnd.jupyter.widget-view+json": {
4239
       "model_id": "212f898a4db2400b8dab7e50a801a478",
4240
       "version_major": 2,
4241
       "version_minor": 0
4242
      },
4243
      "text/plain": [
4244
       "Downloading (…)okenizer_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]"
4245
      ]
4246
     },
4247
     "metadata": {},
4248
     "output_type": "display_data"
4249
    },
4250
    {
4251
     "data": {
4252
      "application/vnd.jupyter.widget-view+json": {
4253
       "model_id": "27044bdd83d9413c9b15a0bb3297f0d5",
4254
       "version_major": 2,
4255
       "version_minor": 0
4256
      },
4257
      "text/plain": [
4258
       "Downloading onnx/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]"
4259
      ]
4260
     },
4261
     "metadata": {},
4262
     "output_type": "display_data"
4263
    },
4264
    {
4265
     "data": {
4266
      "application/vnd.jupyter.widget-view+json": {
4267
       "model_id": "807c16faf8c9458490527ebeddc94375",
4268
       "version_major": 2,
4269
       "version_minor": 0
4270
      },
4271
      "text/plain": [
4272
       "Downloading pytorch_model.bin:   0%|          | 0.00/670M [00:00<?, ?B/s]"
4273
      ]
4274
     },
4275
     "metadata": {},
4276
     "output_type": "display_data"
4277
    },
4278
    {
4279
     "data": {
4280
      "application/vnd.jupyter.widget-view+json": {
4281
       "model_id": "b04454bb93bf48aaa3ed7ffcb688cb6b",
4282
       "version_major": 2,
4283
       "version_minor": 0
4284
      },
4285
      "text/plain": [
4286
       "Downloading (…)nce_bert_config.json:   0%|          | 0.00/57.0 [00:00<?, ?B/s]"
4287
      ]
4288
     },
4289
     "metadata": {},
4290
     "output_type": "display_data"
4291
    },
4292
    {
4293
     "data": {
4294
      "application/vnd.jupyter.widget-view+json": {
4295
       "model_id": "07a3eda52cfc44e88b2cdd2b960a81a8",
4296
       "version_major": 2,
4297
       "version_minor": 0
4298
      },
4299
      "text/plain": [
4300
       "Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]"
4301
      ]
4302
     },
4303
     "metadata": {},
4304
     "output_type": "display_data"
4305
    },
4306
    {
4307
     "data": {
4308
      "application/vnd.jupyter.widget-view+json": {
4309
       "model_id": "0c82adb6d6f14ce8a63b9f7d342f0ea7",
4310
       "version_major": 2,
4311
       "version_minor": 0
4312
      },
4313
      "text/plain": [
4314
       "Downloading tokenizer.json:   0%|          | 0.00/712k [00:00<?, ?B/s]"
4315
      ]
4316
     },
4317
     "metadata": {},
4318
     "output_type": "display_data"
4319
    },
4320
    {
4321
     "data": {
4322
      "application/vnd.jupyter.widget-view+json": {
4323
       "model_id": "f0bb5b2ef2ca43f6b0e0bb8dc715c0eb",
4324
       "version_major": 2,
4325
       "version_minor": 0
4326
      },
4327
      "text/plain": [
4328
       "Downloading tokenizer_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]"
4329
      ]
4330
     },
4331
     "metadata": {},
4332
     "output_type": "display_data"
4333
    },
4334
    {
4335
     "data": {
4336
      "application/vnd.jupyter.widget-view+json": {
4337
       "model_id": "67b7842417684fd5bb6e609b0df3faa5",
4338
       "version_major": 2,
4339
       "version_minor": 0
4340
      },
4341
      "text/plain": [
4342
       "Downloading vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]"
4343
      ]
4344
     },
4345
     "metadata": {},
4346
     "output_type": "display_data"
4347
    },
4348
    {
4349
     "data": {
4350
      "application/vnd.jupyter.widget-view+json": {
4351
       "model_id": "ce53f891700e496f954de33851869bcc",
4352
       "version_major": 2,
4353
       "version_minor": 0
4354
      },
4355
      "text/plain": [
4356
       "Downloading modules.json:   0%|          | 0.00/385 [00:00<?, ?B/s]"
4357
      ]
4358
     },
4359
     "metadata": {},
4360
     "output_type": "display_data"
4361
    },
4362
    {
4363
     "data": {
4364
      "text/plain": [
4365
       "SentenceTransformer(\n",
4366
       "  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel \n",
4367
       "  (1): Pooling({'word_embedding_dimension': 1024, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})\n",
4368
       "  (2): Normalize()\n",
4369
       ")"
4370
      ]
4371
     },
4372
     "execution_count": null,
4373
     "metadata": {},
4374
     "output_type": "execute_result"
4375
    }
4376
   ],
4377
   "source": [
4378
    "embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)\n",
4379
    "embedding_model"
4380
   ]
4381
  },
4382
  {
4383
   "cell_type": "markdown",
4384
   "id": "4a64dc32-6dbc-4f1e-b950-d4ffa797912c",
4385
   "metadata": {
4386
    "tags": []
4387
   },
4388
   "source": [
4389
    "## Resize Tokenizer"
4390
   ]
4391
  },
4392
  {
4393
   "cell_type": "markdown",
4394
   "id": "871a9c3a-c60f-4ff0-9d1a-3025fdad56be",
4395
   "metadata": {},
4396
   "source": [
4397
    "While our tokenizer can represent new subtokens that are part of the vocabulary, it might be very helpful to explicitly add new tokens to our base model (BertModel) in our cast to our transformer. And then we can use [resize_token_embeddings](https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings) to adjust the model's embedding layer prior to fine-tuning. This can be very useful for contextual use cases, especially if many tokens are new or existing tokens have a very different meaning in our context."
4398
   ]
4399
  },
4400
  {
4401
   "cell_type": "code",
4402
   "execution_count": null,
4403
   "id": "fb431961-ee8f-46ca-9325-96eca41d0da0",
4404
   "metadata": {
4405
    "tags": []
4406
   },
4407
   "outputs": [],
4408
   "source": [
4409
    "import re"
4410
   ]
4411
  },
4412
  {
4413
   "cell_type": "code",
4414
   "execution_count": null,
4415
   "id": "806dbfcf-bb69-4d38-9fc6-afa6c3f62694",
4416
   "metadata": {
4417
    "tags": []
4418
   },
4419
   "outputs": [],
4420
   "source": [
4421
    "def get_unique_words(texts):\n",
4422
    "    all_text = \" \".join(texts)  # join all texts\n",
4423
    "    all_text = all_text.replace(\"_\", \" \")  # replace underscores (ex. variable names)\n",
4424
    "    words = re.findall(r'\\b[a-zA-Z]+\\b', all_text)  # only letters\n",
4425
    "    words = [word.lower() for word in words]  # lower\n",
4426
    "    return set(words)"
4427
   ]
4428
  },
4429
  {
4430
   "cell_type": "code",
4431
   "execution_count": null,
4432
   "id": "f85c5818-3526-4171-b4f8-4e529b81cc9b",
4433
   "metadata": {
4434
    "tags": []
4435
   },
4436
   "outputs": [
4437
    {
4438
     "name": "stderr",
4439
     "output_type": "stream",
4440
     "text": [
4441
      "100%|██████████| 11204/11204 [00:00<00:00, 1936098.47it/s]\n"
4442
     ]
4443
    }
4444
   ],
4445
   "source": [
4446
    "# Get tokens that are OOV (out of vocabulary)\n",
4447
    "new_words = []\n",
4448
    "vocab = embedding_model.tokenizer.get_vocab().keys()\n",
4449
    "texts = [section[\"text\"] for section in sections_ds.take_all()]\n",
4450
    "unique_words = get_unique_words(texts=texts)\n",
4451
    "for word in tqdm(unique_words):\n",
4452
    "    if word not in vocab:\n",
4453
    "        new_words.append(word)"
4454
   ]
4455
  },
4456
  {
4457
   "cell_type": "code",
4458
   "execution_count": null,
4459
   "id": "a0ee4cfe-7989-4957-9fbe-6fd1360ab504",
4460
   "metadata": {
4461
    "tags": []
4462
   },
4463
   "outputs": [
4464
    {
4465
     "name": "stdout",
4466
     "output_type": "stream",
4467
     "text": [
4468
      "5790\n",
4469
      "['estimator', 'txt', 'replicacontext', 'metricsexportport', 'memoryefficientattentionflashattentionop', 'voc', 'disjoint', 'custompredictor', 'ulimit', 'allvalues']\n"
4470
     ]
4471
    }
4472
   ],
4473
   "source": [
4474
    "# Inspect\n",
4475
    "print (len(new_words))\n",
4476
    "print (new_words[:10])"
4477
   ]
4478
  },
4479
  {
4480
   "cell_type": "code",
4481
   "execution_count": null,
4482
   "id": "30fc2d0c-1ecd-499f-aa06-918464a45b1e",
4483
   "metadata": {
4484
    "tags": []
4485
   },
4486
   "outputs": [
4487
    {
4488
     "name": "stdout",
4489
     "output_type": "stream",
4490
     "text": [
4491
      "30522\n",
4492
      "36312\n"
4493
     ]
4494
    }
4495
   ],
4496
   "source": [
4497
    "# Add new words to tokenizer\n",
4498
    "print (len(embedding_model.tokenizer))\n",
4499
    "embedding_model.tokenizer.add_tokens(new_words)\n",
4500
    "print (len(embedding_model.tokenizer))"
4501
   ]
4502
  },
4503
  {
4504
   "cell_type": "code",
4505
   "execution_count": null,
4506
   "id": "97dd942a-74d2-4650-a87a-f102b75ed206",
4507
   "metadata": {
4508
    "tags": []
4509
   },
4510
   "outputs": [
4511
    {
4512
     "name": "stdout",
4513
     "output_type": "stream",
4514
     "text": [
4515
      "Embedding(30522, 1024, padding_idx=0)\n",
4516
      "Embedding(36312, 1024, padding_idx=0)\n"
4517
     ]
4518
    }
4519
   ],
4520
   "source": [
4521
    "# Resize tokenizer\n",
4522
    "print (embedding_model._modules[\"0\"]._modules[\"auto_model\"]._modules[\"embeddings\"]._modules[\"word_embeddings\"])\n",
4523
    "embedding_model._modules[\"0\"]._modules[\"auto_model\"].resize_token_embeddings(len(embedding_model.tokenizer))\n",
4524
    "embedding_model._modules[\"0\"]._modules[\"auto_model\"]._modules[\"embeddings\"]._modules[\"word_embeddings\"].padding_idx = 0\n",
4525
    "print (embedding_model._modules[\"0\"]._modules[\"auto_model\"]._modules[\"embeddings\"]._modules[\"word_embeddings\"])"
4526
   ]
4527
  },
4528
  {
4529
   "cell_type": "markdown",
4530
   "id": "cf482bee-b507-4166-bf53-60a50984c5d0",
4531
   "metadata": {},
4532
   "source": [
4533
    "## Full parameter"
4534
   ]
4535
  },
4536
  {
4537
   "cell_type": "markdown",
4538
   "id": "e47066d0-e7e2-47d8-8e55-6f295ece02a0",
4539
   "metadata": {},
4540
   "source": [
4541
    "Our full parameter fine-tuning approach will tune all of the following weights:"
4542
   ]
4543
  },
4544
  {
4545
   "cell_type": "code",
4546
   "execution_count": null,
4547
   "id": "aaa5da86-626c-4387-9a83-c73e7c51419d",
4548
   "metadata": {
4549
    "tags": []
4550
   },
4551
   "outputs": [
4552
    {
4553
     "data": {
4554
      "text/plain": [
4555
       "OrderedDict([('auto_model',\n",
4556
       "              BertModel(\n",
4557
       "                (embeddings): BertEmbeddings(\n",
4558
       "                  (word_embeddings): Embedding(36312, 1024, padding_idx=0)\n",
4559
       "                  (position_embeddings): Embedding(512, 1024)\n",
4560
       "                  (token_type_embeddings): Embedding(2, 1024)\n",
4561
       "                  (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
4562
       "                  (dropout): Dropout(p=0.1, inplace=False)\n",
4563
       "                )\n",
4564
       "                (encoder): BertEncoder(\n",
4565
       "                  (layer): ModuleList(\n",
4566
       "                    (0-23): 24 x BertLayer(\n",
4567
       "                      (attention): BertAttention(\n",
4568
       "                        (self): BertSelfAttention(\n",
4569
       "                          (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
4570
       "                          (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
4571
       "                          (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
4572
       "                          (dropout): Dropout(p=0.1, inplace=False)\n",
4573
       "                        )\n",
4574
       "                        (output): BertSelfOutput(\n",
4575
       "                          (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
4576
       "                          (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
4577
       "                          (dropout): Dropout(p=0.1, inplace=False)\n",
4578
       "                        )\n",
4579
       "                      )\n",
4580
       "                      (intermediate): BertIntermediate(\n",
4581
       "                        (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
4582
       "                        (intermediate_act_fn): GELUActivation()\n",
4583
       "                      )\n",
4584
       "                      (output): BertOutput(\n",
4585
       "                        (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
4586
       "                        (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
4587
       "                        (dropout): Dropout(p=0.1, inplace=False)\n",
4588
       "                      )\n",
4589
       "                    )\n",
4590
       "                  )\n",
4591
       "                )\n",
4592
       "                (pooler): BertPooler(\n",
4593
       "                  (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
4594
       "                  (activation): Tanh()\n",
4595
       "                )\n",
4596
       "              ))])"
4597
      ]
4598
     },
4599
     "execution_count": null,
4600
     "metadata": {},
4601
     "output_type": "execute_result"
4602
    }
4603
   ],
4604
   "source": [
4605
    "embedding_model._modules[\"0\"]._modules"
4606
   ]
4607
  },
4608
  {
4609
   "cell_type": "code",
4610
   "execution_count": null,
4611
   "id": "ea3b63f0-428b-4642-a0db-6e23939db67f",
4612
   "metadata": {
4613
    "tags": []
4614
   },
4615
   "outputs": [],
4616
   "source": [
4617
    "from sentence_transformers.losses import MultipleNegativesRankingLoss\n",
4618
    "from torch.utils.data import DataLoader"
4619
   ]
4620
  },
4621
  {
4622
   "cell_type": "code",
4623
   "execution_count": null,
4624
   "id": "dcdcf10f-71cd-4502-bd6d-d979e01691a9",
4625
   "metadata": {
4626
    "tags": []
4627
   },
4628
   "outputs": [],
4629
   "source": [
4630
    "# Training setup\n",
4631
    "num_epochs = 2\n",
4632
    "batch_size = 4\n",
4633
    "train_dataloader = DataLoader(train_dataset, batch_size=batch_size)\n",
4634
    "loss = MultipleNegativesRankingLoss(embedding_model) # MNR Loss\n",
4635
    "warmup_steps = int(0.1 * num_epochs * len(train_dataloader))  # not used"
4636
   ]
4637
  },
4638
  {
4639
   "cell_type": "code",
4640
   "execution_count": null,
4641
   "id": "ecb202a6-d8ac-47d6-a505-7f6ed74d4d44",
4642
   "metadata": {
4643
    "tags": []
4644
   },
4645
   "outputs": [
4646
    {
4647
     "data": {
4648
      "application/vnd.jupyter.widget-view+json": {
4649
       "model_id": "0a1c1f00d4424acca9cebe8698797386",
4650
       "version_major": 2,
4651
       "version_minor": 0
4652
      },
4653
      "text/plain": [
4654
       "Epoch:   0%|          | 0/2 [00:00<?, ?it/s]"
4655
      ]
4656
     },
4657
     "metadata": {},
4658
     "output_type": "display_data"
4659
    },
4660
    {
4661
     "data": {
4662
      "application/vnd.jupyter.widget-view+json": {
4663
       "model_id": "112633295a3b467ba8dff62d0e04a1d9",
4664
       "version_major": 2,
4665
       "version_minor": 0
4666
      },
4667
      "text/plain": [
4668
       "Iteration:   0%|          | 0/1146 [00:00<?, ?it/s]"
4669
      ]
4670
     },
4671
     "metadata": {},
4672
     "output_type": "display_data"
4673
    },
4674
    {
4675
     "name": "stdout",
4676
     "output_type": "stream",
4677
     "text": [
4678
      "EPOCH: 0, VAL SCORE:0.5425\n",
4679
      "\n"
4680
     ]
4681
    },
4682
    {
4683
     "data": {
4684
      "application/vnd.jupyter.widget-view+json": {
4685
       "model_id": "861a6dd2b06641a19d1cecaac246e050",
4686
       "version_major": 2,
4687
       "version_minor": 0
4688
      },
4689
      "text/plain": [
4690
       "Iteration:   0%|          | 0/1146 [00:00<?, ?it/s]"
4691
      ]
4692
     },
4693
     "metadata": {},
4694
     "output_type": "display_data"
4695
    },
4696
    {
4697
     "name": "stdout",
4698
     "output_type": "stream",
4699
     "text": [
4700
      "EPOCH: 1, VAL SCORE:0.5420\n",
4701
      "\n"
4702
     ]
4703
    }
4704
   ],
4705
   "source": [
4706
    "# Train\n",
4707
    "experiment_name = \"gte-large-fine-tuned-fp\"\n",
4708
    "embedding_model_path = str(Path(EFS_DIR, experiment_name))\n",
4709
    "embedding_model.fit(\n",
4710
    "    train_objectives=[(train_dataloader, loss)],\n",
4711
    "    epochs=num_epochs,\n",
4712
    "    warmup_steps=0,\n",
4713
    "    optimizer_params={\"lr\": 1e-8},\n",
4714
    "    weight_decay=0,\n",
4715
    "    output_path=embedding_model_path,\n",
4716
    "    show_progress_bar=True,\n",
4717
    "    evaluator=evaluator,\n",
4718
    "    callback=val_callback)"
4719
   ]
4720
  },
4721
  {
4722
   "cell_type": "code",
4723
   "execution_count": null,
4724
   "id": "df298a82-42ba-413c-bc7e-3b8050bf1c3e",
4725
   "metadata": {
4726
    "tags": []
4727
   },
4728
   "outputs": [],
4729
   "source": [
4730
    "# Clean up\n",
4731
    "import gc\n",
4732
    "import torch\n",
4733
    "del embedding_model, loss\n",
4734
    "gc.collect()\n",
4735
    "torch.cuda.empty_cache()"
4736
   ]
4737
  },
4738
  {
4739
   "cell_type": "markdown",
4740
   "id": "660c77bf-23a6-4690-b7b4-17dda0532cde",
4741
   "metadata": {},
4742
   "source": [
4743
    "Now we're ready to actually apply this fine-tuned embedding model on our test evaluation dataset. We can simply pass in our model artifact directory for the `embedding_model_name` because [HuggingFaceEmbeddings](https://api.python.langchain.com/en/latest/embeddings/langchain.embeddings.huggingface.HuggingFaceEmbeddings.html) accepts a string that can be either a directory or the model's name. If a directory matches with the input string, then it will load the model from that location first before trying to search on HF's hub."
4744
   ]
4745
  },
4746
  {
4747
   "cell_type": "code",
4748
   "execution_count": null,
4749
   "id": "2fecbe01-3b88-4124-9128-1fa0a04ec2cb",
4750
   "metadata": {
4751
    "tags": []
4752
   },
4753
   "outputs": [
4754
    {
4755
     "data": {
4756
      "application/json": {
4757
       "generated_answer": " To kill a specific serve replica, you can follow these steps:\n\n1. List the available actors using the command `ray summary actors` or `ray list actors --filter \"class_name=ServeReplica:SleepyPid\"` in the Ray terminal.\n2. Identify the NAME of the replica you want to kill from the list.\n3. Get a handle to the replica using the command `replica_handle = ray.get_actor(\"NAME\", namespace=\"serve\")` in a Python Ray terminal.\n4. Kill the replica using the command `ray.kill(replica_handle, no_restart=True)` in the Python Ray terminal.\n\nFor example, if your replica's NAME is `SERVE_REPLICA::SleepyPid#RlRptP`, you can kill it using the following commands:\n```python\nreplica_handle = ray.get_actor(\"SERVE_REPLICA::SleepyPid#RlRptP\", namespace=\"serve\")\nray.kill(replica_handle, no_restart=True)\n```",
4758
       "question": "how do I kill a specific serve replica",
4759
       "reasoning": "The generated answer is detailed, accurate, and provides a step-by-step guide on how to kill a specific serve replica. It goes above and beyond the reference answer, which does not provide any information on the query. The generated answer is also technically correct and would be helpful to someone looking for this information.",
4760
       "reference_answer": "The context does not provide information on how to kill a specific serve replica.",
4761
       "score": 5,
4762
       "sources": [
4763
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
4764
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
4765
        "https://docs.ray.io/en/master/serve/advanced-guides/dev-workflow.html#local-development-with-http-requests",
4766
        "https://docs.ray.io/en/master/serve/api/index.html#delete-api-serve-deployments",
4767
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
4768
        "https://docs.ray.io/en/master/serve/scaling-and-resource-allocation.html#autoscaling",
4769
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#worker-node-failure",
4770
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#httpproxy-failure",
4771
        "https://docs.ray.io/en/master/serve/api/index.html#delete-api-serve-applications",
4772
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#replica-health-checking",
4773
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#httpproxy-failure",
4774
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#serve-controller-failure",
4775
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#head-node-failure"
4776
       ]
4777
      },
4778
      "text/plain": [
4779
       "<IPython.core.display.JSON object>"
4780
      ]
4781
     },
4782
     "metadata": {
4783
      "application/json": {
4784
       "expanded": false,
4785
       "root": "root"
4786
      }
4787
     },
4788
     "output_type": "display_data"
4789
    },
4790
    {
4791
     "name": "stderr",
4792
     "output_type": "stream",
4793
     "text": [
4794
      "100%|██████████| 177/177 [22:47<00:00,  7.73s/it]\n"
4795
     ]
4796
    }
4797
   ],
4798
   "source": [
4799
    "sql_dump_fp = Path(EFS_DIR, \"sql_dumps\", f\"{experiment_name}_{CHUNK_SIZE}_{CHUNK_OVERLAP}.sql\")\n",
4800
    "run_experiment(\n",
4801
    "    experiment_name=experiment_name, \n",
4802
    "    chunk_size=CHUNK_SIZE, \n",
4803
    "    chunk_overlap=CHUNK_OVERLAP, \n",
4804
    "    num_chunks=NUM_CHUNKS,\n",
4805
    "    embedding_model_name=embedding_model_path,\n",
4806
    "    embedding_dim=EMBEDDING_DIMENSIONS[EMBEDDING_MODEL_NAME],\n",
4807
    "    llm=llm,  # ensure same model as we did for embedding model experiments\n",
4808
    "    evaluator=EVALUATOR,\n",
4809
    "    docs_dir=DOCS_DIR, \n",
4810
    "    experiments_dir=EXPERIMENTS_DIR, \n",
4811
    "    references_fp=REFERENCES_FILE_PATH,\n",
4812
    "    num_samples=NUM_SAMPLES,\n",
4813
    "    sql_dump_fp=sql_dump_fp)"
4814
   ]
4815
  },
4816
  {
4817
   "cell_type": "code",
4818
   "execution_count": null,
4819
   "id": "271f0c48-1962-4ef8-adaf-ce0fe9dc4525",
4820
   "metadata": {
4821
    "tags": []
4822
   },
4823
   "outputs": [],
4824
   "source": [
4825
    "embedding_model_names.append(experiment_name)\n",
4826
    "experiment_names = []\n",
4827
    "for embedding_model_name in embedding_model_names:\n",
4828
    "    experiment_names.append(f\"{embedding_model_name.split('/')[-1]}\")"
4829
   ]
4830
  },
4831
  {
4832
   "cell_type": "code",
4833
   "execution_count": null,
4834
   "id": "3b4a3565-264c-4047-9cee-37dbc27ace8d",
4835
   "metadata": {
4836
    "tags": []
4837
   },
4838
   "outputs": [
4839
    {
4840
     "name": "stdout",
4841
     "output_type": "stream",
4842
     "text": [
4843
      "gte-base\n",
4844
      "  retrieval score: 0.7570621468926554\n",
4845
      "  quality score: 3.9322033898305087\n",
4846
      "\n",
4847
      "gte-large\n",
4848
      "  retrieval score: 0.7796610169491526\n",
4849
      "  quality score: 3.9350282485875705\n",
4850
      "\n",
4851
      "bge-large-en\n",
4852
      "  retrieval score: 0.4745762711864407\n",
4853
      "  quality score: 3.480225988700565\n",
4854
      "\n",
4855
      "text-embedding-ada-002\n",
4856
      "  retrieval score: 0.6497175141242938\n",
4857
      "  quality score: 3.5395480225988702\n",
4858
      "\n",
4859
      "gte-large-fine-tuned-fp\n",
4860
      "  retrieval score: 0.5141242937853108\n",
4861
      "  quality score: 3.446327683615819\n",
4862
      "\n"
4863
     ]
4864
    },
4865
    {
4866
     "data": {
4867
      "image/png": "",
4868
      "text/plain": [
4869
       "<Figure size 1000x300 with 1 Axes>"
4870
      ]
4871
     },
4872
     "metadata": {},
4873
     "output_type": "display_data"
4874
    }
4875
   ],
4876
   "source": [
4877
    "scores = {}\n",
4878
    "for experiment_name in experiment_names:\n",
4879
    "    scores[experiment_name] = print_experiment(experiment_name, EXPERIMENTS_DIR)\n",
4880
    "plot_scores(scores=scores)"
4881
   ]
4882
  },
4883
  {
4884
   "cell_type": "markdown",
4885
   "id": "77dcce5c-953f-4dca-a589-862bb476e98f",
4886
   "metadata": {},
4887
   "source": [
4888
    "This didn't really improve our overall application's retrieval or quality score. This doesn't necessarily mean that fine-tuning is not useful but might not always be worth the effort.\n",
4889
    "- synthetic data is not exactly like the types of questions that users ask (might be worth creating a dataset of more realistic queries or prompt tuning for more synthetic data that is more representative of user queries).\n",
4890
    "- fine-tuning the entire embedding model on our small embedding dataset might be causing **overfitting**.\n",
4891
    "- our experiment's evaluation is on a small dataset so slightly tuning embeddings via MNR may not increase retrieval recall much/if at all."
4892
   ]
4893
  },
4894
  {
4895
   "cell_type": "markdown",
4896
   "id": "baad7054-5af6-4f4b-8257-af9f9172bd26",
4897
   "metadata": {},
4898
   "source": [
4899
    "## Embedding layer"
4900
   ]
4901
  },
4902
  {
4903
   "cell_type": "markdown",
4904
   "id": "ca2b2107-a086-4b48-86a6-94728ddbd822",
4905
   "metadata": {},
4906
   "source": [
4907
    "To help mitigate the overfitting, we can avoid retraining the entire embedding model and freeze all layers except for the embedding layer (word/subtoken embedding only, not the positional or token type layers). **Note**: this approach is somewhat similar to training a separate linear adapter (which we evaluation results for) except that it's larger and requires rebuilding the index."
4908
   ]
4909
  },
4910
  {
4911
   "cell_type": "code",
4912
   "execution_count": null,
4913
   "id": "e4ae9f3b-366e-42de-9ad7-32510df31770",
4914
   "metadata": {
4915
    "tags": []
4916
   },
4917
   "outputs": [],
4918
   "source": [
4919
    "import torch\n",
4920
    "import torch.nn as nn"
4921
   ]
4922
  },
4923
  {
4924
   "cell_type": "code",
4925
   "execution_count": null,
4926
   "id": "9d7fa520-2bf1-4f42-a008-359047913d42",
4927
   "metadata": {
4928
    "tags": []
4929
   },
4930
   "outputs": [
4931
    {
4932
     "data": {
4933
      "text/plain": [
4934
       "BertEmbeddings(\n",
4935
       "  (word_embeddings): Embedding(30522, 1024, padding_idx=0)\n",
4936
       "  (position_embeddings): Embedding(512, 1024)\n",
4937
       "  (token_type_embeddings): Embedding(2, 1024)\n",
4938
       "  (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
4939
       "  (dropout): Dropout(p=0.1, inplace=False)\n",
4940
       ")"
4941
      ]
4942
     },
4943
     "execution_count": null,
4944
     "metadata": {},
4945
     "output_type": "execute_result"
4946
    }
4947
   ],
4948
   "source": [
4949
    "# Reinitialize base embedding model\n",
4950
    "embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)\n",
4951
    "embedding_model._modules[\"0\"]._modules[\"auto_model\"]._modules[\"embeddings\"]"
4952
   ]
4953
  },
4954
  {
4955
   "cell_type": "code",
4956
   "execution_count": null,
4957
   "id": "6b03f5ed-b938-47d0-95fb-67ee08aaa324",
4958
   "metadata": {
4959
    "tags": []
4960
   },
4961
   "outputs": [],
4962
   "source": [
4963
    "# Unfreeze embedding layers\n",
4964
    "for param in embedding_model._modules[\"0\"]._modules[\"auto_model\"]._modules[\"embeddings\"].parameters():\n",
4965
    "    param.requires_grad = True"
4966
   ]
4967
  },
4968
  {
4969
   "cell_type": "code",
4970
   "execution_count": null,
4971
   "id": "b6157d3d-103e-4287-bf70-cca745e797bb",
4972
   "metadata": {
4973
    "tags": []
4974
   },
4975
   "outputs": [],
4976
   "source": [
4977
    "# Freeze Bert encoder layers\n",
4978
    "for param in embedding_model._modules[\"0\"]._modules[\"auto_model\"]._modules[\"encoder\"].parameters():\n",
4979
    "    param.requires_grad = False"
4980
   ]
4981
  },
4982
  {
4983
   "cell_type": "code",
4984
   "execution_count": null,
4985
   "id": "9506f22c-f192-4384-83db-0aa108494f7b",
4986
   "metadata": {
4987
    "tags": []
4988
   },
4989
   "outputs": [],
4990
   "source": [
4991
    "# Training setup\n",
4992
    "num_epochs = 2\n",
4993
    "batch_size = 4\n",
4994
    "train_dataloader = DataLoader(train_dataset, batch_size=batch_size)\n",
4995
    "loss = MultipleNegativesRankingLoss(embedding_model)\n",
4996
    "warmup_steps = int(0.1 * num_epochs * len(train_dataloader))  # not used"
4997
   ]
4998
  },
4999
  {
5000
   "cell_type": "code",
5001
   "execution_count": null,
5002
   "id": "ff787441-13fc-4d18-8eca-d76f691f0940",
5003
   "metadata": {
5004
    "tags": []
5005
   },
5006
   "outputs": [
5007
    {
5008
     "data": {
5009
      "application/vnd.jupyter.widget-view+json": {
5010
       "model_id": "1ff0d8c3cab1466ea41a0e368aeafbcb",
5011
       "version_major": 2,
5012
       "version_minor": 0
5013
      },
5014
      "text/plain": [
5015
       "Epoch:   0%|          | 0/2 [00:00<?, ?it/s]"
5016
      ]
5017
     },
5018
     "metadata": {},
5019
     "output_type": "display_data"
5020
    },
5021
    {
5022
     "data": {
5023
      "application/vnd.jupyter.widget-view+json": {
5024
       "model_id": "25278952f4414d658e36a4a3774495b4",
5025
       "version_major": 2,
5026
       "version_minor": 0
5027
      },
5028
      "text/plain": [
5029
       "Iteration:   0%|          | 0/1146 [00:00<?, ?it/s]"
5030
      ]
5031
     },
5032
     "metadata": {},
5033
     "output_type": "display_data"
5034
    },
5035
    {
5036
     "name": "stdout",
5037
     "output_type": "stream",
5038
     "text": [
5039
      "EPOCH: 0, VAL SCORE:0.7930\n",
5040
      "\n"
5041
     ]
5042
    },
5043
    {
5044
     "data": {
5045
      "application/vnd.jupyter.widget-view+json": {
5046
       "model_id": "ca65307eb6a345079c88ce3bd2c7afdb",
5047
       "version_major": 2,
5048
       "version_minor": 0
5049
      },
5050
      "text/plain": [
5051
       "Iteration:   0%|          | 0/1146 [00:00<?, ?it/s]"
5052
      ]
5053
     },
5054
     "metadata": {},
5055
     "output_type": "display_data"
5056
    },
5057
    {
5058
     "name": "stdout",
5059
     "output_type": "stream",
5060
     "text": [
5061
      "EPOCH: 1, VAL SCORE:0.7963\n",
5062
      "\n"
5063
     ]
5064
    }
5065
   ],
5066
   "source": [
5067
    "# Train\n",
5068
    "experiment_name = \"gte-large-fine-tuned-el\"\n",
5069
    "embedding_model_path = str(Path(EFS_DIR, experiment_name))\n",
5070
    "embedding_model.fit(\n",
5071
    "    train_objectives=[(train_dataloader, loss)],\n",
5072
    "    epochs=num_epochs,\n",
5073
    "    warmup_steps=0,\n",
5074
    "    optimizer_params={\"lr\": 1e-5},\n",
5075
    "    weight_decay=0,\n",
5076
    "    output_path=embedding_model_path,\n",
5077
    "    show_progress_bar=True,\n",
5078
    "    evaluator=evaluator,\n",
5079
    "    callback=val_callback)"
5080
   ]
5081
  },
5082
  {
5083
   "cell_type": "code",
5084
   "execution_count": null,
5085
   "id": "15778bc3-8ff0-409e-8396-fc2acb6d7e9b",
5086
   "metadata": {
5087
    "tags": []
5088
   },
5089
   "outputs": [
5090
    {
5091
     "data": {
5092
      "application/json": {
5093
       "generated_answer": " To kill a specific serve replica, you can follow these steps:\n\n1. Get a list of all actors by running `ray summary actors` or `ray list actors` command.\n2. Identify the NAME of the specific replica you want to kill from the list.\n3. Get a handle to the replica by running `ray.get_actor(\"NAME\", namespace=\"serve\")` in a Python interpreter.\n4. Kill the replica by running `ray.kill(replica_handle, no_restart=True)` in the Python interpreter.\n\nFor example, if you want to kill the replica with NAME \"SERVE_REPLICA::SleepyPid#RlRptP\", you can run:\n```python\nimport ray\nreplica_handle = ray.get_actor(\"SERVE_REPLICA::SleepyPid#RlRptP\", namespace=\"serve\")\nray.kill(replica_handle, no_restart=True)\n```\nThis will kill the specific replica without restarting it, and you can confirm its death by checking the list of actors again.",
5094
       "question": "how do I kill a specific serve replica",
5095
       "reasoning": "The generated answer is highly detailed and provides a step-by-step guide on how to kill a specific serve replica. It even includes a practical example. The reference answer does not provide any useful information, making the generated answer superior.",
5096
       "reference_answer": "The context does not provide information on how to kill a specific serve replica.",
5097
       "score": 5,
5098
       "sources": [
5099
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
5100
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
5101
        "https://docs.ray.io/en/master/serve/api/index.html#delete-api-serve-deployments",
5102
        "https://docs.ray.io/en/master/serve/advanced-guides/dev-workflow.html#local-development-with-http-requests",
5103
        "https://docs.ray.io/en/master/serve/scaling-and-resource-allocation.html#autoscaling",
5104
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
5105
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#httpproxy-failure",
5106
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#worker-node-failure",
5107
        "https://docs.ray.io/en/master/serve/api/index.html#delete-api-serve-applications",
5108
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#httpproxy-failure",
5109
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#replica-health-checking",
5110
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#serve-controller-failure",
5111
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#head-node-failure"
5112
       ]
5113
      },
5114
      "text/plain": [
5115
       "<IPython.core.display.JSON object>"
5116
      ]
5117
     },
5118
     "metadata": {
5119
      "application/json": {
5120
       "expanded": false,
5121
       "root": "root"
5122
      }
5123
     },
5124
     "output_type": "display_data"
5125
    },
5126
    {
5127
     "name": "stderr",
5128
     "output_type": "stream",
5129
     "text": [
5130
      "100%|██████████| 177/177 [21:29<00:00,  7.28s/it]\n"
5131
     ]
5132
    }
5133
   ],
5134
   "source": [
5135
    "# Experiment\n",
5136
    "sql_dump_fp = Path(EFS_DIR, \"sql_dumps\", f\"{experiment_name}_{CHUNK_SIZE}_{CHUNK_OVERLAP}.sql\")\n",
5137
    "run_experiment(\n",
5138
    "    experiment_name=experiment_name, \n",
5139
    "    chunk_size=CHUNK_SIZE,\n",
5140
    "    chunk_overlap=CHUNK_OVERLAP, \n",
5141
    "    num_chunks=NUM_CHUNKS,\n",
5142
    "    embedding_model_name=embedding_model_path,\n",
5143
    "    embedding_dim=EMBEDDING_DIMENSIONS[EMBEDDING_MODEL_NAME],\n",
5144
    "    llm=llm,  # ensure same model as we did for embedding model experiments\n",
5145
    "    evaluator=EVALUATOR,\n",
5146
    "    docs_dir=DOCS_DIR, \n",
5147
    "    experiments_dir=EXPERIMENTS_DIR, \n",
5148
    "    references_fp=REFERENCES_FILE_PATH,\n",
5149
    "    num_samples=NUM_SAMPLES,\n",
5150
    "    sql_dump_fp=sql_dump_fp)"
5151
   ]
5152
  },
5153
  {
5154
   "cell_type": "code",
5155
   "execution_count": null,
5156
   "id": "3fcd9a29-2e15-46ac-9443-3e87632888a6",
5157
   "metadata": {
5158
    "tags": []
5159
   },
5160
   "outputs": [],
5161
   "source": [
5162
    "embedding_model_names.append(experiment_name)\n",
5163
    "experiment_names = []\n",
5164
    "for embedding_model_name in embedding_model_names:\n",
5165
    "    experiment_names.append(f\"{embedding_model_name.split('/')[-1]}\")"
5166
   ]
5167
  },
5168
  {
5169
   "cell_type": "code",
5170
   "execution_count": null,
5171
   "id": "4eb87087-437f-4a76-9890-e47f457d81be",
5172
   "metadata": {
5173
    "tags": []
5174
   },
5175
   "outputs": [
5176
    {
5177
     "name": "stdout",
5178
     "output_type": "stream",
5179
     "text": [
5180
      "gte-base\n",
5181
      "  retrieval score: 0.7570621468926554\n",
5182
      "  quality score: 3.9322033898305087\n",
5183
      "\n",
5184
      "gte-large\n",
5185
      "  retrieval score: 0.7796610169491526\n",
5186
      "  quality score: 3.9350282485875705\n",
5187
      "\n",
5188
      "bge-large-en\n",
5189
      "  retrieval score: 0.4745762711864407\n",
5190
      "  quality score: 3.480225988700565\n",
5191
      "\n",
5192
      "text-embedding-ada-002\n",
5193
      "  retrieval score: 0.6497175141242938\n",
5194
      "  quality score: 3.5395480225988702\n",
5195
      "\n",
5196
      "gte-large-fine-tuned-fp\n",
5197
      "  retrieval score: 0.5141242937853108\n",
5198
      "  quality score: 3.446327683615819\n",
5199
      "\n",
5200
      "gte-large-fine-tuned-el\n",
5201
      "  retrieval score: 0.7909604519774012\n",
5202
      "  quality score: 3.8728813559322033\n",
5203
      "\n"
5204
     ]
5205
    },
5206
    {
5207
     "data": {
5208
      "image/png": "",
5209
      "text/plain": [
5210
       "<Figure size 1000x300 with 1 Axes>"
5211
      ]
5212
     },
5213
     "metadata": {},
5214
     "output_type": "display_data"
5215
    }
5216
   ],
5217
   "source": [
5218
    "scores = {}\n",
5219
    "for experiment_name in experiment_names:\n",
5220
    "    scores[experiment_name] = print_experiment(experiment_name, EXPERIMENTS_DIR)\n",
5221
    "plot_scores(scores=scores)"
5222
   ]
5223
  },
5224
  {
5225
   "cell_type": "markdown",
5226
   "id": "044e2e42-f6eb-494e-89db-a617bc72da8b",
5227
   "metadata": {},
5228
   "source": [
5229
    "Much better validation scores but it's not worth the effort compared to using our base gte-large embedding model. This again can be improved with larger/higher quality datasets and perhaps even a larger testing dataset to capture small improvements in our retrieval scores. If we did want to use this fine-tuned embedding model, we can set these values:\n",
5230
    "\n",
5231
    "```python\n",
5232
    "experiment_name = \"gte-large-fine-tuned-el\"\n",
5233
    "EMBEDDING_MODEL_PATH = str(Path(EFS_DIR, experiment_name))  # can pass this in directly for embedding_model_name\n",
5234
    "SQL_DUMP_FP = Path(EFS_DIR, \"sql_dumps\", f\"{experiment_name}_{CHUNK_SIZE}_{CHUNK_OVERLAP}.sql\")\n",
5235
    "run_experiment(embedding_model_name=EMBEDDING_MODEL_PATH, sql_dump_fp=SQL_DUMP_FP, ...)\n",
5236
    "```"
5237
   ]
5238
  },
5239
  {
5240
   "cell_type": "markdown",
5241
   "id": "b5e41783-6f45-4bdc-8a1b-fe3329f7b101",
5242
   "metadata": {
5243
    "tags": []
5244
   },
5245
   "source": [
5246
    "# Prompt engineering"
5247
   ]
5248
  },
5249
  {
5250
   "cell_type": "markdown",
5251
   "id": "dbb495d1-557e-41eb-8c27-fe8cea02e7b5",
5252
   "metadata": {
5253
    "tags": []
5254
   },
5255
   "source": [
5256
    "There's too much we can do when it comes to engineering the prompt (x-of-thought, multimodal, self-refine, query decomposition, etc.) so we're going to try out just a few interesting ideas. We're going to allow the LLM to ignore anything not relevant. The idea here is to show how quickly we can go from prompt engineering to evaluation report.\n",
5257
    "\n",
5258
    "<img width=\"800\" src=\"https://images.ctfassets.net/xjan103pcp94/6nMOu5sm3jploFUXeKxog2/1e4405924686798b243eb80ff5b8d549/Screenshot_2023-11-01_at_2.40.56_PM.png\">"
5259
   ]
5260
  },
5261
  {
5262
   "cell_type": "code",
5263
   "execution_count": null,
5264
   "id": "1de8e133-52fc-42d8-ac3d-51e8557ed563",
5265
   "metadata": {
5266
    "tags": []
5267
   },
5268
   "outputs": [],
5269
   "source": [
5270
    "system_content = \"Answer the query using the context provided. Be succinct. \" \\\n",
5271
    "    \"Contexts are organized in a list of dictionaries [{'text': <context>}, {'text': <context>}, ...]. \" \\\n",
5272
    "    \"Feel free to ignore any contexts in the list that don't seem relevant to the query. \""
5273
   ]
5274
  },
5275
  {
5276
   "cell_type": "code",
5277
   "execution_count": null,
5278
   "id": "f618682c-33fb-45d2-8247-883e3136f829",
5279
   "metadata": {
5280
    "tags": []
5281
   },
5282
   "outputs": [
5283
    {
5284
     "data": {
5285
      "application/json": {
5286
       "generated_answer": " To kill a specific serve replica, you can use the `ray.kill()` function with the replica handle. First, get the replica handle by calling `ray.get_actor()` and providing the NAME of the replica from the output of `ray list actors`. Here's an example:\n\n```python\nreplica_handle = ray.get_actor(\"SERVE_REPLICA::SleepyPid#RlRptP\", namespace=\"serve\")\nray.kill(replica_handle, no_restart=True)\n```\n\nThis will kill the specified replica without restarting it.",
5287
       "question": "how do I kill a specific serve replica",
5288
       "reasoning": "The generated answer provides a detailed and technically accurate method to kill a specific serve replica using the `ray.kill()` function. However, it does not perfectly match the reference answer, which states that there is no information available on how to perform the task. Despite this, the generated answer is still useful and informative.",
5289
       "reference_answer": "The context does not provide information on how to kill a specific serve replica.",
5290
       "score": 4,
5291
       "sources": [
5292
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
5293
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
5294
        "https://docs.ray.io/en/master/serve/advanced-guides/dev-workflow.html#local-development-with-http-requests",
5295
        "https://docs.ray.io/en/master/serve/api/index.html#delete-api-serve-deployments",
5296
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
5297
        "https://docs.ray.io/en/master/serve/scaling-and-resource-allocation.html#autoscaling",
5298
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#worker-node-failure",
5299
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#httpproxy-failure",
5300
        "https://docs.ray.io/en/master/serve/api/index.html#delete-api-serve-applications",
5301
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#replica-health-checking",
5302
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#httpproxy-failure",
5303
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#serve-controller-failure",
5304
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#head-node-failure"
5305
       ]
5306
      },
5307
      "text/plain": [
5308
       "<IPython.core.display.JSON object>"
5309
      ]
5310
     },
5311
     "metadata": {
5312
      "application/json": {
5313
       "expanded": false,
5314
       "root": "root"
5315
      }
5316
     },
5317
     "output_type": "display_data"
5318
    },
5319
    {
5320
     "name": "stderr",
5321
     "output_type": "stream",
5322
     "text": [
5323
      "100%|██████████| 177/177 [19:56<00:00,  6.76s/it]\n"
5324
     ]
5325
    }
5326
   ],
5327
   "source": [
5328
    "# Evaluate\n",
5329
    "experiment_name = \"prompt-ignore-contexts\"\n",
5330
    "run_experiment(\n",
5331
    "    experiment_name=experiment_name, \n",
5332
    "    chunk_size=CHUNK_SIZE, \n",
5333
    "    chunk_overlap=CHUNK_OVERLAP, \n",
5334
    "    num_chunks=NUM_CHUNKS,\n",
5335
    "    embedding_model_name=EMBEDDING_MODEL_NAME,\n",
5336
    "    embedding_dim=EMBEDDING_DIMENSIONS[EMBEDDING_MODEL_NAME],\n",
5337
    "    llm=LLM,\n",
5338
    "    evaluator=EVALUATOR,\n",
5339
    "    docs_dir=DOCS_DIR, \n",
5340
    "    experiments_dir=EXPERIMENTS_DIR, \n",
5341
    "    references_fp=REFERENCES_FILE_PATH,\n",
5342
    "    system_content=system_content,  # new prompt\n",
5343
    "    num_samples=NUM_SAMPLES)"
5344
   ]
5345
  },
5346
  {
5347
   "cell_type": "code",
5348
   "execution_count": null,
5349
   "id": "8913f69e-25d1-4354-aaa0-13fb46778c83",
5350
   "metadata": {
5351
    "tags": []
5352
   },
5353
   "outputs": [
5354
    {
5355
     "name": "stdout",
5356
     "output_type": "stream",
5357
     "text": [
5358
      "prompt-ignore-contexts\n",
5359
      "  retrieval score: 0.7796610169491526\n",
5360
      "  quality score: 3.8954802259887007\n",
5361
      "\n"
5362
     ]
5363
    },
5364
    {
5365
     "data": {
5366
      "text/plain": [
5367
       "{'retrieval_score': 0.7796610169491526, 'quality_score': 3.8954802259887007}"
5368
      ]
5369
     },
5370
     "execution_count": null,
5371
     "metadata": {},
5372
     "output_type": "execute_result"
5373
    }
5374
   ],
5375
   "source": [
5376
    "# Results\n",
5377
    "print_experiment(experiment_name, EXPERIMENTS_DIR)"
5378
   ]
5379
  },
5380
  {
5381
   "cell_type": "markdown",
5382
   "id": "2674a08a-1447-42e6-84bf-72777a660392",
5383
   "metadata": {},
5384
   "source": [
5385
    "It seems this specific prompt engineering effort does improve the quality of our system (knowing which context is relevant requires domain knowledge of Ray which the model may not have developed). But, as we mentioned earlier, there are too many other ways we can engineer our prompt and we encourage you to explore more. What’s important here is that we have a **clean and simple way to evaluate anything** that we want to experiment with. However, we have empirically found that improving the quality of our retrieval system and the data flywheel (where we fix our documentation itself) has had a much larger impact on the overall quality of our system."
5386
   ]
5387
  },
5388
  {
5389
   "cell_type": "code",
5390
   "execution_count": null,
5391
   "id": "9cd22d32-3eba-4a1e-a79c-e5f85cdd5114",
5392
   "metadata": {
5393
    "tags": []
5394
   },
5395
   "outputs": [],
5396
   "source": [
5397
    "SYSTEM_CONTENT = \"Answer the query using the context provided. Be succinct.\""
5398
   ]
5399
  },
5400
  {
5401
   "cell_type": "markdown",
5402
   "id": "6f3fcd85-01ac-4cd4-8dc1-cf864bb646ef",
5403
   "metadata": {
5404
    "tags": []
5405
   },
5406
   "source": [
5407
    "# Lexical search"
5408
   ]
5409
  },
5410
  {
5411
   "cell_type": "markdown",
5412
   "id": "63edccae-c7b5-4452-b97c-5d3e53b6a1cb",
5413
   "metadata": {},
5414
   "source": [
5415
    "We're going to now supplement our vector embedding based search with traditional lexical search, which searches for exact token matches between our query and document chunks. Our intuition here is that lexical search can help identify chunks with exact keyword matches where semantic representation may fail to capture. Especially for tokens that are out-of-vocabulary (and so represented via subtokens) with our embedding model. But our embeddings based approach is still very advantageous for capturing implicit meaning, and so we're going to combine several retrieval chunks from both vector embeddings based search and lexical search.\n",
5416
    "\n",
5417
    "<img width=\"800\" src=\"https://images.ctfassets.net/xjan103pcp94/9eBIE4iw7SmTtVvANbkAq/8913fcbd10fc66fd8b59278642155609/rag-based-llm-applications-lexical-search.png\">"
5418
   ]
5419
  },
5420
  {
5421
   "cell_type": "code",
5422
   "execution_count": null,
5423
   "id": "b29782e2-786a-4665-aa13-d9339ef3dad4",
5424
   "metadata": {
5425
    "tags": []
5426
   },
5427
   "outputs": [],
5428
   "source": [
5429
    "# Env vars\n",
5430
    "os.environ[\"EMBEDDING_DIM\"] = f\"{EMBEDDING_DIMENSIONS[EMBEDDING_MODEL_NAME]}\"\n",
5431
    "os.environ[\"SQL_DUMP_FP\"] = str(Path(EFS_DIR, \"sql_dumps\", f\"{EMBEDDING_MODEL_NAME.split('/')[-1]}_{CHUNK_SIZE}_{CHUNK_OVERLAP}.sql\"))"
5432
   ]
5433
  },
5434
  {
5435
   "cell_type": "code",
5436
   "execution_count": null,
5437
   "id": "7b85823f-f6fa-480a-8b24-0da188a55a2f",
5438
   "metadata": {
5439
    "tags": []
5440
   },
5441
   "outputs": [
5442
    {
5443
     "name": "stdout",
5444
     "output_type": "stream",
5445
     "text": [
5446
      "DROP TABLE\n",
5447
      "CREATE TABLE\n",
5448
      "SET\n",
5449
      "SET\n",
5450
      "SET\n",
5451
      "SET\n",
5452
      "SET\n",
5453
      " set_config \n",
5454
      "------------\n",
5455
      " \n",
5456
      "(1 row)\n",
5457
      "\n",
5458
      "SET\n",
5459
      "SET\n",
5460
      "SET\n",
5461
      "SET\n",
5462
      "ALTER TABLE\n",
5463
      "ALTER TABLE\n",
5464
      "ALTER TABLE\n",
5465
      "ALTER TABLE\n",
5466
      "DROP SEQUENCE\n",
5467
      "DROP TABLE\n",
5468
      "DROP SEQUENCE\n",
5469
      "DROP TABLE\n",
5470
      "DROP EXTENSION\n",
5471
      "CREATE EXTENSION\n",
5472
      "COMMENT\n",
5473
      "SET\n",
5474
      "SET\n",
5475
      "CREATE TABLE\n",
5476
      "ALTER TABLE\n",
5477
      "CREATE SEQUENCE\n",
5478
      "ALTER SEQUENCE\n",
5479
      "ALTER SEQUENCE\n",
5480
      "CREATE TABLE\n",
5481
      "ALTER TABLE\n",
5482
      "CREATE SEQUENCE\n",
5483
      "ALTER SEQUENCE\n",
5484
      "ALTER SEQUENCE\n",
5485
      "ALTER TABLE\n",
5486
      "ALTER TABLE\n",
5487
      "COPY 40433\n",
5488
      "COPY 14774\n",
5489
      " setval \n",
5490
      "--------\n",
5491
      "  40433\n",
5492
      "(1 row)\n",
5493
      "\n",
5494
      " setval \n",
5495
      "--------\n",
5496
      "  14774\n",
5497
      "(1 row)\n",
5498
      "\n",
5499
      "ALTER TABLE\n",
5500
      "ALTER TABLE\n"
5501
     ]
5502
    }
5503
   ],
5504
   "source": [
5505
    "%%bash\n",
5506
    "# Ensure the right index in built in-memory\n",
5507
    "psql \"$DB_CONNECTION_STRING\" -c \"DROP TABLE IF EXISTS document;\"  # drop\n",
5508
    "sudo -u postgres psql -f ../migrations/vector-${EMBEDDING_DIM}.sql  # set up\n",
5509
    "psql \"$DB_CONNECTION_STRING\" -f $SQL_DUMP_FP  # load"
5510
   ]
5511
  },
5512
  {
5513
   "cell_type": "code",
5514
   "execution_count": null,
5515
   "id": "c07fb08c-b395-4a77-9968-581d63aacca0",
5516
   "metadata": {
5517
    "tags": []
5518
   },
5519
   "outputs": [],
5520
   "source": [
5521
    "# Get chunks\n",
5522
    "with psycopg.connect(os.environ[\"DB_CONNECTION_STRING\"]) as conn:\n",
5523
    "    register_vector(conn)\n",
5524
    "    with conn.cursor() as cur:\n",
5525
    "        cur.execute(\"SELECT id, text, source FROM document\")\n",
5526
    "        chunks = cur.fetchall()"
5527
   ]
5528
  },
5529
  {
5530
   "cell_type": "markdown",
5531
   "id": "19e62598-3116-466e-8a60-fc571ce698df",
5532
   "metadata": {},
5533
   "source": [
5534
    "## BM25"
5535
   ]
5536
  },
5537
  {
5538
   "cell_type": "markdown",
5539
   "id": "65023148-7fd1-4383-b28c-b417103e716d",
5540
   "metadata": {},
5541
   "source": [
5542
    "Let's apply lexical search using [BM25](https://en.wikipedia.org/wiki/Okapi_BM25), which is a ranking algorithm that rewards unique token matches between our query and contexts."
5543
   ]
5544
  },
5545
  {
5546
   "cell_type": "code",
5547
   "execution_count": null,
5548
   "id": "206e3523-f651-425d-817a-ee359700cac9",
5549
   "metadata": {
5550
    "tags": []
5551
   },
5552
   "outputs": [],
5553
   "source": [
5554
    "import re\n",
5555
    "from rank_bm25 import BM25Okapi"
5556
   ]
5557
  },
5558
  {
5559
   "cell_type": "code",
5560
   "execution_count": null,
5561
   "id": "f05f3014-710c-4edd-b6c9-d94299a7d33e",
5562
   "metadata": {
5563
    "tags": []
5564
   },
5565
   "outputs": [],
5566
   "source": [
5567
    "# BM25 index\n",
5568
    "texts = [re.sub(r\"[^a-zA-Z0-9]\", \" \", chunk[1]).lower().split() for chunk in chunks]\n",
5569
    "lexical_index = BM25Okapi(texts)"
5570
   ]
5571
  },
5572
  {
5573
   "cell_type": "markdown",
5574
   "id": "fc5470e0-8430-4cf2-8b59-7f6c2674c25b",
5575
   "metadata": {},
5576
   "source": [
5577
    "Similar to our `semantic_search` function to retrieve the relevant context, we can implement a search function to use our lexical index to retrieve relevant context."
5578
   ]
5579
  },
5580
  {
5581
   "cell_type": "code",
5582
   "execution_count": null,
5583
   "id": "698222a6-2b4e-4b8d-bb64-551ae09595bb",
5584
   "metadata": {
5585
    "tags": []
5586
   },
5587
   "outputs": [],
5588
   "source": [
5589
    "def lexical_search(index, query, chunks, k):\n",
5590
    "    query_tokens = query.lower().split()  # preprocess query\n",
5591
    "    scores = index.get_scores(query_tokens)  # get best matching (BM) scores\n",
5592
    "    indices = sorted(range(len(scores)), key=lambda i: -scores[i])[:k]  # sort and get top k\n",
5593
    "    lexical_context = [{\n",
5594
    "            \"id\": chunks[i][0], \n",
5595
    "            \"text\": chunks[i][1], \n",
5596
    "            \"source\": chunks[i][2], \n",
5597
    "            \"score\": scores[i]} for i in indices]\n",
5598
    "    return lexical_context"
5599
   ]
5600
  },
5601
  {
5602
   "cell_type": "code",
5603
   "execution_count": null,
5604
   "id": "ca7734e4-ea44-483c-93ed-e7e6f646caf8",
5605
   "metadata": {
5606
    "tags": []
5607
   },
5608
   "outputs": [
5609
    {
5610
     "name": "stdout",
5611
     "output_type": "stream",
5612
     "text": [
5613
      "https://docs.ray.io/en/master/data/transforming-data.html#configuring-batch-size\n",
5614
      "Configuring batch size#\n",
5615
      "Increasing batch_size improves the performance of vectorized transformations like\n",
5616
      "NumPy functions and model inference. However, if your batch size is too large, your\n",
5617
      "program might run out of memory. If you encounter an out-of-memory error, decrease your\n",
5618
      "batch_size.\n",
5619
      "Note\n",
5620
      "The default batch size depends on your resource type. If you’re using CPUs,\n",
5621
      "the default batch size is 4096. If you’re using GPUs, you must specify an explicit\n",
5622
      "batch size.\n",
5623
      "\n",
5624
      "https://docs.ray.io/en/master/data/api/doc/ray.data.Dataset.map_batches.html#ray-data-dataset-map-batches\n",
5625
      "fn – The function or generator to apply to a record batch, or a class type\n",
5626
      "that can be instantiated to create such a callable. Callable classes are\n",
5627
      "only supported for the actor compute strategy. Note fn must be\n",
5628
      "pickle-able.\n",
5629
      "batch_size – The desired number of rows in each batch, or None to use\n",
5630
      "entire blocks as batches (blocks may contain different numbers of rows).\n",
5631
      "The actual size of the batch provided to fn may be smaller than\n",
5632
      "batch_size if batch_size doesn’t evenly divide the block(s) sent\n",
5633
      "to a given map task. Default batch_size is 4096 with “default”.\n",
5634
      "compute – Either “tasks” (default) to use Ray Tasks or an\n",
5635
      "ActorPoolStrategy to use an autoscaling actor pool.\n",
5636
      "\n",
5637
      "https://docs.ray.io/en/master/data/api/doc/ray.data.Dataset.take_batch.html#ray-data-dataset-take-batch\n",
5638
      "ray.data.Dataset.take_batch#\n",
5639
      "Dataset.take_batch(batch_size: int = 20, *, batch_format: Optional[str] = 'default') → Union[pyarrow.Table, pandas.DataFrame, Dict[str, numpy.ndarray]][source]#\n",
5640
      "Return up to batch_size rows from the Dataset in a batch.\n",
5641
      "Ray Data represents batches as NumPy arrays or pandas DataFrames. You can\n",
5642
      "configure the batch type by specifying batch_format.\n",
5643
      "This method is useful for inspecting inputs to map_batches().\n",
5644
      "\n",
5645
      "Warning\n",
5646
      "take_batch() moves up to batch_size rows to the caller’s\n",
5647
      "machine. If batch_size is large, this method can cause an `\n",
5648
      "OutOfMemory error on the caller.\n",
5649
      "\n",
5650
      "\n",
5651
      "Note\n",
5652
      "This operation will trigger execution of the lazy transformations performed on this dataset.\n",
5653
      "\n"
5654
     ]
5655
    }
5656
   ],
5657
   "source": [
5658
    "# Retrieve top-k docs\n",
5659
    "k = 3\n",
5660
    "query = \"What is the default batch size for map_batches?\"\n",
5661
    "top_docs = lexical_search(lexical_index, query, chunks, k=k)\n",
5662
    "for item in top_docs:\n",
5663
    "    print (item[\"source\"])\n",
5664
    "    print (item[\"text\"])\n",
5665
    "    print ()"
5666
   ]
5667
  },
5668
  {
5669
   "cell_type": "markdown",
5670
   "id": "6c884bbc-76e5-40ed-b7d5-2a5733fc9c86",
5671
   "metadata": {},
5672
   "source": [
5673
    "## Semantic"
5674
   ]
5675
  },
5676
  {
5677
   "cell_type": "markdown",
5678
   "id": "d9785a5f-7e62-4727-9dda-d038d640f105",
5679
   "metadata": {},
5680
   "source": [
5681
    "Comparing this with the retrieved sources with our existing vector embedding based search shows that the two approaches, while different, both retrieved relevant sources. So, we're going to combine both approaches and feed it into the context for our LLM for generation."
5682
   ]
5683
  },
5684
  {
5685
   "cell_type": "code",
5686
   "execution_count": null,
5687
   "id": "6711aa7b-7e2a-43d2-9043-e03dc7a32ba0",
5688
   "metadata": {
5689
    "tags": []
5690
   },
5691
   "outputs": [
5692
    {
5693
     "data": {
5694
      "text/plain": [
5695
       "1024"
5696
      ]
5697
     },
5698
     "execution_count": null,
5699
     "metadata": {},
5700
     "output_type": "execute_result"
5701
    }
5702
   ],
5703
   "source": [
5704
    "# Embed query\n",
5705
    "embedding_model = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)\n",
5706
    "embedding = np.array(embedding_model.embed_query(query))\n",
5707
    "len(embedding)"
5708
   ]
5709
  },
5710
  {
5711
   "cell_type": "code",
5712
   "execution_count": null,
5713
   "id": "7afafafe-5df4-4087-8f28-bd79d6d931e2",
5714
   "metadata": {
5715
    "tags": []
5716
   },
5717
   "outputs": [],
5718
   "source": [
5719
    "# Get context\n",
5720
    "with psycopg.connect(os.environ[\"DB_CONNECTION_STRING\"]) as conn:\n",
5721
    "    register_vector(conn)\n",
5722
    "    with conn.cursor() as cur:\n",
5723
    "        cur.execute(\"SELECT * FROM document ORDER BY embedding <=> %s LIMIT %s\", (embedding, k))\n",
5724
    "        rows = cur.fetchall()\n",
5725
    "        context = [{\"text\": row[1]} for row in rows]\n",
5726
    "        sources = [row[2] for row in rows]"
5727
   ]
5728
  },
5729
  {
5730
   "cell_type": "code",
5731
   "execution_count": null,
5732
   "id": "1698b140-f3a3-4026-ad4a-c378192dfc47",
5733
   "metadata": {
5734
    "tags": []
5735
   },
5736
   "outputs": [
5737
    {
5738
     "name": "stdout",
5739
     "output_type": "stream",
5740
     "text": [
5741
      "https://docs.ray.io/en/master/data/batch_inference.html#configuring-batch-size\n",
5742
      "# Specify that each input batch should be of size 2.\n",
5743
      "ds.map_batches(assert_batch, batch_size=2)\n",
5744
      "Caution\n",
5745
      "The default batch_size of 4096 may be too large for datasets with large rows\n",
5746
      "(for example, tables with many columns or a collection of large images).\n",
5747
      "\n",
5748
      "https://docs.ray.io/en/master/data/transforming-data.html#configuring-batch-size\n",
5749
      "Configuring batch size#\n",
5750
      "Increasing batch_size improves the performance of vectorized transformations like\n",
5751
      "NumPy functions and model inference. However, if your batch size is too large, your\n",
5752
      "program might run out of memory. If you encounter an out-of-memory error, decrease your\n",
5753
      "batch_size.\n",
5754
      "Note\n",
5755
      "The default batch size depends on your resource type. If you’re using CPUs,\n",
5756
      "the default batch size is 4096. If you’re using GPUs, you must specify an explicit\n",
5757
      "batch size.\n",
5758
      "\n",
5759
      "https://docs.ray.io/en/master/data/examples/pytorch_resnet_batch_prediction.html#model-inference\n",
5760
      "}\n",
5761
      "Then we use the map_batches() API to apply the model to the whole dataset.\n",
5762
      "The first parameter of map_batches is the user-defined function (UDF), which can either be a function or a class. Since we are using a class in this case, the UDF will run as long-running Ray actors. For class-based UDFs, we use the compute argument to specify ActorPoolStrategy with the number of parallel actors.\n",
5763
      "The batch_size argument indicates the number of images in each batch. See the Ray dashboard\n",
5764
      "for GPU memory usage to experiment with the batch_size when using your own model and dataset.\n",
5765
      "You should aim to max out the batch size without running out of GPU memory.\n",
5766
      "\n"
5767
     ]
5768
    }
5769
   ],
5770
   "source": [
5771
    "for i, item in enumerate(context):\n",
5772
    "    print (sources[i])\n",
5773
    "    print (item[\"text\"])\n",
5774
    "    print ()"
5775
   ]
5776
  },
5777
  {
5778
   "cell_type": "markdown",
5779
   "id": "d100290c-fdaa-4d57-877f-01f621698433",
5780
   "metadata": {},
5781
   "source": [
5782
    "## Experiment"
5783
   ]
5784
  },
5785
  {
5786
   "cell_type": "markdown",
5787
   "id": "c76aebae-f7ae-469c-bfcf-319b588d78ae",
5788
   "metadata": {},
5789
   "source": [
5790
    "Now let's incorporate this into our retrieval workflow by add it to our `generate.py/QueryAgent` class. The main change will be to include the additional sources from lexical search:\n",
5791
    "```python\n",
5792
    "def QueryAgent():\n",
5793
    "    def __call__(lexical_search_k=1, **kwargs):\n",
5794
    "        # Add lexical search results\n",
5795
    "        if self.lexical_index:\n",
5796
    "            lexical_context = lexical_search(\n",
5797
    "                index=self.lexical_index, query=query, chunks=self.chunks, k=lexical_search_k)\n",
5798
    "            # Insert after <lexical_search_k> worth of semantic results\n",
5799
    "            context_results[lexical_search_k:lexical_search_k] = lexical_context\n",
5800
    "```"
5801
   ]
5802
  },
5803
  {
5804
   "cell_type": "code",
5805
   "execution_count": null,
5806
   "id": "95e2cdf8-4c53-414c-b302-990fc79b25e1",
5807
   "metadata": {
5808
    "tags": []
5809
   },
5810
   "outputs": [],
5811
   "source": [
5812
    "experiment_names = []\n",
5813
    "lexical_search_k_list = [1, 3, 5]"
5814
   ]
5815
  },
5816
  {
5817
   "cell_type": "code",
5818
   "execution_count": null,
5819
   "id": "f8ef2b7f-abf5-4604-9bf3-faa1d6e3d8d8",
5820
   "metadata": {
5821
    "tags": []
5822
   },
5823
   "outputs": [
5824
    {
5825
     "data": {
5826
      "application/json": {
5827
       "generated_answer": " To kill a specific serve replica, you can use the following steps:\n\n1. Get the actor handle for the replica using `ray.get_actor()` with the NAME of the replica from the output of `ray list actors`.\n2. Use `ray.kill()` to kill the replica, passing the replica handle as the first argument and setting `no_restart=True`.\n\nHere is an example:\n\n```python\nimport ray\n\n# Get the actor handle for the replica\nreplica_handle = ray.get_actor(\"SERVE_REPLICA::SleepyPid#RlRptP\", namespace=\"serve\")\n\n# Kill the replica\nray.kill(replica_handle, no_restart=True)\n```\n\nThis will kill the replica and prevent it from being restarted. Other replicas will continue processing requests, and the killed replica will eventually be restarted and resume serving requests.",
5828
       "question": "how do I kill a specific serve replica",
5829
       "reasoning": "The generated answer is detailed, accurate, and provides a step-by-step guide on how to kill a specific serve replica. It even includes a Python code example for better understanding. Despite the reference answer stating that there's no information available, the generated answer provides a comprehensive response to the query.",
5830
       "reference_answer": "The context does not provide information on how to kill a specific serve replica.",
5831
       "score": 5,
5832
       "sources": [
5833
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
5834
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
5835
        "https://docs.ray.io/en/master/serve/advanced-guides/dev-workflow.html#local-development-with-http-requests",
5836
        "https://docs.ray.io/en/master/serve/api/index.html#delete-api-serve-deployments",
5837
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
5838
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
5839
        "https://docs.ray.io/en/master/tune/faq.html#ray-tune-faq",
5840
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#replica-health-checking",
5841
        "https://docs.ray.io/en/master/tune/faq.html#ray-tune-faq",
5842
        "https://docs.ray.io/en/master/serve/api/doc/ray.serve.get_replica_context.html#ray-serve-get-replica-context",
5843
        "https://docs.ray.io/en/master/serve/scaling-and-resource-allocation.html#autoscaling",
5844
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#worker-node-failure",
5845
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#httpproxy-failure",
5846
        "https://docs.ray.io/en/master/serve/api/index.html#delete-api-serve-applications",
5847
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#replica-health-checking",
5848
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#httpproxy-failure",
5849
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#serve-controller-failure",
5850
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#head-node-failure"
5851
       ]
5852
      },
5853
      "text/plain": [
5854
       "<IPython.core.display.JSON object>"
5855
      ]
5856
     },
5857
     "metadata": {
5858
      "application/json": {
5859
       "expanded": false,
5860
       "root": "root"
5861
      }
5862
     },
5863
     "output_type": "display_data"
5864
    },
5865
    {
5866
     "name": "stderr",
5867
     "output_type": "stream",
5868
     "text": [
5869
      "100%|██████████| 177/177 [19:50<00:00,  6.72s/it]\n"
5870
     ]
5871
    }
5872
   ],
5873
   "source": [
5874
    "# Experiment\n",
5875
    "use_lexical_search = True\n",
5876
    "for lexical_search_k in lexical_search_k_list:\n",
5877
    "    experiment_name = f\"lexical-search-bm25-{lexical_search_k}\"\n",
5878
    "    experiment_names.append(experiment_name)\n",
5879
    "    run_experiment(\n",
5880
    "        experiment_name=experiment_name, \n",
5881
    "        chunk_size=CHUNK_SIZE, \n",
5882
    "        chunk_overlap=CHUNK_OVERLAP, \n",
5883
    "        num_chunks=NUM_CHUNKS,\n",
5884
    "        embedding_model_name=EMBEDDING_MODEL_NAME,\n",
5885
    "        embedding_dim=EMBEDDING_DIMENSIONS[EMBEDDING_MODEL_NAME],\n",
5886
    "        llm=LLM,\n",
5887
    "        evaluator=EVALUATOR,\n",
5888
    "        docs_dir=DOCS_DIR, \n",
5889
    "        experiments_dir=EXPERIMENTS_DIR, \n",
5890
    "        references_fp=REFERENCES_FILE_PATH,\n",
5891
    "        system_content=SYSTEM_CONTENT,\n",
5892
    "        use_lexical_search=use_lexical_search,\n",
5893
    "        lexical_search_k=lexical_search_k,\n",
5894
    "        num_samples=NUM_SAMPLES)"
5895
   ]
5896
  },
5897
  {
5898
   "cell_type": "code",
5899
   "execution_count": null,
5900
   "id": "aecb240f-dc95-48ab-bec9-cfea285467c2",
5901
   "metadata": {
5902
    "tags": []
5903
   },
5904
   "outputs": [
5905
    {
5906
     "name": "stdout",
5907
     "output_type": "stream",
5908
     "text": [
5909
      "lexical-search-bm25-1\n",
5910
      "  retrieval score: 0.7853107344632768\n",
5911
      "  quality score: 4.019774011299435\n",
5912
      "\n",
5913
      "lexical-search-bm25-3\n",
5914
      "  retrieval score: 0.7966101694915254\n",
5915
      "  quality score: 3.9322033898305087\n",
5916
      "\n",
5917
      "lexical-search-bm25-5\n",
5918
      "  retrieval score: 0.8022598870056498\n",
5919
      "  quality score: 3.8954802259887007\n",
5920
      "\n"
5921
     ]
5922
    },
5923
    {
5924
     "data": {
5925
      "image/png": "",
5926
      "text/plain": [
5927
       "<Figure size 1000x300 with 1 Axes>"
5928
      ]
5929
     },
5930
     "metadata": {},
5931
     "output_type": "display_data"
5932
    }
5933
   ],
5934
   "source": [
5935
    "scores = {}\n",
5936
    "for experiment_name in experiment_names:\n",
5937
    "    scores[experiment_name] = print_experiment(experiment_name, EXPERIMENTS_DIR)\n",
5938
    "plot_scores(scores=scores)"
5939
   ]
5940
  },
5941
  {
5942
   "cell_type": "code",
5943
   "execution_count": null,
5944
   "id": "f2998e54-4bd2-432d-8948-ac57abe31640",
5945
   "metadata": {
5946
    "tags": []
5947
   },
5948
   "outputs": [],
5949
   "source": [
5950
    "USE_LEXICAL_SEARCH = True\n",
5951
    "LEXICAL_SEARCH_K = 1"
5952
   ]
5953
  },
5954
  {
5955
   "cell_type": "markdown",
5956
   "id": "743b6854-6a92-4201-884f-7bf83ac6bfe6",
5957
   "metadata": {},
5958
   "source": [
5959
    "It looks like lexical search improved our retrieval scores, which was expected since more chunks are added into the context. But it didn't really help improve our quality scores, perhaps the additional chunks added also introduced quite a bit of noise.\n",
5960
    "\n",
5961
    "**Note**: This was just one aspect (keyword matching) of lexical search that we explored but there are many other useful features such as filtering, counts, etc. It's also worth exploring how we combine the lexical search results with semantic search results."
5962
   ]
5963
  },
5964
  {
5965
   "cell_type": "markdown",
5966
   "id": "140e2b7f-f8a1-4d9c-a479-f9eefbd21451",
5967
   "metadata": {
5968
    "tags": []
5969
   },
5970
   "source": [
5971
    "# Reranking"
5972
   ]
5973
  },
5974
  {
5975
   "cell_type": "markdown",
5976
   "id": "bd99d5da-b9ee-40e2-99b8-3b43f2820618",
5977
   "metadata": {},
5978
   "source": [
5979
    "So far with all of our approaches, we've used an embedding model (+ lexical search) to identify the top k relevant chunks in our dataset. The number of chunks (k) has been a small number because we found that adding too many chunks did not help and our LLMs have restricted context lengths. However, this was all under the assumption that the top k retrieved chunks were truly the most relevant chunks and that their order was correct as well. What if increasing the number of chunks didn't help because some relevant chunks were much lower in the ordered list. And, semantic representations, while very rich, were not trained for this specific task. \n",
5980
    "\n",
5981
    "In this section, we'll implement reranking so that we can use our semantic and lexical search methods to cast a much wider net over our dataset (retrieve many chunks) and then rerank the order based on the user's query. The intuition here is that we can account for gaps in our semantic representations with ranking specific to our use case. We'll train a supervised model that predicts which part of our [documentation](https://docs.ray.io/) is most relevant for a given user's query. We'll use this prediction to then rerank the relevant chunks so that chunks from this part of our documentation are moved to the top of the list.\n",
5982
    "\n",
5983
    "**Note**: We also experimented with [cross-encoders](https://www.sbert.net/docs/pretrained_cross-encoders.html) which processes the query and the relevant contexts together with the same model. This allows for a more contextual representation compared to cosine distance but it's also more computationally expensive. So we followed a similar approach of using the similarity distance first to extract many chunks and then use the cross-encoder to rerank and choose the top k chunks. Unfortunately, this approach didn't improve our quality as the technique below does but it's worth fine-tuning our cross-encoder using our synthetic embedding QA dataset.\n",
5984
    "\n",
5985
    "<img width=\"800\" src=\"https://images.ctfassets.net/xjan103pcp94/4bmoRNSzxtOyfToCtl68xq/d9727c41a3d435d1821eea5ab67c1e97/rag-based-llm-applications-reranking.png\">"
5986
   ]
5987
  },
5988
  {
5989
   "cell_type": "markdown",
5990
   "id": "da391621-c587-4453-b365-6f25a049b3e3",
5991
   "metadata": {
5992
    "tags": []
5993
   },
5994
   "source": [
5995
    "## Dataset"
5996
   ]
5997
  },
5998
  {
5999
   "cell_type": "markdown",
6000
   "id": "1eba3ee9-fe06-4a2d-998f-d51cdd075cd7",
6001
   "metadata": {},
6002
   "source": [
6003
    "We're going to reuse the QA dataset we created in our fine-tuning section because that dataset has questions that map with specific sections."
6004
   ]
6005
  },
6006
  {
6007
   "cell_type": "code",
6008
   "execution_count": null,
6009
   "id": "83c3d888-1f71-4561-8d9c-787ae03a5e3b",
6010
   "metadata": {
6011
    "tags": []
6012
   },
6013
   "outputs": [],
6014
   "source": [
6015
    "from collections import Counter\n",
6016
    "import pandas as pd\n",
6017
    "from sklearn.model_selection import train_test_split\n",
6018
    "from tqdm import tqdm"
6019
   ]
6020
  },
6021
  {
6022
   "cell_type": "code",
6023
   "execution_count": null,
6024
   "id": "964f649e-8ba2-4866-b6a5-af5df2d1e17f",
6025
   "metadata": {
6026
    "tags": []
6027
   },
6028
   "outputs": [],
6029
   "source": [
6030
    "def get_tag(url):\n",
6031
    "    return re.findall(r\"docs\\.ray\\.io/en/master/([^/]+)\", url)[0].split(\"#\")[0]"
6032
   ]
6033
  },
6034
  {
6035
   "cell_type": "code",
6036
   "execution_count": null,
6037
   "id": "9f82ad86-e310-4ba4-b351-c513cf78c0ab",
6038
   "metadata": {
6039
    "tags": []
6040
   },
6041
   "outputs": [
6042
    {
6043
     "data": {
6044
      "text/html": [
6045
       "<div>\n",
6046
       "<style scoped>\n",
6047
       "    .dataframe tbody tr th:only-of-type {\n",
6048
       "        vertical-align: middle;\n",
6049
       "    }\n",
6050
       "\n",
6051
       "    .dataframe tbody tr th {\n",
6052
       "        vertical-align: top;\n",
6053
       "    }\n",
6054
       "\n",
6055
       "    .dataframe thead th {\n",
6056
       "        text-align: right;\n",
6057
       "    }\n",
6058
       "</style>\n",
6059
       "<table border=\"1\" class=\"dataframe\">\n",
6060
       "  <thead>\n",
6061
       "    <tr style=\"text-align: right;\">\n",
6062
       "      <th></th>\n",
6063
       "      <th>question</th>\n",
6064
       "      <th>source</th>\n",
6065
       "      <th>tag</th>\n",
6066
       "      <th>section</th>\n",
6067
       "      <th>text</th>\n",
6068
       "    </tr>\n",
6069
       "  </thead>\n",
6070
       "  <tbody>\n",
6071
       "    <tr>\n",
6072
       "      <th>3170</th>\n",
6073
       "      <td>What is the purpose of the RayTrainReportCall...</td>\n",
6074
       "      <td>https://docs.ray.io/en/master/train/api/doc/ra...</td>\n",
6075
       "      <td>train</td>\n",
6076
       "      <td>ray.train.lightning.RayTrainReportCallback.tea...</td>\n",
6077
       "      <td>ray.train.lightning.RayTrainReportCallback.tea...</td>\n",
6078
       "    </tr>\n",
6079
       "    <tr>\n",
6080
       "      <th>3099</th>\n",
6081
       "      <td>What is the approximate number of rows that w...</td>\n",
6082
       "      <td>https://docs.ray.io/en/master/data/api/doc/ray...</td>\n",
6083
       "      <td>data</td>\n",
6084
       "      <td>ray.data.Dataset.random_sample.html#ray-data-d...</td>\n",
6085
       "      <td>ray.data.Dataset.random_sample.html#ray-data-d...</td>\n",
6086
       "    </tr>\n",
6087
       "    <tr>\n",
6088
       "      <th>1216</th>\n",
6089
       "      <td>What is the purpose of the 'value' parameter ...</td>\n",
6090
       "      <td>https://docs.ray.io/en/master/ray-core/api/doc...</td>\n",
6091
       "      <td>ray-core</td>\n",
6092
       "      <td>ray.runtime_env.RuntimeEnvConfig.fromkeys.html...</td>\n",
6093
       "      <td>ray.runtime_env.RuntimeEnvConfig.fromkeys.html...</td>\n",
6094
       "    </tr>\n",
6095
       "    <tr>\n",
6096
       "      <th>5215</th>\n",
6097
       "      <td>What is the purpose of the 'returns_to_go' va...</td>\n",
6098
       "      <td>https://docs.ray.io/en/master/rllib/package_re...</td>\n",
6099
       "      <td>rllib</td>\n",
6100
       "      <td>ray.rllib.policy.sample_batch.SampleBatch.RETU...</td>\n",
6101
       "      <td>ray.rllib.policy.sample_batch.SampleBatch.RETU...</td>\n",
6102
       "    </tr>\n",
6103
       "    <tr>\n",
6104
       "      <th>5107</th>\n",
6105
       "      <td>What is the purpose of the \"full_fetch\" param...</td>\n",
6106
       "      <td>https://docs.ray.io/en/master/rllib/package_re...</td>\n",
6107
       "      <td>rllib</td>\n",
6108
       "      <td>ray.rllib.algorithms.algorithm.Algorithm.compu...</td>\n",
6109
       "      <td>ray.rllib.algorithms.algorithm.Algorithm.compu...</td>\n",
6110
       "    </tr>\n",
6111
       "  </tbody>\n",
6112
       "</table>\n",
6113
       "</div>"
6114
      ],
6115
      "text/plain": [
6116
       "                                               question  \\\n",
6117
       "3170   What is the purpose of the RayTrainReportCall...   \n",
6118
       "3099   What is the approximate number of rows that w...   \n",
6119
       "1216   What is the purpose of the 'value' parameter ...   \n",
6120
       "5215   What is the purpose of the 'returns_to_go' va...   \n",
6121
       "5107   What is the purpose of the \"full_fetch\" param...   \n",
6122
       "\n",
6123
       "                                                 source       tag  \\\n",
6124
       "3170  https://docs.ray.io/en/master/train/api/doc/ra...     train   \n",
6125
       "3099  https://docs.ray.io/en/master/data/api/doc/ray...      data   \n",
6126
       "1216  https://docs.ray.io/en/master/ray-core/api/doc...  ray-core   \n",
6127
       "5215  https://docs.ray.io/en/master/rllib/package_re...     rllib   \n",
6128
       "5107  https://docs.ray.io/en/master/rllib/package_re...     rllib   \n",
6129
       "\n",
6130
       "                                                section  \\\n",
6131
       "3170  ray.train.lightning.RayTrainReportCallback.tea...   \n",
6132
       "3099  ray.data.Dataset.random_sample.html#ray-data-d...   \n",
6133
       "1216  ray.runtime_env.RuntimeEnvConfig.fromkeys.html...   \n",
6134
       "5215  ray.rllib.policy.sample_batch.SampleBatch.RETU...   \n",
6135
       "5107  ray.rllib.algorithms.algorithm.Algorithm.compu...   \n",
6136
       "\n",
6137
       "                                                   text  \n",
6138
       "3170  ray.train.lightning.RayTrainReportCallback.tea...  \n",
6139
       "3099  ray.data.Dataset.random_sample.html#ray-data-d...  \n",
6140
       "1216  ray.runtime_env.RuntimeEnvConfig.fromkeys.html...  \n",
6141
       "5215  ray.rllib.policy.sample_batch.SampleBatch.RETU...  \n",
6142
       "5107  ray.rllib.algorithms.algorithm.Algorithm.compu...  "
6143
      ]
6144
     },
6145
     "execution_count": null,
6146
     "metadata": {},
6147
     "output_type": "execute_result"
6148
    }
6149
   ],
6150
   "source": [
6151
    "# Load data\n",
6152
    "from pathlib import Path\n",
6153
    "df = pd.read_json(Path(ROOT_DIR, \"datasets\", \"embedding_qa.json\"))\n",
6154
    "df[\"tag\"] = df.source.map(get_tag)\n",
6155
    "df[\"section\"] = df.source.map(lambda source: source.split(\"/\")[-1])\n",
6156
    "df[\"text\"] = df[\"section\"] + \" \" + df[\"question\"]\n",
6157
    "df.sample(n=5)"
6158
   ]
6159
  },
6160
  {
6161
   "cell_type": "code",
6162
   "execution_count": null,
6163
   "id": "d264d8c6-83df-4eb2-90b8-aaec0b9c71e0",
6164
   "metadata": {
6165
    "tags": []
6166
   },
6167
   "outputs": [
6168
    {
6169
     "data": {
6170
      "text/plain": [
6171
       "Counter({'rllib': 1269,\n",
6172
       "         'tune': 979,\n",
6173
       "         'train': 697,\n",
6174
       "         'cluster': 690,\n",
6175
       "         'data': 652,\n",
6176
       "         'ray-core': 557,\n",
6177
       "         'serve': 302,\n",
6178
       "         'ray-observability': 175,\n",
6179
       "         'ray-contribute': 95,\n",
6180
       "         'workflows': 82,\n",
6181
       "         'ray-air': 74,\n",
6182
       "         'ray-more-libs': 66,\n",
6183
       "         'ray-overview': 46,\n",
6184
       "         'rllib-env.html': 17,\n",
6185
       "         'installation.html': 16,\n",
6186
       "         'tune.html': 5,\n",
6187
       "         'joblib.html': 3,\n",
6188
       "         'ray-references': 2})"
6189
      ]
6190
     },
6191
     "execution_count": null,
6192
     "metadata": {},
6193
     "output_type": "execute_result"
6194
    }
6195
   ],
6196
   "source": [
6197
    "Counter(df.tag)"
6198
   ]
6199
  },
6200
  {
6201
   "cell_type": "code",
6202
   "execution_count": null,
6203
   "id": "4bca87cb-dbea-47f7-9d0a-e2c84d1ff75b",
6204
   "metadata": {
6205
    "tags": []
6206
   },
6207
   "outputs": [
6208
    {
6209
     "data": {
6210
      "text/plain": [
6211
       "Counter({'rllib': 1269,\n",
6212
       "         'tune': 979,\n",
6213
       "         'train': 697,\n",
6214
       "         'cluster': 690,\n",
6215
       "         'data': 652,\n",
6216
       "         'ray-core': 557,\n",
6217
       "         'other': 406,\n",
6218
       "         'serve': 302,\n",
6219
       "         'ray-observability': 175})"
6220
      ]
6221
     },
6222
     "execution_count": null,
6223
     "metadata": {},
6224
     "output_type": "execute_result"
6225
    }
6226
   ],
6227
   "source": [
6228
    "# Map only what we want to keep\n",
6229
    "tags_to_keep = [\"rllib\", \"tune\", \"train\", \"cluster\", \"ray-core\", \"data\", \"serve\", \"ray-observability\"]\n",
6230
    "df[\"tag\"] = df.tag.apply(lambda x: x if x in tags_to_keep else \"other\")\n",
6231
    "Counter(df.tag)"
6232
   ]
6233
  },
6234
  {
6235
   "cell_type": "code",
6236
   "execution_count": null,
6237
   "id": "f31817a0-7bc7-42fe-8923-b3649d965d96",
6238
   "metadata": {
6239
    "tags": []
6240
   },
6241
   "outputs": [],
6242
   "source": [
6243
    "# Train and test data splits\n",
6244
    "test_size = 0.2\n",
6245
    "train_df, test_df = train_test_split(df, stratify=df.tag, test_size=test_size, random_state=1234)"
6246
   ]
6247
  },
6248
  {
6249
   "cell_type": "markdown",
6250
   "id": "da73dbc5-bb15-414a-9df0-57d25745e960",
6251
   "metadata": {},
6252
   "source": [
6253
    "## Preprocessing"
6254
   ]
6255
  },
6256
  {
6257
   "cell_type": "markdown",
6258
   "id": "1dbbbf23-9172-4865-8785-e248e523405a",
6259
   "metadata": {},
6260
   "source": [
6261
    "We'll start by creating some preprocessing functions to better represent our data. For example, our documentation has many variables that are camel cased (ex. `RayDeepSpeedStrategy`). When a tokenizer is used on this, we often lose the individual tokens that we know to be useful and, instead, random subtokens are created.\n",
6262
    "\n",
6263
    "**Note**: we didn't omnisciently know to create these unique preprocessing functions! This is all a result of methodical iteration. We train a model → view incorrect data points → view how the data was represented (ex. subtokenization) → update preprocessing → iterate ↺"
6264
   ]
6265
  },
6266
  {
6267
   "cell_type": "code",
6268
   "execution_count": null,
6269
   "id": "22c56a71-fad6-4393-bee8-fe44f835cbf1",
6270
   "metadata": {
6271
    "tags": []
6272
   },
6273
   "outputs": [],
6274
   "source": [
6275
    "import re\n",
6276
    "from transformers import BertTokenizer"
6277
   ]
6278
  },
6279
  {
6280
   "cell_type": "code",
6281
   "execution_count": null,
6282
   "id": "6fe5a48a-4a1a-43c6-9a69-a72fa03f357f",
6283
   "metadata": {
6284
    "tags": []
6285
   },
6286
   "outputs": [],
6287
   "source": [
6288
    "def split_camel_case_in_sentences(sentences):\n",
6289
    "    def split_camel_case_word(word):\n",
6290
    "        return re.sub(\"([a-z0-9])([A-Z])\", r\"\\1 \\2\", word)\n",
6291
    "    processed_sentences = []\n",
6292
    "    for sentence in sentences:\n",
6293
    "        processed_words = []   \n",
6294
    "        for word in sentence.split():\n",
6295
    "            processed_words.extend(split_camel_case_word(word).split())\n",
6296
    "        processed_sentences.append(\" \".join(processed_words))\n",
6297
    "    return processed_sentences"
6298
   ]
6299
  },
6300
  {
6301
   "cell_type": "code",
6302
   "execution_count": null,
6303
   "id": "a3485909-e074-4796-bf4b-ed388f56c61f",
6304
   "metadata": {
6305
    "tags": []
6306
   },
6307
   "outputs": [],
6308
   "source": [
6309
    "# Tokenizer\n",
6310
    "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")"
6311
   ]
6312
  },
6313
  {
6314
   "cell_type": "code",
6315
   "execution_count": null,
6316
   "id": "9ddbe8bf-a974-4d84-b452-d15ba39e762f",
6317
   "metadata": {
6318
    "tags": []
6319
   },
6320
   "outputs": [],
6321
   "source": [
6322
    "def preprocess(texts):\n",
6323
    "    texts = [re.sub(r'(?<=\\w)([?.,!])(?!\\s)', r' \\1', text) for text in texts]\n",
6324
    "    texts = [text.replace(\"_\", \" \").replace(\"-\", \" \").replace(\"#\", \" \").replace(\".html\", \"\").replace(\".\", \" \") for text in texts]\n",
6325
    "    texts = split_camel_case_in_sentences(texts)  # camelcase\n",
6326
    "    texts = [tokenizer.tokenize(text) for text in texts]  # subtokens\n",
6327
    "    texts = [\" \".join(word for word in text) for text in texts]\n",
6328
    "    return texts"
6329
   ]
6330
  },
6331
  {
6332
   "cell_type": "code",
6333
   "execution_count": null,
6334
   "id": "0d610206-5376-4cf4-af50-e5121a14c45a",
6335
   "metadata": {
6336
    "tags": []
6337
   },
6338
   "outputs": [
6339
    {
6340
     "name": "stdout",
6341
     "output_type": "stream",
6342
     "text": [
6343
      "['ray deep speed strategy']\n",
6344
      "['what is the default batch size for map batch ##es ?']\n"
6345
     ]
6346
    }
6347
   ],
6348
   "source": [
6349
    "print (preprocess([\"RayDeepSpeedStrategy\"]))\n",
6350
    "print (preprocess([\"What is the default batch_size for map_batches?\"]))"
6351
   ]
6352
  },
6353
  {
6354
   "cell_type": "markdown",
6355
   "id": "ea43347c-f025-4933-b358-6ad797e3d8e6",
6356
   "metadata": {
6357
    "tags": []
6358
   },
6359
   "source": [
6360
    "## Model"
6361
   ]
6362
  },
6363
  {
6364
   "cell_type": "code",
6365
   "execution_count": null,
6366
   "id": "d9b6b234-4868-4b76-8f30-483513c59137",
6367
   "metadata": {
6368
    "tags": []
6369
   },
6370
   "outputs": [],
6371
   "source": [
6372
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
6373
    "from sklearn.linear_model import LogisticRegression\n",
6374
    "from sklearn.pipeline import Pipeline\n",
6375
    "from sklearn.preprocessing import FunctionTransformer"
6376
   ]
6377
  },
6378
  {
6379
   "cell_type": "code",
6380
   "execution_count": null,
6381
   "id": "be4f8194-5018-427d-86ed-bd1536cfcb58",
6382
   "metadata": {
6383
    "tags": []
6384
   },
6385
   "outputs": [
6386
    {
6387
     "data": {
6388
      "text/html": [
6389
       "<style>#sk-container-id-1 {color: black;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>Pipeline(steps=[(&#x27;preprocess&#x27;,\n",
6390
       "                 FunctionTransformer(func=&lt;function preprocess at 0x7fa070971bd0&gt;)),\n",
6391
       "                (&#x27;vectorizer&#x27;, TfidfVectorizer()),\n",
6392
       "                (&#x27;classifier&#x27;, LogisticRegression(multi_class=&#x27;multinomial&#x27;))])</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" ><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">Pipeline</label><div class=\"sk-toggleable__content\"><pre>Pipeline(steps=[(&#x27;preprocess&#x27;,\n",
6393
       "                 FunctionTransformer(func=&lt;function preprocess at 0x7fa070971bd0&gt;)),\n",
6394
       "                (&#x27;vectorizer&#x27;, TfidfVectorizer()),\n",
6395
       "                (&#x27;classifier&#x27;, LogisticRegression(multi_class=&#x27;multinomial&#x27;))])</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" ><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">FunctionTransformer</label><div class=\"sk-toggleable__content\"><pre>FunctionTransformer(func=&lt;function preprocess at 0x7fa070971bd0&gt;)</pre></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" ><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">TfidfVectorizer</label><div class=\"sk-toggleable__content\"><pre>TfidfVectorizer()</pre></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-4\" type=\"checkbox\" ><label for=\"sk-estimator-id-4\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">LogisticRegression</label><div class=\"sk-toggleable__content\"><pre>LogisticRegression(multi_class=&#x27;multinomial&#x27;)</pre></div></div></div></div></div></div></div>"
6396
      ],
6397
      "text/plain": [
6398
       "Pipeline(steps=[('preprocess',\n",
6399
       "                 FunctionTransformer(func=<function preprocess at 0x7fa070971bd0>)),\n",
6400
       "                ('vectorizer', TfidfVectorizer()),\n",
6401
       "                ('classifier', LogisticRegression(multi_class='multinomial'))])"
6402
      ]
6403
     },
6404
     "execution_count": null,
6405
     "metadata": {},
6406
     "output_type": "execute_result"
6407
    }
6408
   ],
6409
   "source": [
6410
    "# Train classifier\n",
6411
    "from rag.rerank import preprocess  # for pickle\n",
6412
    "reranker = Pipeline([\n",
6413
    "    (\"preprocess\", FunctionTransformer(preprocess)),\n",
6414
    "    (\"vectorizer\", TfidfVectorizer(lowercase=True)),\n",
6415
    "    (\"classifier\", LogisticRegression(multi_class=\"multinomial\", solver=\"lbfgs\"))\n",
6416
    "])\n",
6417
    "reranker.fit(train_df[\"text\"].tolist(), train_df[\"tag\"].tolist())"
6418
   ]
6419
  },
6420
  {
6421
   "cell_type": "markdown",
6422
   "id": "47825a4a-6b76-42ef-ad07-403710d25f97",
6423
   "metadata": {},
6424
   "source": [
6425
    "**Note**: we also trained a BERT classifier and while performance was better than our logistic classifier, these large networks suffer from [overconfidence](https://arxiv.org/abs/1706.04599) and we can't use a threshold based approach as we do below. And without the threshold approach (where we only rerank when the reranker is truly confident), then the quality score of our application does not improve."
6426
   ]
6427
  },
6428
  {
6429
   "cell_type": "markdown",
6430
   "id": "fd6ce6ca-ea56-4a27-af9d-758175170ebd",
6431
   "metadata": {},
6432
   "source": [
6433
    "## Inference"
6434
   ]
6435
  },
6436
  {
6437
   "cell_type": "code",
6438
   "execution_count": null,
6439
   "id": "bb1139d3-e43f-4bf4-8bac-43d3ee1ec30e",
6440
   "metadata": {
6441
    "tags": []
6442
   },
6443
   "outputs": [],
6444
   "source": [
6445
    "import pickle"
6446
   ]
6447
  },
6448
  {
6449
   "cell_type": "code",
6450
   "execution_count": null,
6451
   "id": "a70ca619-0bf3-4254-9a94-77701ed2c771",
6452
   "metadata": {
6453
    "tags": []
6454
   },
6455
   "outputs": [],
6456
   "source": [
6457
    "# Save\n",
6458
    "reranker_fp = Path(EFS_DIR, \"reranker.pkl\")\n",
6459
    "with open(reranker_fp, \"wb\") as file:\n",
6460
    "    pickle.dump(reranker, file)"
6461
   ]
6462
  },
6463
  {
6464
   "cell_type": "code",
6465
   "execution_count": null,
6466
   "id": "98d2450d-e035-4389-8236-b3553594972e",
6467
   "metadata": {
6468
    "tags": []
6469
   },
6470
   "outputs": [],
6471
   "source": [
6472
    "# Load\n",
6473
    "reranker_fp = Path(EFS_DIR, \"reranker.pkl\")\n",
6474
    "with open(reranker_fp, \"rb\") as file:\n",
6475
    "    reranker = pickle.load(file)"
6476
   ]
6477
  },
6478
  {
6479
   "cell_type": "code",
6480
   "execution_count": null,
6481
   "id": "f8d1d74e-3ee7-445c-ab7d-386766f18b18",
6482
   "metadata": {
6483
    "tags": []
6484
   },
6485
   "outputs": [],
6486
   "source": [
6487
    "def custom_predict(inputs, classifier, threshold=0.2, other_label=\"other\"):\n",
6488
    "    y_pred = []\n",
6489
    "    for item in classifier.predict_proba(inputs):\n",
6490
    "        prob = max(item)\n",
6491
    "        index = item.argmax()\n",
6492
    "        if prob >= threshold:\n",
6493
    "            pred = classifier.classes_[index]\n",
6494
    "        else:\n",
6495
    "            pred = other_label\n",
6496
    "        y_pred.append(pred)\n",
6497
    "    return y_pred"
6498
   ]
6499
  },
6500
  {
6501
   "cell_type": "code",
6502
   "execution_count": null,
6503
   "id": "081e53d4-ba98-4323-85df-d5f03f1807f0",
6504
   "metadata": {
6505
    "tags": []
6506
   },
6507
   "outputs": [
6508
    {
6509
     "data": {
6510
      "text/plain": [
6511
       "'train'"
6512
      ]
6513
     },
6514
     "execution_count": null,
6515
     "metadata": {},
6516
     "output_type": "execute_result"
6517
    }
6518
   ],
6519
   "source": [
6520
    "# Test inference\n",
6521
    "question = \"traning with deepspeed\"\n",
6522
    "custom_predict([question], classifier=reranker)[0]"
6523
   ]
6524
  },
6525
  {
6526
   "cell_type": "markdown",
6527
   "id": "54a19983-1caf-4b11-8093-5e2a54cf7bbf",
6528
   "metadata": {},
6529
   "source": [
6530
    "## Evaluation"
6531
   ]
6532
  },
6533
  {
6534
   "cell_type": "code",
6535
   "execution_count": null,
6536
   "id": "987ee635-cdf0-490a-89bc-1795660b7cc2",
6537
   "metadata": {
6538
    "tags": []
6539
   },
6540
   "outputs": [],
6541
   "source": [
6542
    "import json\n",
6543
    "import matplotlib.pyplot as plt\n",
6544
    "import numpy as np\n",
6545
    "import seaborn as sns\n",
6546
    "from sklearn.metrics import confusion_matrix, precision_recall_fscore_support"
6547
   ]
6548
  },
6549
  {
6550
   "cell_type": "code",
6551
   "execution_count": null,
6552
   "id": "4baa30a5-2f44-4b6b-90a6-87b149b88bdc",
6553
   "metadata": {
6554
    "tags": []
6555
   },
6556
   "outputs": [],
6557
   "source": [
6558
    "# Evaluation\n",
6559
    "metrics = {}\n",
6560
    "y_test = test_df[\"tag\"]\n",
6561
    "y_pred = custom_predict(inputs=test_df[\"text\"], classifier=reranker)"
6562
   ]
6563
  },
6564
  {
6565
   "cell_type": "code",
6566
   "execution_count": null,
6567
   "id": "96b9bf23-2091-4b0a-bd0d-8dc13c03eb62",
6568
   "metadata": {
6569
    "tags": []
6570
   },
6571
   "outputs": [
6572
    {
6573
     "data": {
6574
      "image/png": "",
6575
      "text/plain": [
6576
       "<Figure size 600x400 with 2 Axes>"
6577
      ]
6578
     },
6579
     "metadata": {},
6580
     "output_type": "display_data"
6581
    }
6582
   ],
6583
   "source": [
6584
    "# Confusion matrix\n",
6585
    "cm = confusion_matrix(y_test, y_pred)\n",
6586
    "plt.figure(figsize=(6,4))\n",
6587
    "sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\", \n",
6588
    "            xticklabels=reranker.classes_, yticklabels=reranker.classes_)\n",
6589
    "plt.ylabel(\"Actual\")\n",
6590
    "plt.xlabel(\"Predicted\")\n",
6591
    "plt.title(\"Confusion Matrix\")\n",
6592
    "plt.xticks(rotation=65)\n",
6593
    "plt.show()"
6594
   ]
6595
  },
6596
  {
6597
   "cell_type": "code",
6598
   "execution_count": null,
6599
   "id": "03057b05-ee59-47d8-b6ba-7b7f0b2c76d9",
6600
   "metadata": {
6601
    "tags": []
6602
   },
6603
   "outputs": [
6604
    {
6605
     "name": "stdout",
6606
     "output_type": "stream",
6607
     "text": [
6608
      "{\n",
6609
      "    \"precision\": 0.9168129573272782,\n",
6610
      "    \"recall\": 0.9171029668411868,\n",
6611
      "    \"f1\": 0.9154520876579969,\n",
6612
      "    \"num_samples\": 1146.0\n",
6613
      "}\n"
6614
     ]
6615
    }
6616
   ],
6617
   "source": [
6618
    "# Metrics\n",
6619
    "overall_metrics = precision_recall_fscore_support(y_test, y_pred, average=\"weighted\")\n",
6620
    "metrics[\"precision\"] = overall_metrics[0]\n",
6621
    "metrics[\"recall\"] = overall_metrics[1]\n",
6622
    "metrics[\"f1\"] = overall_metrics[2]\n",
6623
    "metrics[\"num_samples\"] = np.float64(len(y_test))\n",
6624
    "print (json.dumps(metrics, indent=4))"
6625
   ]
6626
  },
6627
  {
6628
   "cell_type": "markdown",
6629
   "id": "40364870-7f40-457e-9bbd-bd51f8619022",
6630
   "metadata": {},
6631
   "source": [
6632
    "Precision will be the most important metric here because we will only apply reranking if the softmax probability of our predicted tag is above the threshold value."
6633
   ]
6634
  },
6635
  {
6636
   "cell_type": "markdown",
6637
   "id": "9e2189b4-3b76-4770-b70e-f841a4133899",
6638
   "metadata": {},
6639
   "source": [
6640
    "## Testing"
6641
   ]
6642
  },
6643
  {
6644
   "cell_type": "markdown",
6645
   "id": "def64f1e-3c02-498f-b233-92894abc1892",
6646
   "metadata": {},
6647
   "source": [
6648
    "Besides just a metric based evaluation, we also want to assess how our model performs on some minimum functionality tests. We need all of these basic sanity checks to pass regardless of what type of model we use."
6649
   ]
6650
  },
6651
  {
6652
   "cell_type": "code",
6653
   "execution_count": null,
6654
   "id": "f6535156-fcc7-4bac-a1ba-4628171e7d27",
6655
   "metadata": {
6656
    "tags": []
6657
   },
6658
   "outputs": [],
6659
   "source": [
6660
    "def predict_proba(question, classifier):\n",
6661
    "    y_prob = classifier.predict_proba([question])\n",
6662
    "    zipped = list(zip(y_prob[0], classifier.classes_))\n",
6663
    "    return sorted(zipped, key=lambda x: x[0], reverse=True)"
6664
   ]
6665
  },
6666
  {
6667
   "cell_type": "code",
6668
   "execution_count": null,
6669
   "id": "70cf7edb-7eb9-4692-9b60-57cea8aa0ede",
6670
   "metadata": {
6671
    "tags": []
6672
   },
6673
   "outputs": [
6674
    {
6675
     "name": "stdout",
6676
     "output_type": "stream",
6677
     "text": [
6678
      "[85% train]: How to train a train an LLM using DeepSpeed? → ['how to train a train an ll ##m using deep speed ?']\n",
6679
      "[67% serve]: How does autoscaling work in a Ray Serve application? → ['how does auto ##sca ##ling work in a ray serve application ?']\n",
6680
      "[38% tune]: How to find the best checkpoint from the trial directory? → ['how to find the best checkpoint from the trial directory ?']\n",
6681
      "[88% data]: How do I avoid my dataset shuffling during a ray.data.map_batches? → ['how do i avoid my data ##set shuffling during a ray data map batch ##es ?']\n",
6682
      "[37% ray-core]: how to push a custom module to ray which is using by Actor? → ['how to push a custom module to ray which is using by actor ?']\n",
6683
      "[29% other]: How would you compare Spark, Ray, Dask? → ['how would you compare spark , ray , das ##k ?']\n",
6684
      "[59% ray-observability]: How do I enable Ray debug logs? → ['how do i enable ray de ##bu ##g logs ?']\n",
6685
      "[74% rllib]: How do I set a maximum episode length when training with Rllib → ['how do i set a maximum episode length when training with r ##lli ##b']\n"
6686
     ]
6687
    }
6688
   ],
6689
   "source": [
6690
    "# Basic tests\n",
6691
    "tests = [\n",
6692
    "    {\"question\": \"How to train a train an LLM using DeepSpeed?\", \"tag\": \"train\"},\n",
6693
    "    {\"question\": \"How does autoscaling work in a Ray Serve application?\", \"tag\": \"serve\"},\n",
6694
    "    {\"question\": \"How to find the best checkpoint from the trial directory?\", \"tag\": \"tune\"},\n",
6695
    "    {\"question\": \"How do I avoid my dataset shuffling during a ray.data.map_batches?\", \"tag\": \"data\"},\n",
6696
    "    {\"question\": \"how to push a custom module to ray which is using by Actor?\", \"tag\": \"ray-core\"},\n",
6697
    "    {\"question\": \"How would you compare Spark, Ray, Dask?\", \"tag\": \"other\"},\n",
6698
    "    {\"question\": \"How do I enable Ray debug logs?\", \"tag\": \"ray-observability\"},\n",
6699
    "    {\"question\": \"How do I set a maximum episode length when training with Rllib\", \"tag\": \"rllib\"}]\n",
6700
    "for test in tests:\n",
6701
    "    question = test[\"question\"]\n",
6702
    "    prob, pred = predict_proba(question=test[\"question\"], classifier=reranker)[0]\n",
6703
    "    print (f\"[{prob*100:.0f}% {pred}]: {question} → {preprocess([question])}\")\n",
6704
    "    assert (pred == test[\"tag\"])"
6705
   ]
6706
  },
6707
  {
6708
   "cell_type": "markdown",
6709
   "id": "f640f3ee-bc42-4f96-a0fb-f11c4015f2b4",
6710
   "metadata": {},
6711
   "source": [
6712
    "## Experiment"
6713
   ]
6714
  },
6715
  {
6716
   "cell_type": "markdown",
6717
   "id": "cd7be07c-fee9-4b04-bb31-d58e07a17700",
6718
   "metadata": {},
6719
   "source": [
6720
    "Now we're ready to apply our reranking model post retrieval using these steps:\n",
6721
    "1. Increase the retrieved context (can experiment with this) so that we can apply reranking to yield a smaller subset (`num_chunks`). The intuition here is that we'll use semantic and lexical search to retrieve N chunks (N > k) and then we'll use reranking to reorder the retrieved results (top k).\n",
6722
    "2. Perform generation using the top k retrieved chunks."
6723
   ]
6724
  },
6725
  {
6726
   "cell_type": "markdown",
6727
   "id": "99ce7c04-9da0-43e3-b5db-2c59aff3fe87",
6728
   "metadata": {},
6729
   "source": [
6730
    "We're going to alter our `QueryAgent` class directly to include reranking:\n",
6731
    "\n",
6732
    "```python\n",
6733
    "class QueryAgent():\n",
6734
    "    def __call__(rerank_threshold=0.3, rerank_k=7, **kwargs):\n",
6735
    "        # Rerank\n",
6736
    "        if self.reranker:\n",
6737
    "            predicted_tag = custom_predict(\n",
6738
    "                inputs=[query], classifier=self.reranker, threshold=rerank_threshold)[0]\n",
6739
    "            if predicted_tag != \"other\":\n",
6740
    "                sources = [item[\"source\"] for item in context_results]\n",
6741
    "                reranked_indices = get_reranked_indices(sources, predicted_tag)\n",
6742
    "                context_results = [context_results[i] for i in reranked_indices]\n",
6743
    "            context_results = context_results[:rerank_k]\n",
6744
    "```\n",
6745
    "\n",
6746
    "And the rerank function (along with other utilities specific to reranking) will now be placed in `rerank.py`:\n",
6747
    "```python\n",
6748
    "def get_reranked_indices(sources, predicted_tag):\n",
6749
    "    tags = [get_tag(source) for source in sources]\n",
6750
    "    reranked_indices = sorted(range(len(tags)), \n",
6751
    "        key=lambda i: (tags[i] != predicted_tag, i))\n",
6752
    "    return reranked_indices\n",
6753
    "```"
6754
   ]
6755
  },
6756
  {
6757
   "cell_type": "markdown",
6758
   "id": "61f82f9f-694d-4c49-b39d-8cab4c285b13",
6759
   "metadata": {},
6760
   "source": [
6761
    "And with that, let's use our query agent augmented with reranking on an evaluation run."
6762
   ]
6763
  },
6764
  {
6765
   "cell_type": "code",
6766
   "execution_count": null,
6767
   "id": "d2713153-da91-46e3-b1c7-5d06903d550d",
6768
   "metadata": {
6769
    "tags": []
6770
   },
6771
   "outputs": [],
6772
   "source": [
6773
    "experiment_names = []\n",
6774
    "rerank_threshold_list = [0, 0.3, 0.5, 0.7, 0.9]"
6775
   ]
6776
  },
6777
  {
6778
   "cell_type": "code",
6779
   "execution_count": null,
6780
   "id": "55f34f33-3764-49f1-8aed-7536c426ef9b",
6781
   "metadata": {
6782
    "tags": []
6783
   },
6784
   "outputs": [
6785
    {
6786
     "data": {
6787
      "application/json": {
6788
       "generated_answer": " To kill a specific serve replica, you can use the following steps:\n\n1. Get the handle to the replica using `ray.get_actor()` function with the NAME of the replica from the output of `ray list actors` command.\n2. Use `ray.kill()` function to kill the replica, passing the replica handle as an argument. You can also specify `no_restart=True` to prevent the replica from automatically restarting.\n\nHere is an example:\n```python\nimport ray\n\n# Get the handle to the replica\nreplica_handle = ray.get_actor(\"SERVE_REPLICA::SleepyPid#RlRptP\", namespace=\"serve\")\n\n# Kill the replica\nray.kill(replica_handle, no_restart=True)\n```\nNote: While the replica is restarted, the other replicas can continue processing requests. Eventually, the replica restarts and continues serving requests.",
6789
       "question": "how do I kill a specific serve replica",
6790
       "reasoning": "The generated answer is detailed, accurate, and provides a step-by-step guide on how to kill a specific serve replica. It even includes a code example for better understanding. The reference answer does not provide any useful information, making the generated answer superior.",
6791
       "reference_answer": "The context does not provide information on how to kill a specific serve replica.",
6792
       "score": 5,
6793
       "sources": [
6794
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
6795
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
6796
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
6797
        "https://docs.ray.io/en/master/serve/advanced-guides/dev-workflow.html#local-development-with-http-requests",
6798
        "https://docs.ray.io/en/master/serve/api/index.html#delete-api-serve-deployments",
6799
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#deployment-replica-failure",
6800
        "https://docs.ray.io/en/master/serve/scaling-and-resource-allocation.html#autoscaling",
6801
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#worker-node-failure",
6802
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#httpproxy-failure",
6803
        "https://docs.ray.io/en/master/serve/api/index.html#delete-api-serve-applications",
6804
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#replica-health-checking",
6805
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#httpproxy-failure",
6806
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#serve-controller-failure",
6807
        "https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html#head-node-failure"
6808
       ]
6809
      },
6810
      "text/plain": [
6811
       "<IPython.core.display.JSON object>"
6812
      ]
6813
     },
6814
     "metadata": {
6815
      "application/json": {
6816
       "expanded": false,
6817
       "root": "root"
6818
      }
6819
     },
6820
     "output_type": "display_data"
6821
    },
6822
    {
6823
     "name": "stderr",
6824
     "output_type": "stream",
6825
     "text": [
6826
      "100%|██████████| 177/177 [17:03<00:00,  5.78s/it]\n"
6827
     ]
6828
    }
6829
   ],
6830
   "source": [
6831
    "# Experiment\n",
6832
    "use_reranking = True\n",
6833
    "num_chunks = 30  # increase number of chunks\n",
6834
    "rerank_k = NUM_CHUNKS + LEXICAL_SEARCH_K\n",
6835
    "for rerank_threshold in rerank_threshold_list:\n",
6836
    "    experiment_name = f\"rerank-{rerank_threshold}\"\n",
6837
    "    experiment_names.append(experiment_name)\n",
6838
    "    run_experiment(\n",
6839
    "        experiment_name=experiment_name, \n",
6840
    "        chunk_size=CHUNK_SIZE, \n",
6841
    "        chunk_overlap=CHUNK_OVERLAP, \n",
6842
    "        num_chunks=num_chunks,\n",
6843
    "        embedding_model_name=EMBEDDING_MODEL_NAME,\n",
6844
    "        embedding_dim=EMBEDDING_DIMENSIONS[EMBEDDING_MODEL_NAME],\n",
6845
    "        llm=LLM,\n",
6846
    "        evaluator=EVALUATOR,\n",
6847
    "        docs_dir=DOCS_DIR, \n",
6848
    "        experiments_dir=EXPERIMENTS_DIR, \n",
6849
    "        references_fp=REFERENCES_FILE_PATH,\n",
6850
    "        use_lexical_search=USE_LEXICAL_SEARCH,\n",
6851
    "        lexical_search_k=LEXICAL_SEARCH_K,\n",
6852
    "        use_reranking=use_reranking,\n",
6853
    "        rerank_threshold=rerank_threshold,\n",
6854
    "        rerank_k=rerank_k,\n",
6855
    "        num_samples=NUM_SAMPLES)"
6856
   ]
6857
  },
6858
  {
6859
   "cell_type": "code",
6860
   "execution_count": null,
6861
   "id": "166b11ce-df5b-46a4-949c-16d7df45ab6c",
6862
   "metadata": {
6863
    "tags": []
6864
   },
6865
   "outputs": [
6866
    {
6867
     "name": "stdout",
6868
     "output_type": "stream",
6869
     "text": [
6870
      "rerank-0\n",
6871
      "  retrieval score: 0.7457627118644068\n",
6872
      "  quality score: 3.9124293785310735\n",
6873
      "\n",
6874
      "rerank-0.3\n",
6875
      "  retrieval score: 0.7570621468926554\n",
6876
      "  quality score: 3.906779661016949\n",
6877
      "\n",
6878
      "rerank-0.5\n",
6879
      "  retrieval score: 0.7740112994350282\n",
6880
      "  quality score: 3.9915254237288136\n",
6881
      "\n",
6882
      "rerank-0.7\n",
6883
      "  retrieval score: 0.7853107344632768\n",
6884
      "  quality score: 3.9858757062146895\n",
6885
      "\n",
6886
      "rerank-0.9\n",
6887
      "  retrieval score: 0.7853107344632768\n",
6888
      "  quality score: 4.022598870056497\n",
6889
      "\n"
6890
     ]
6891
    },
6892
    {
6893
     "data": {
6894
      "image/png": "",
6895
      "text/plain": [
6896
       "<Figure size 1000x300 with 1 Axes>"
6897
      ]
6898
     },
6899
     "metadata": {},
6900
     "output_type": "display_data"
6901
    }
6902
   ],
6903
   "source": [
6904
    "scores = {}\n",
6905
    "for experiment_name in experiment_names:\n",
6906
    "    scores[experiment_name] = print_experiment(experiment_name, EXPERIMENTS_DIR)\n",
6907
    "plot_scores(scores=scores)"
6908
   ]
6909
  },
6910
  {
6911
   "cell_type": "code",
6912
   "execution_count": null,
6913
   "id": "65429a21-ee23-4d0a-9dee-60a4a4b09d09",
6914
   "metadata": {
6915
    "tags": []
6916
   },
6917
   "outputs": [],
6918
   "source": [
6919
    "original_num_chunks = NUM_CHUNKS\n",
6920
    "NUM_CHUNKS = 30\n",
6921
    "USE_RERANKING = True\n",
6922
    "RERANK_THRESHOLD = 0.9\n",
6923
    "RERANK_K = original_num_chunks + LEXICAL_SEARCH_K"
6924
   ]
6925
  },
6926
  {
6927
   "cell_type": "markdown",
6928
   "id": "628021f8-e79b-46e5-a8e0-ca7deb3be134",
6929
   "metadata": {},
6930
   "source": [
6931
    "**Note**: there is still a lot more to experiment with reranking (increasing the initial `num_chunks`, adding lexical search resutls *after* reranking, weighted reranking where we promote the top N classes, etc.)"
6932
   ]
6933
  },
6934
  {
6935
   "cell_type": "code",
6936
   "execution_count": null,
6937
   "id": "4193ac04-bf77-4ae6-8273-a613b91f81a4",
6938
   "metadata": {
6939
    "tags": []
6940
   },
6941
   "outputs": [
6942
    {
6943
     "data": {
6944
      "text/plain": [
6945
       "[('gpt-4-1106-preview',\n",
6946
       "  {'retrieval_score': 0.7288135593220338, 'quality_score': 4.209039548022599}),\n",
6947
       " ('rerank-0.9',\n",
6948
       "  {'retrieval_score': 0.7853107344632768, 'quality_score': 4.022598870056497}),\n",
6949
       " ('lexical-search-bm25-1',\n",
6950
       "  {'retrieval_score': 0.7853107344632768, 'quality_score': 4.019774011299435})]"
6951
      ]
6952
     },
6953
     "execution_count": null,
6954
     "metadata": {},
6955
     "output_type": "execute_result"
6956
    }
6957
   ],
6958
   "source": [
6959
    "# Top experiments (by quality and retrieval score)\n",
6960
    "top_n = 3\n",
6961
    "experiment_results = {}\n",
6962
    "all_experiments = [d for d in Path(EXPERIMENTS_DIR, \"evaluations\").iterdir() if d.is_file()]\n",
6963
    "for experiment_fp in all_experiments:\n",
6964
    "    with open(str(experiment_fp), \"r\") as fp:\n",
6965
    "        results = json.load(fp)\n",
6966
    "    experiment_results[results[\"config\"][\"experiment_name\"]] = {\n",
6967
    "        \"retrieval_score\": results[\"retrieval_score\"], \n",
6968
    "        \"quality_score\": results[\"quality_score\"]}\n",
6969
    "sorted(experiment_results.items(), key=lambda i: (i[1].get(\"quality_score\", float('inf')), i[1].get(\"retrieval_score\")), reverse=True)[:top_n]"
6970
   ]
6971
  },
6972
  {
6973
   "cell_type": "markdown",
6974
   "id": "f916ea05-4b99-4024-9211-ed35bb8ac9dc",
6975
   "metadata": {
6976
    "tags": [],
6977
    "toc-hr-collapsed": true
6978
   },
6979
   "source": [
6980
    "# Cost analysis"
6981
   ]
6982
  },
6983
  {
6984
   "cell_type": "markdown",
6985
   "id": "2c420359-c506-4257-b0d8-8b826134d001",
6986
   "metadata": {},
6987
   "source": [
6988
    "Besides just performance, we also want to evaluate the cost of our configurations (especially given the high price points of larger LLMs). We’re going to break this down into prompt and sampled pricing. The prompt size is the number of characters in our system, assistant and user contents (which includes the retrieved contexts). And the sampled size is the number of characters the LLM generated in its response."
6989
   ]
6990
  },
6991
  {
6992
   "cell_type": "markdown",
6993
   "id": "96c145a5-e97e-4811-ba12-c1a37d71c171",
6994
   "metadata": {
6995
    "tags": []
6996
   },
6997
   "source": [
6998
    "**Note**: Our OSS models are served via [Anyscale Endpoints](https://endpoints.anyscale.com/)."
6999
   ]
7000
  },
7001
  {
7002
   "cell_type": "code",
7003
   "execution_count": null,
7004
   "id": "7f020340-9357-4a48-976b-cdac2e467351",
7005
   "metadata": {
7006
    "tags": []
7007
   },
7008
   "outputs": [],
7009
   "source": [
7010
    "def cost_analysis(experiment_name):\n",
7011
    "    eval_fp = Path(ROOT_DIR, EXPERIMENTS_DIR, \"evaluations\", f\"{experiment_name}_{EVALUATOR}.json\")\n",
7012
    "    with open(eval_fp, \"r\") as fp:\n",
7013
    "        d = json.load(fp)\n",
7014
    "    num_samples = len(d[\"results\"])\n",
7015
    "    prompt_size, sampled_size = 0, 0\n",
7016
    "    for result in d[\"results\"]:\n",
7017
    "        prompt_size += get_num_tokens(result[\"question\"]) + \\\n",
7018
    "            ((CHUNK_SIZE/5)*(4/3) * original_num_chunks)  # 5 chars / word, 1 token = 3/4 word\n",
7019
    "        sampled_size += get_num_tokens(result[\"generated_answer\"])\n",
7020
    "    total_cost = PRICING[experiment_name][\"prompt\"]/1e6 * prompt_size + PRICING[experiment_name][\"sampled\"]/1e6 * sampled_size\n",
7021
    "    avg_cost = total_cost / num_samples\n",
7022
    "    \n",
7023
    "    print (experiment_name)\n",
7024
    "    print (f\"  prompted tokens (avg): {int(prompt_size/num_samples)}\")\n",
7025
    "    print (f\"  sampled tokens (avg): {int(sampled_size/num_samples)}\")\n",
7026
    "    print (f\"  total cost: ${total_cost:.4f}\")\n",
7027
    "    print (f\"  avg cost: ${avg_cost:.4f}\")\n",
7028
    "    print ()\n",
7029
    "    return avg_cost"
7030
   ]
7031
  },
7032
  {
7033
   "cell_type": "code",
7034
   "execution_count": null,
7035
   "id": "ee95574b-b355-4d5a-979b-467657bbd959",
7036
   "metadata": {
7037
    "tags": []
7038
   },
7039
   "outputs": [
7040
    {
7041
     "name": "stdout",
7042
     "output_type": "stream",
7043
     "text": [
7044
      "gpt-3.5-turbo\n",
7045
      "  prompted tokens (avg): 1695\n",
7046
      "  sampled tokens (avg): 73\n",
7047
      "  total cost: $0.4761\n",
7048
      "  avg cost: $0.0027\n",
7049
      "\n",
7050
      "gpt-4\n",
7051
      "  prompted tokens (avg): 1695\n",
7052
      "  sampled tokens (avg): 111\n",
7053
      "  total cost: $10.1884\n",
7054
      "  avg cost: $0.0576\n",
7055
      "\n",
7056
      "gpt-4-1106-preview\n",
7057
      "  prompted tokens (avg): 1695\n",
7058
      "  sampled tokens (avg): 200\n",
7059
      "  total cost: $4.0641\n",
7060
      "  avg cost: $0.0230\n",
7061
      "\n",
7062
      "llama-2-7b-chat-hf\n",
7063
      "  prompted tokens (avg): 1695\n",
7064
      "  sampled tokens (avg): 248\n",
7065
      "  total cost: $0.0516\n",
7066
      "  avg cost: $0.0003\n",
7067
      "\n",
7068
      "llama-2-13b-chat-hf\n",
7069
      "  prompted tokens (avg): 1695\n",
7070
      "  sampled tokens (avg): 226\n",
7071
      "  total cost: $0.0851\n",
7072
      "  avg cost: $0.0005\n",
7073
      "\n",
7074
      "llama-2-70b-chat-hf\n",
7075
      "  prompted tokens (avg): 1695\n",
7076
      "  sampled tokens (avg): 218\n",
7077
      "  total cost: $0.3388\n",
7078
      "  avg cost: $0.0019\n",
7079
      "\n",
7080
      "codellama-34b-instruct-hf\n",
7081
      "  prompted tokens (avg): 1695\n",
7082
      "  sampled tokens (avg): 187\n",
7083
      "  total cost: $0.3333\n",
7084
      "  avg cost: $0.0019\n",
7085
      "\n",
7086
      "mistral-7b-instruct-v0.1\n",
7087
      "  prompted tokens (avg): 1695\n",
7088
      "  sampled tokens (avg): 135\n",
7089
      "  total cost: $0.0486\n",
7090
      "  avg cost: $0.0003\n",
7091
      "\n",
7092
      "mixtral-8x7b-instruct-v0.1\n",
7093
      "  prompted tokens (avg): 1695\n",
7094
      "  sampled tokens (avg): 180\n",
7095
      "  total cost: $0.1661\n",
7096
      "  avg cost: $0.0009\n",
7097
      "\n"
7098
     ]
7099
    }
7100
   ],
7101
   "source": [
7102
    "scores = {}\n",
7103
    "for experiment_name in PRICING.keys():\n",
7104
    "    scores[experiment_name] = print_experiment(experiment_name, EXPERIMENTS_DIR, verbose=False)\n",
7105
    "    scores[experiment_name][\"average_cost\"] = cost_analysis(experiment_name=experiment_name)"
7106
   ]
7107
  },
7108
  {
7109
   "cell_type": "code",
7110
   "execution_count": null,
7111
   "id": "e58b0513-f630-47c2-a30f-aa900bfaa66c",
7112
   "metadata": {
7113
    "tags": []
7114
   },
7115
   "outputs": [
7116
    {
7117
     "data": {
7118
      "image/png": "",
7119
      "text/plain": [
7120
       "<Figure size 1000x400 with 1 Axes>"
7121
      ]
7122
     },
7123
     "metadata": {},
7124
     "output_type": "display_data"
7125
    }
7126
   ],
7127
   "source": [
7128
    "# Prepare data for plotting\n",
7129
    "experiment_names = list(scores.keys())\n",
7130
    "average_costs = [scores[experiment_name][\"average_cost\"] for experiment_name in experiment_names]\n",
7131
    "quality_scores = [scores[experiment_name][\"quality_score\"] for experiment_name in experiment_names]\n",
7132
    "sorted_data = sorted(zip(quality_scores, average_costs, experiment_names))\n",
7133
    "\n",
7134
    "# Plotting\n",
7135
    "plt.figure(figsize=(10, 4))\n",
7136
    "for i, (q_score, avg_cost, exp_name) in enumerate(sorted_data):\n",
7137
    "    plt.scatter(q_score, avg_cost, label=exp_name)\n",
7138
    "    ha = \"left\"  # Horizontal alignment\n",
7139
    "    va = \"bottom\"  # Vertical alignment\n",
7140
    "    offset = 0.02  # Small offset for labels\n",
7141
    "    if i > 0 and abs(sorted_data[i - 1][0] - q_score) < offset: ha = \"left\"  # Check left neighbor\n",
7142
    "    if i < len(sorted_data) - 1 and abs(sorted_data[i + 1][0] - q_score) < offset: ha = \"right\"  # Check right neighbor\n",
7143
    "    plt.text(q_score, avg_cost, exp_name, ha=ha, va=va, )\n",
7144
    "\n",
7145
    "# Add labels and title\n",
7146
    "plt.xlabel(\"Quality Score\")\n",
7147
    "plt.ylabel(\"Average cost / query ($)\")\n",
7148
    "plt.legend(title=\"Experiments\", loc=\"upper left\", fontsize=8)\n",
7149
    "plt.yscale(\"log\")\n",
7150
    "plt.show()"
7151
   ]
7152
  },
7153
  {
7154
   "cell_type": "markdown",
7155
   "id": "d03dabb0-5a77-42da-b6c7-a65e8a5920a7",
7156
   "metadata": {},
7157
   "source": [
7158
    "**Note**: This cost analysis is performed with our original experiments before lexical search, reranking, etc. since we haven't run experiments with these improvements on the other OSS and closed source LLMs yet."
7159
   ]
7160
  },
7161
  {
7162
   "cell_type": "markdown",
7163
   "id": "d223b361-4d14-4ea1-804b-0eb92eb1224e",
7164
   "metadata": {
7165
    "tags": []
7166
   },
7167
   "source": [
7168
    "# Routing"
7169
   ]
7170
  },
7171
  {
7172
   "cell_type": "markdown",
7173
   "id": "2463c757-f54e-4905-bd8b-d00a32c2ec7d",
7174
   "metadata": {},
7175
   "source": [
7176
    "It seems that the most performant LLM, `gpt-4-turbo`, is also very expensive. While our OSS LLM (`mixtral-8x7b-instruct-v0.1`) is very close in quality but ~25X more cost-effective."
7177
   ]
7178
  },
7179
  {
7180
   "cell_type": "code",
7181
   "execution_count": null,
7182
   "id": "391b4116-4312-48ef-aa3d-2184d46ed64b",
7183
   "metadata": {
7184
    "tags": []
7185
   },
7186
   "outputs": [
7187
    {
7188
     "name": "stdout",
7189
     "output_type": "stream",
7190
     "text": [
7191
      "Cost multiplier compared to mixtral-8x7b-instruct-v0.1\n",
7192
      "  gpt-3.5-turbo: 2.87X\n",
7193
      "  gpt-4: 61.35X\n",
7194
      "  gpt-4-1106-preview: 24.47X\n",
7195
      "  llama-2-7b-chat-hf: 0.31X\n",
7196
      "  llama-2-13b-chat-hf: 0.51X\n",
7197
      "  llama-2-70b-chat-hf: 2.04X\n",
7198
      "  codellama-34b-instruct-hf: 2.01X\n",
7199
      "  mistral-7b-instruct-v0.1: 0.29X\n",
7200
      "  mixtral-8x7b-instruct-v0.1: 1.00X\n"
7201
     ]
7202
    }
7203
   ],
7204
   "source": [
7205
    "# Cost multiplier\n",
7206
    "chosen_model = LLM.split('/')[-1].lower()\n",
7207
    "print (f\"Cost multiplier compared to {chosen_model}\")\n",
7208
    "for model in scores:\n",
7209
    "    print (f\"  {model}: {scores[model]['average_cost']/scores[chosen_model]['average_cost']:.2f}X\")"
7210
   ]
7211
  },
7212
  {
7213
   "cell_type": "markdown",
7214
   "id": "29727066-6a67-4ff9-a4d4-b4517f5014c4",
7215
   "metadata": {},
7216
   "source": [
7217
    "However, we want to be able to serve the most performant and cost-effective solution. We can close this gap in performance between open source and proprietary models by routing queries to the right model according to the complexity or topic of the query. For example, in our application, open source models perform really well on simple queries where the answer can be easily inferred from the retrieved context. However, the OSS models fall short for queries that involve reasoning, numbers or code examples. To identify the appropriate LLM to use, we can train a classifier that takes the query and routes it to the best model.\n",
7218
    "\n",
7219
    "<img width=\"800\" src=\"https://images.ctfassets.net/xjan103pcp94/7FWrvPPlIdz5fs8wQgxLFz/fdae368044275028f0544a3d252fcfe4/image15.png\">\n",
7220
    "\n",
7221
    "**Note**: In part 2 of this series, we’ll fine-tune our embedding models and OSS LLMs to make them even more performant.\n",
7222
    "\n",
7223
    "In order to implement this, we hand-annotated a [dataset of 1.8k queries](https://github.com/ray-project/llm-applications/blob/main/datasets/routing-dataset-train.jsonl) according to which model (`gpt-4` (label=0) or `codellama-34b` (label=1)) would be appropriate -- by default we route to OSS LLM and only if the query needs more advanced capabilities do we send the query to `gpt-4`. We then evaluate the performance of the model on a test set that has been scored with an evaluator."
7224
   ]
7225
  },
7226
  {
7227
   "cell_type": "markdown",
7228
   "id": "1a731efa-7990-4856-a9e0-250be556505b",
7229
   "metadata": {},
7230
   "source": [
7231
    "## Dataset"
7232
   ]
7233
  },
7234
  {
7235
   "cell_type": "markdown",
7236
   "id": "d506e529-518b-4b37-bd81-4d5e7532b528",
7237
   "metadata": {},
7238
   "source": [
7239
    "Let's first train the model on the training dataset [routing-dataset-training.jsonl](https://github.com/ray-project/llm-applications/blob/main/datasets/routing-dataset-train.jsonl):"
7240
   ]
7241
  },
7242
  {
7243
   "cell_type": "code",
7244
   "execution_count": null,
7245
   "id": "86e4a927-ab91-4189-97d2-35bd045c0360",
7246
   "metadata": {
7247
    "tags": []
7248
   },
7249
   "outputs": [],
7250
   "source": [
7251
    "with open(Path(ROOT_DIR, \"datasets\", \"routing-dataset-train.jsonl\")) as f:\n",
7252
    "    records = [json.loads(l) for l in f]\n",
7253
    "    texts = [record[\"question\"] for record in records]\n",
7254
    "    labels = [record[\"target\"] for record in records]"
7255
   ]
7256
  },
7257
  {
7258
   "cell_type": "code",
7259
   "execution_count": null,
7260
   "id": "9d9a4d97-a39b-409e-9b45-ae609c9f651b",
7261
   "metadata": {
7262
    "tags": []
7263
   },
7264
   "outputs": [
7265
    {
7266
     "name": "stdout",
7267
     "output_type": "stream",
7268
     "text": [
7269
      "Question for gpt-4:\n",
7270
      " {'question': 'if I am inside of a anyscale cluster how do I get my cluster-env-build-id', 'target': 0}\n",
7271
      "\n",
7272
      "Question for codellama-34b:\n",
7273
      " {'question': 'what is num_samples in tune?', 'target': 1}\n"
7274
     ]
7275
    }
7276
   ],
7277
   "source": [
7278
    "# Sample records (1 = can be handled by OSS LLM)\n",
7279
    "print (\"Question for gpt-4:\\n\", [record for record in records if record[\"target\"] == 0][0]) \n",
7280
    "print (\"\\nQuestion for OSS LLM:\\n\", [record for record in records if record[\"target\"] == 1][0])"
7281
   ]
7282
  },
7283
  {
7284
   "cell_type": "markdown",
7285
   "id": "ea873f02-d4e4-4bfa-b583-7d750c1144ff",
7286
   "metadata": {},
7287
   "source": [
7288
    "## Modeling"
7289
   ]
7290
  },
7291
  {
7292
   "cell_type": "code",
7293
   "execution_count": null,
7294
   "id": "25404ad8-8300-4e6d-ac98-4117e4608128",
7295
   "metadata": {
7296
    "tags": []
7297
   },
7298
   "outputs": [],
7299
   "source": [
7300
    "import pickle\n",
7301
    "from sklearn.pipeline import Pipeline\n",
7302
    "from sklearn.feature_extraction.text import CountVectorizer\n",
7303
    "from sklearn.linear_model import LogisticRegression"
7304
   ]
7305
  },
7306
  {
7307
   "cell_type": "code",
7308
   "execution_count": null,
7309
   "id": "3021c4fa-caf1-489d-90d5-cf1890214c0d",
7310
   "metadata": {
7311
    "tags": []
7312
   },
7313
   "outputs": [
7314
    {
7315
     "data": {
7316
      "text/html": [
7317
       "<style>#sk-container-id-10 {color: black;}#sk-container-id-10 pre{padding: 0;}#sk-container-id-10 div.sk-toggleable {background-color: white;}#sk-container-id-10 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-10 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-10 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-10 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-10 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-10 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-10 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-10 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-10 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-10 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-10 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-10 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-10 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-10 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-10 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-10 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-10 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-10 div.sk-item {position: relative;z-index: 1;}#sk-container-id-10 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-10 div.sk-item::before, #sk-container-id-10 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-10 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-10 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-10 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-10 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-10 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-10 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-10 div.sk-label-container {text-align: center;}#sk-container-id-10 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-10 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-10\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>Pipeline(steps=[(&#x27;vectorizer&#x27;, CountVectorizer()),\n",
7318
       "                (&#x27;classifier&#x27;, LogisticRegression(multi_class=&#x27;multinomial&#x27;))])</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-32\" type=\"checkbox\" ><label for=\"sk-estimator-id-32\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">Pipeline</label><div class=\"sk-toggleable__content\"><pre>Pipeline(steps=[(&#x27;vectorizer&#x27;, CountVectorizer()),\n",
7319
       "                (&#x27;classifier&#x27;, LogisticRegression(multi_class=&#x27;multinomial&#x27;))])</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-33\" type=\"checkbox\" ><label for=\"sk-estimator-id-33\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">CountVectorizer</label><div class=\"sk-toggleable__content\"><pre>CountVectorizer()</pre></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-34\" type=\"checkbox\" ><label for=\"sk-estimator-id-34\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">LogisticRegression</label><div class=\"sk-toggleable__content\"><pre>LogisticRegression(multi_class=&#x27;multinomial&#x27;)</pre></div></div></div></div></div></div></div>"
7320
      ],
7321
      "text/plain": [
7322
       "Pipeline(steps=[('vectorizer', CountVectorizer()),\n",
7323
       "                ('classifier', LogisticRegression(multi_class='multinomial'))])"
7324
      ]
7325
     },
7326
     "execution_count": null,
7327
     "metadata": {},
7328
     "output_type": "execute_result"
7329
    }
7330
   ],
7331
   "source": [
7332
    "# Train classifier\n",
7333
    "router = Pipeline([\n",
7334
    "    (\"vectorizer\", CountVectorizer()),\n",
7335
    "    (\"classifier\", LogisticRegression(multi_class=\"multinomial\", solver=\"lbfgs\"))\n",
7336
    "])\n",
7337
    "router.fit(texts, labels)"
7338
   ]
7339
  },
7340
  {
7341
   "cell_type": "markdown",
7342
   "id": "f7570fc7-158b-4799-87fb-33c1298b9594",
7343
   "metadata": {},
7344
   "source": [
7345
    "**Note**: we also trained a BERT classifier and performance was not as good as our logistic regression use case given the size of our dataset."
7346
   ]
7347
  },
7348
  {
7349
   "cell_type": "markdown",
7350
   "id": "68dd8c1e-08b7-43ee-829a-f3c6c2e8d8b7",
7351
   "metadata": {},
7352
   "source": [
7353
    "## Inference"
7354
   ]
7355
  },
7356
  {
7357
   "cell_type": "code",
7358
   "execution_count": null,
7359
   "id": "5ca1ab5d-9a68-4062-a9f9-0dc267b1abe6",
7360
   "metadata": {
7361
    "tags": []
7362
   },
7363
   "outputs": [],
7364
   "source": [
7365
    "# Router file path\n",
7366
    "router_fp = Path(EFS_DIR, \"router.pkl\")"
7367
   ]
7368
  },
7369
  {
7370
   "cell_type": "code",
7371
   "execution_count": null,
7372
   "id": "b3b9169b-2534-4c55-bf4b-471a4e524ed9",
7373
   "metadata": {
7374
    "tags": []
7375
   },
7376
   "outputs": [],
7377
   "source": [
7378
    "# Save\n",
7379
    "with open(router_fp, \"wb\") as file:\n",
7380
    "    pickle.dump(router, file)"
7381
   ]
7382
  },
7383
  {
7384
   "cell_type": "code",
7385
   "execution_count": null,
7386
   "id": "82f07fc8-3cf6-4189-a8a4-593860a73acf",
7387
   "metadata": {
7388
    "tags": []
7389
   },
7390
   "outputs": [],
7391
   "source": [
7392
    "# Load\n",
7393
    "with open(router_fp, \"rb\") as file:\n",
7394
    "    router = pickle.load(file)"
7395
   ]
7396
  },
7397
  {
7398
   "cell_type": "code",
7399
   "execution_count": null,
7400
   "id": "57536c4d-34d8-48e1-9d86-b83a5d766409",
7401
   "metadata": {
7402
    "tags": []
7403
   },
7404
   "outputs": [
7405
    {
7406
     "name": "stdout",
7407
     "output_type": "stream",
7408
     "text": [
7409
      "[OSS] What is the default batch size for map_batches?\n"
7410
     ]
7411
    }
7412
   ],
7413
   "source": [
7414
    "# Inference\n",
7415
    "router_map = {0: \"GPT-4\", 1: \"OSS\"}\n",
7416
    "query = \"What is the default batch size for map_batches?\"\n",
7417
    "print (f\"[{router_map[router.predict([query])[0]]}]\", query)"
7418
   ]
7419
  },
7420
  {
7421
   "cell_type": "markdown",
7422
   "id": "77cc5449-5cfc-469d-9f84-42f4ebff4f3d",
7423
   "metadata": {},
7424
   "source": [
7425
    "## Evaluation"
7426
   ]
7427
  },
7428
  {
7429
   "cell_type": "markdown",
7430
   "id": "6cc6a7b8-427b-433c-98e3-c82b6ffaf042",
7431
   "metadata": {},
7432
   "source": [
7433
    "Now let's evaluate the performance on the [test dataset](https://github.com/ray-project/llm-applications/blob/main/datasets/routing-dataset-test.jsonl):"
7434
   ]
7435
  },
7436
  {
7437
   "cell_type": "code",
7438
   "execution_count": null,
7439
   "id": "073421c5-662c-4774-b256-103e11d81361",
7440
   "metadata": {
7441
    "tags": []
7442
   },
7443
   "outputs": [],
7444
   "source": [
7445
    "import numpy as np\n",
7446
    "import pickle\n",
7447
    "from sklearn.metrics import accuracy_score, classification_report\n",
7448
    "from sklearn.metrics import precision_recall_fscore_support"
7449
   ]
7450
  },
7451
  {
7452
   "cell_type": "code",
7453
   "execution_count": null,
7454
   "id": "15b20f4e-57dc-42a5-9fce-012ce2c672e1",
7455
   "metadata": {
7456
    "tags": []
7457
   },
7458
   "outputs": [],
7459
   "source": [
7460
    "# Load\n",
7461
    "with open(router_fp, \"rb\") as file:\n",
7462
    "    router = pickle.load(file)"
7463
   ]
7464
  },
7465
  {
7466
   "cell_type": "code",
7467
   "execution_count": null,
7468
   "id": "87181fbe-b56f-485e-8089-34a2b1de3f49",
7469
   "metadata": {
7470
    "tags": []
7471
   },
7472
   "outputs": [],
7473
   "source": [
7474
    "with open(Path(ROOT_DIR, \"datasets\", \"routing-dataset-test.jsonl\")) as f:\n",
7475
    "    records = [json.loads(line) for line in f]\n",
7476
    "    texts = [record[\"question\"] for record in records]\n",
7477
    "    y_test = [record[\"target\"] for record in records]\n",
7478
    "    score_test = [record[\"score\"] for record in records]"
7479
   ]
7480
  },
7481
  {
7482
   "cell_type": "code",
7483
   "execution_count": null,
7484
   "id": "4c2b3f1d-be35-4103-9aeb-ae8d562b3ee2",
7485
   "metadata": {
7486
    "tags": []
7487
   },
7488
   "outputs": [],
7489
   "source": [
7490
    "# Predictions\n",
7491
    "y_pred = router.predict(texts)"
7492
   ]
7493
  },
7494
  {
7495
   "cell_type": "code",
7496
   "execution_count": null,
7497
   "id": "beddd2f0-80a3-4c3d-a66c-7977fb1c43bc",
7498
   "metadata": {
7499
    "tags": []
7500
   },
7501
   "outputs": [
7502
    {
7503
     "name": "stdout",
7504
     "output_type": "stream",
7505
     "text": [
7506
      "{\n",
7507
      "    \"precision\": 0.9191264005602239,\n",
7508
      "    \"recall\": 0.9285714285714286,\n",
7509
      "    \"f1\": 0.9226432439812495,\n",
7510
      "    \"num_samples\": 574.0\n",
7511
      "}\n"
7512
     ]
7513
    }
7514
   ],
7515
   "source": [
7516
    "metrics = {}\n",
7517
    "performance = precision_recall_fscore_support(y_test, y_pred, average=\"weighted\")\n",
7518
    "metrics[\"precision\"] = performance[0]\n",
7519
    "metrics[\"recall\"] = performance[1]\n",
7520
    "metrics[\"f1\"] = performance[2]\n",
7521
    "metrics[\"num_samples\"] = np.float64(len(y_test))\n",
7522
    "print (json.dumps(metrics, indent=4))"
7523
   ]
7524
  },
7525
  {
7526
   "cell_type": "code",
7527
   "execution_count": null,
7528
   "id": "9deb50e5-5702-4740-a92c-602af580f572",
7529
   "metadata": {
7530
    "tags": []
7531
   },
7532
   "outputs": [
7533
    {
7534
     "name": "stdout",
7535
     "output_type": "stream",
7536
     "text": [
7537
      "# total samples 574\n",
7538
      "# samples for OSS models: 544 (94.8%)\n",
7539
      "Performance on samples predicted for codellama/CodeLlama-34b-Instruct-hf: 3.87\n",
7540
      "Performance on samples predicted for gpt-4: 3.55\n"
7541
     ]
7542
    }
7543
   ],
7544
   "source": [
7545
    "print (\"# total samples\", len(y_pred))\n",
7546
    "print(f\"# samples for OSS models: {sum(y_pred)} ({sum(y_pred)*100/len(y_pred):.1f}%)\")\n",
7547
    "print(f\"Performance on samples predicted for {LLM}: {np.mean([score_test[i] for i, p in enumerate(y_pred) if p]):.2f}\")\n",
7548
    "print(f\"Performance on samples predicted for gpt-4: {np.mean([score_test[i] for i, p in enumerate(y_pred) if not p]):.2f}\")"
7549
   ]
7550
  },
7551
  {
7552
   "cell_type": "markdown",
7553
   "id": "ec755df6-2568-445c-a9b0-7f3fa30fa9d4",
7554
   "metadata": {},
7555
   "source": [
7556
    "**Note**: For our dataset, a small logistic regression model is good enough to perform the routing. But if your use case is more complex, consider training a more complex model, like a BERT-based classifier to perform the classification. These models are still small enough that wouldn’t introduce too much latency. Be sure to check out this [guide](https://github.com/GokuMohandas/Made-With-ML) if you want to learn how to train and deploy supervised deep learning models."
7557
   ]
7558
  },
7559
  {
7560
   "cell_type": "markdown",
7561
   "id": "a84b35b1-1342-47cb-a203-b8f24926cfab",
7562
   "metadata": {
7563
    "tags": []
7564
   },
7565
   "source": [
7566
    "# Serving"
7567
   ]
7568
  },
7569
  {
7570
   "cell_type": "markdown",
7571
   "id": "014a7e88-a8c1-4114-ae2a-0651363595da",
7572
   "metadata": {},
7573
   "source": [
7574
    "Now we're ready to start serving our Ray Assistant using our best configuration. We're going to use [Ray Serve](https://docs.ray.io/en/latest/serve/index.html) with [FastAPI](https://fastapi.tiangolo.com/) to develop and scale our service. First, we'll define some data structures like `Query` and `Answer` to represent the inputs and outputs to our service. We will also define a small function to load our index (assumes that the respective SQL dump file already exists). Finally, we can define our `QueryAgent` and use it to serve `POST` requests with the query. And we can serve our agent at any deployment scale we wish using the [@serve.deployment](https://docs.ray.io/en/latest/serve/api/doc/ray.serve.Deployment.html) decorator where we can specify the number of replicas, compute resources, etc."
7575
   ]
7576
  },
7577
  {
7578
   "cell_type": "code",
7579
   "execution_count": null,
7580
   "id": "31cbea09-318f-49b7-864c-ba7b8072f5be",
7581
   "metadata": {
7582
    "tags": []
7583
   },
7584
   "outputs": [],
7585
   "source": [
7586
    "import pickle\n",
7587
    "import requests\n",
7588
    "from typing import List"
7589
   ]
7590
  },
7591
  {
7592
   "cell_type": "code",
7593
   "execution_count": null,
7594
   "id": "123e53fa-8b11-4781-91fd-0947a242e0d5",
7595
   "metadata": {
7596
    "tags": []
7597
   },
7598
   "outputs": [],
7599
   "source": [
7600
    "from fastapi import FastAPI\n",
7601
    "from pydantic import BaseModel\n",
7602
    "from ray import serve\n",
7603
    "from rag.generate import QueryAgent\n",
7604
    "from rag.index import load_index"
7605
   ]
7606
  },
7607
  {
7608
   "cell_type": "code",
7609
   "execution_count": null,
7610
   "id": "cab7b615-d03b-4fb2-b099-3a02dd7e8c2f",
7611
   "metadata": {
7612
    "tags": []
7613
   },
7614
   "outputs": [],
7615
   "source": [
7616
    "# Initialize application\n",
7617
    "app = FastAPI()"
7618
   ]
7619
  },
7620
  {
7621
   "cell_type": "code",
7622
   "execution_count": null,
7623
   "id": "e6d7ab2c-4946-4c6e-92da-08450ebc5a68",
7624
   "metadata": {
7625
    "tags": []
7626
   },
7627
   "outputs": [],
7628
   "source": [
7629
    "class Query(BaseModel):\n",
7630
    "    query: str"
7631
   ]
7632
  },
7633
  {
7634
   "cell_type": "code",
7635
   "execution_count": null,
7636
   "id": "00d51f9a-fff8-40d6-92d5-0d9984e311cc",
7637
   "metadata": {
7638
    "tags": []
7639
   },
7640
   "outputs": [],
7641
   "source": [
7642
    "class Answer(BaseModel):\n",
7643
    "    question: str\n",
7644
    "    sources: List[str]\n",
7645
    "    answer: str\n",
7646
    "    llm: str"
7647
   ]
7648
  },
7649
  {
7650
   "cell_type": "code",
7651
   "execution_count": null,
7652
   "id": "24aca110-8bff-4097-aaee-33a70043c73c",
7653
   "metadata": {
7654
    "tags": []
7655
   },
7656
   "outputs": [
7657
    {
7658
     "name": "stderr",
7659
     "output_type": "stream",
7660
     "text": [
7661
      "DeprecationWarning: `route_prefix` in `@serve.deployment` has been deprecated. To specify a route prefix for an application, pass it into `serve.run` instead.\n"
7662
     ]
7663
    }
7664
   ],
7665
   "source": [
7666
    "@serve.deployment(route_prefix=\"/\", num_replicas=1, ray_actor_options={\"num_cpus\": 6, \"num_gpus\": 1})\n",
7667
    "@serve.ingress(app)\n",
7668
    "class RayAssistantDeployment:\n",
7669
    "    def __init__(self, chunk_size, chunk_overlap, num_chunks, \n",
7670
    "                 embedding_model_name, embedding_dim,\n",
7671
    "                 use_lexical_search, lexical_search_k, \n",
7672
    "                 use_reranking, rerank_threshold, rerank_k,\n",
7673
    "                 llm, sql_dump_fp=None):\n",
7674
    "        \n",
7675
    "        # Set up\n",
7676
    "        chunks = load_index(\n",
7677
    "            embedding_model_name=embedding_model_name, \n",
7678
    "            embedding_dim=embedding_dim, \n",
7679
    "            chunk_size=chunk_size, \n",
7680
    "            chunk_overlap=chunk_overlap,\n",
7681
    "            sql_dump_fp=sql_dump_fp,\n",
7682
    "        )\n",
7683
    "\n",
7684
    "        # Lexical index\n",
7685
    "        lexical_index = None\n",
7686
    "        self.lexical_search_k = lexical_search_k\n",
7687
    "        if use_lexical_search:\n",
7688
    "            texts = [re.sub(r\"[^a-zA-Z0-9]\", \" \", chunk[1]).lower().split() for chunk in chunks]\n",
7689
    "            lexical_index = BM25Okapi(texts)\n",
7690
    "\n",
7691
    "        # Reranker\n",
7692
    "        reranker = None\n",
7693
    "        self.rerank_threshold = rerank_threshold\n",
7694
    "        self.rerank_k = rerank_k\n",
7695
    "        if use_reranking:\n",
7696
    "            reranker_fp = Path(EFS_DIR, \"reranker.pkl\")\n",
7697
    "            with open(reranker_fp, \"rb\") as file:\n",
7698
    "                reranker = pickle.load(file)\n",
7699
    "\n",
7700
    "        # Query agent\n",
7701
    "        self.num_chunks = num_chunks\n",
7702
    "        system_content = \"Answer the query using the context provided. Be succinct. \" \\\n",
7703
    "            \"Contexts are organized in a list of dictionaries [{'text': <context>}, {'text': <context>}, ...]. \" \\\n",
7704
    "            \"Feel free to ignore any contexts in the list that don't seem relevant to the query. \"\n",
7705
    "        self.oss_agent = QueryAgent(\n",
7706
    "            embedding_model_name=embedding_model_name,\n",
7707
    "            chunks=chunks,\n",
7708
    "            lexical_index=lexical_index,\n",
7709
    "            reranker=reranker,\n",
7710
    "            llm=llm,\n",
7711
    "            max_context_length=MAX_CONTEXT_LENGTHS[llm],\n",
7712
    "            system_content=system_content)\n",
7713
    "        self.gpt_agent = QueryAgent(\n",
7714
    "            embedding_model_name=embedding_model_name,\n",
7715
    "            chunks=chunks,\n",
7716
    "            lexical_index=lexical_index,\n",
7717
    "            reranker=reranker,\n",
7718
    "            llm=\"gpt-4\",\n",
7719
    "            max_context_length=MAX_CONTEXT_LENGTHS[\"gpt-4\"],\n",
7720
    "            system_content=system_content)\n",
7721
    "\n",
7722
    "        # Router\n",
7723
    "        router_fp = Path(EFS_DIR, \"router.pkl\")\n",
7724
    "        with open(router_fp, \"rb\") as file:\n",
7725
    "            self.router = pickle.load(file)\n",
7726
    "\n",
7727
    "    @app.post(\"/query\")\n",
7728
    "    def query(self, query: Query) -> Answer:\n",
7729
    "        use_oss_agent = self.router.predict([query.query])[0]\n",
7730
    "        agent = self.oss_agent if use_oss_agent else self.gpt_agent\n",
7731
    "        result = agent(\n",
7732
    "            query=query.query, num_chunks=self.num_chunks, \n",
7733
    "            lexical_search_k=self.lexical_search_k, \n",
7734
    "            rerank_threshold=self.rerank_threshold, \n",
7735
    "            rerank_k=self.rerank_k, \n",
7736
    "            stream=False)\n",
7737
    "        return Answer.parse_obj(result)"
7738
   ]
7739
  },
7740
  {
7741
   "cell_type": "markdown",
7742
   "id": "79d4b6eb-cdbc-46cb-ab3d-adc246f9da9c",
7743
   "metadata": {},
7744
   "source": [
7745
    "**Note**: As we can see, Ray Serve makes [model composition](https://docs.ray.io/en/latest/serve/model_composition.html) extremely easy and we could continue to make this more fine-grained. For example, we can train a classifier to discern between queries for `mixtral-8x7b-instruct-v0.1`,` codellama-34b-instruct-hf` (for code generation) and `gpt-4` (for highly complex queries). Also, we can use streaming end-to-end to reduce the time a user has to wait for the answer. Check out the `/stream` method in `rag/serve.py`."
7746
   ]
7747
  },
7748
  {
7749
   "cell_type": "code",
7750
   "execution_count": null,
7751
   "id": "80a097e8-2cbb-4193-87cd-eca48a143301",
7752
   "metadata": {
7753
    "tags": []
7754
   },
7755
   "outputs": [
7756
    {
7757
     "name": "stderr",
7758
     "output_type": "stream",
7759
     "text": [
7760
      "\u001b[2m\u001b[36m(ServeController pid=213991)\u001b[0m WARNING 2023-11-09 22:10:03,839 controller 213991 application_state.py:663 - The deployments ['RayAssistantDeployment'] are UNHEALTHY.\n",
7761
      "\u001b[2m\u001b[36m(ServeController pid=213991)\u001b[0m INFO 2023-11-09 22:10:03,942 controller 213991 deployment_state.py:1390 - Deploying new version of deployment RayAssistantDeployment in application 'default'.\n",
7762
      "\u001b[2m\u001b[36m(ServeController pid=213991)\u001b[0m INFO 2023-11-09 22:10:04,044 controller 213991 deployment_state.py:1679 - Adding 1 replica to deployment RayAssistantDeployment in application 'default'.\n",
7763
      "\u001b[2m\u001b[36m(ServeReplica:default:RayAssistantDeployment pid=217138)\u001b[0m Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
7764
      "\u001b[2m\u001b[36m(ServeReplica:default:RayAssistantDeployment pid=217138)\u001b[0m Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
7765
      "\u001b[2m\u001b[36m(ServeReplica:default:RayAssistantDeployment pid=217138)\u001b[0m INFO 2023-11-09 22:10:27,795 RayAssistantDeployment default#RayAssistantDeployment#ZuxuxM 2dacbeb8-f591-48d6-a4fe-b86043d7ceff / default replica.py:749 - __CALL__ OK 0.2ms\n",
7766
      "2023-11-09 22:10:27,934\tINFO router.py:1132 -- Using router <class 'ray.serve._private.router.PowerOfTwoChoicesReplicaScheduler'>.\n",
7767
      "2023-11-09 22:10:27,940\tINFO router.py:473 -- Got updated replicas for deployment 'RayAssistantDeployment' in application 'default': {'default#RayAssistantDeployment#ZuxuxM'}.\n"
7768
     ]
7769
    },
7770
    {
7771
     "data": {
7772
      "text/plain": [
7773
       "RayServeSyncHandle(deployment='RayAssistantDeployment')"
7774
      ]
7775
     },
7776
     "execution_count": null,
7777
     "metadata": {},
7778
     "output_type": "execute_result"
7779
    }
7780
   ],
7781
   "source": [
7782
    "# Deploy the Ray Serve application.\n",
7783
    "deployment = RayAssistantDeployment.bind(\n",
7784
    "    chunk_size=CHUNK_SIZE,\n",
7785
    "    chunk_overlap=CHUNK_OVERLAP,\n",
7786
    "    num_chunks=NUM_CHUNKS,\n",
7787
    "    embedding_model_name=EMBEDDING_MODEL_NAME,\n",
7788
    "    embedding_dim=EMBEDDING_DIMENSIONS[EMBEDDING_MODEL_NAME],\n",
7789
    "    use_lexical_search=USE_LEXICAL_SEARCH,\n",
7790
    "    lexical_search_k=LEXICAL_SEARCH_K, \n",
7791
    "    use_reranking=USE_RERANKING, \n",
7792
    "    rerank_threshold=RERANK_THRESHOLD, \n",
7793
    "    rerank_k=RERANK_K,\n",
7794
    "    llm=LLM)\n",
7795
    "serve.run(deployment)"
7796
   ]
7797
  },
7798
  {
7799
   "cell_type": "code",
7800
   "execution_count": null,
7801
   "id": "0f37bd52-ee23-4451-950b-a4badc1c2e52",
7802
   "metadata": {
7803
    "tags": []
7804
   },
7805
   "outputs": [
7806
    {
7807
     "data": {
7808
      "text/plain": [
7809
       "{'question': 'What is the default batch size for map_batches?',\n",
7810
       " 'sources': ['https://docs.ray.io/en/master/data/batch_inference.html#configuring-batch-size',\n",
7811
       "  'https://docs.ray.io/en/master/data/transforming-data.html#configuring-batch-size',\n",
7812
       "  'https://docs.ray.io/en/master/data/examples/pytorch_resnet_batch_prediction.html#model-inference',\n",
7813
       "  'https://docs.ray.io/en/master/data/api/doc/ray.data.Dataset.map_batches.html#ray-data-dataset-map-batches',\n",
7814
       "  'https://docs.ray.io/en/master/data/batch_inference.html#configuring-batch-size',\n",
7815
       "  'https://docs.ray.io/en/master/data/examples/huggingface_vit_batch_prediction.html#step-3-scaling-up-to-the-full-dataset-with-ray-data',\n",
7816
       "  'https://docs.ray.io/en/master/serve/advanced-guides/dyn-req-batch.html#tips-for-fine-tuning-batching-parameters',\n",
7817
       "  'https://docs.ray.io/en/master/data/api/doc/ray.data.Dataset.map_batches.html#ray-data-dataset-map-batches',\n",
7818
       "  'https://docs.ray.io/en/master/data/examples/batch_inference_object_detection.html#model-inference'],\n",
7819
       " 'answer': '  The default batch size for map_batches is 4096.',\n",
7820
       " 'llm': 'codellama/CodeLlama-34b-Instruct-hf'}"
7821
      ]
7822
     },
7823
     "execution_count": null,
7824
     "metadata": {},
7825
     "output_type": "execute_result"
7826
    },
7827
    {
7828
     "name": "stderr",
7829
     "output_type": "stream",
7830
     "text": [
7831
      "\u001b[2m\u001b[36m(ServeReplica:default:RayAssistantDeployment pid=217138)\u001b[0m INFO 2023-11-09 22:10:35,329 RayAssistantDeployment default#RayAssistantDeployment#ZuxuxM 380d5b6f-c3f8-43fa-a3c9-c22d5b04b47d /query default replica.py:749 - __CALL__ OK 3430.1ms\n"
7832
     ]
7833
    }
7834
   ],
7835
   "source": [
7836
    "# Inference\n",
7837
    "data = {\"query\": \"What is the default batch size for map_batches?\"}\n",
7838
    "response = requests.post(\"http://127.0.0.1:8000/query\", json=data)\n",
7839
    "response.json()"
7840
   ]
7841
  },
7842
  {
7843
   "cell_type": "markdown",
7844
   "id": "1f0bc8bd-5d8a-4936-8460-8715464ed230",
7845
   "metadata": {},
7846
   "source": [
7847
    "**Note**: As we can see, Ray Serve makes [model composition](https://docs.ray.io/en/latest/serve/model_composition.html) extremely easy and we could continue to make this even more fine-grained with more workflow logic.\n",
7848
    "\n",
7849
    "Once our application is served, we’re free to use it anywhere we want. For example, we use it as a bot on our Slack channels and as a widget on our docs page (public release coming soon). We can use this to collect feedback from our users to continually improve the application (fine-tuning, UI/UX, etc.).\n",
7850
    "\n",
7851
    "<img width=\"600\" src=\"https://images.ctfassets.net/xjan103pcp94/7pyW8T7La5T51C8iXEwmAO/706dc8ed0ca75cdcbf971d9e74cd67b3/Screenshot_2023-10-24_at_12.56.39_PM.png\">"
7852
   ]
7853
  },
7854
  {
7855
   "cell_type": "code",
7856
   "execution_count": null,
7857
   "id": "74566075-7ad8-4269-8be5-5dc2602ef06b",
7858
   "metadata": {
7859
    "tags": []
7860
   },
7861
   "outputs": [
7862
    {
7863
     "name": "stderr",
7864
     "output_type": "stream",
7865
     "text": [
7866
      "2023-11-09 22:10:41,238\tINFO router.py:473 -- Got updated replicas for deployment 'RayAssistantDeployment' in application 'default': set().\n",
7867
      "\u001b[2m\u001b[36m(ServeController pid=213991)\u001b[0m INFO 2023-11-09 22:10:41,235 controller 213991 deployment_state.py:1707 - Removing 1 replica from deployment 'RayAssistantDeployment' in application 'default'.\n",
7868
      "\u001b[2m\u001b[36m(ServeController pid=213991)\u001b[0m INFO 2023-11-09 22:10:43,848 controller 213991 deployment_state.py:2027 - Replica default#RayAssistantDeployment#ZuxuxM is stopped.\n"
7869
     ]
7870
    }
7871
   ],
7872
   "source": [
7873
    "# Shutdown\n",
7874
    "serve.shutdown()"
7875
   ]
7876
  },
7877
  {
7878
   "cell_type": "markdown",
7879
   "id": "101dab0a-5a8d-4156-b6d2-2f38ff333add",
7880
   "metadata": {},
7881
   "source": [
7882
    "# Data flywheel"
7883
   ]
7884
  },
7885
  {
7886
   "cell_type": "markdown",
7887
   "id": "3d57ef4c-fa7c-491f-befe-f894ee036cab",
7888
   "metadata": {},
7889
   "source": [
7890
    "Creating an application like this is not a one-time task. It's extremely important that we continue to iterate and keep our application up to date. This includes continually reindexing our data so that our application is working with the most up-to-date information. As well as rerunning our experiments to see if any of the decisions need to be altered. This process of continuous iteration can be achieved by mapping our workflows to [CI/CD pipelines](https://madewithml.com/courses/mlops/cicd/).\n",
7891
    "\n",
7892
    "A key part of iteration that goes beyond automated reindexing, evaluation, etc. involves fixing our data itself. In fact, we found that this is the **most** impactful lever (way beyond our retrieval and generation optimizations above) we could control. Here is an example workflow we've settled on:\n",
7893
    "1. Users use the RAG application to ask questions about the product.\n",
7894
    "2. Use feedback (👍/👎, visited source pages, top-k cosine scores, etc.) to identify underperforming queries.\n",
7895
    "3. Inspect the retrieved resources, tokenization, etc. to decide if it's a shortcoming of retrieval, generation or the underlying data source.\n",
7896
    "4. If something in the data can be improved, separated into sections/pages, etc. → fix it!\n",
7897
    "5. Reindex and deploy a new version of the application."
7898
   ]
7899
  },
7900
  {
7901
   "cell_type": "markdown",
7902
   "id": "cf3f9a4a-15bb-4567-82bb-e2a0808f1616",
7903
   "metadata": {
7904
    "tags": []
7905
   },
7906
   "source": [
7907
    "# Impact"
7908
   ]
7909
  },
7910
  {
7911
   "cell_type": "markdown",
7912
   "id": "615b2ad0-94a1-477c-92ee-6bdab3dcc99e",
7913
   "metadata": {},
7914
   "source": [
7915
    "## Products and productivity"
7916
   ]
7917
  },
7918
  {
7919
   "cell_type": "markdown",
7920
   "id": "6df6c44c-10e5-482f-a715-efd0b2c068a1",
7921
   "metadata": {},
7922
   "source": [
7923
    "Building an LLM application like this has had a tremendous impact on our products and company. There were expected 1st order impacts in overall developer and user adoption for our products. The capability to interact and solve problems that our users experience in a self-serve and immediate manner is the type of feature that would improve the experience of any product. It makes it significantly easier for people to succeed and it elevated the perception around LLM applications from a **nice-to-have** to a **must-have**. "
7924
   ]
7925
  },
7926
  {
7927
   "cell_type": "markdown",
7928
   "id": "fce8c9b5-cce0-467c-81b9-1aefcb7437d3",
7929
   "metadata": {},
7930
   "source": [
7931
    "## Foundational agents"
7932
   ]
7933
  },
7934
  {
7935
   "cell_type": "markdown",
7936
   "id": "e895c161-f035-49dc-84d5-1f70d2b65ef9",
7937
   "metadata": {},
7938
   "source": [
7939
    "However, there were also some 2nd order impacts that we didn’t immediately realize. For example, when we further inspected user queries that yielded poor scores, often the issue existed because of a gap in our documentation. When we made the fix (ex. added the appropriate section to our docs), this improved our product and the LLM application itself — creating a very valuable feedback flywheel. Furthermore, when internal teams learned of the capabilities of our LLM application, this generated the development of highly valuable LLM applications that depend on this Ray docs LLM application as one of its **foundational agents** that it uses to perform its tasks."
7940
   ]
7941
  },
7942
  {
7943
   "cell_type": "markdown",
7944
   "id": "5c79bcbe-89eb-469d-a809-acb780cec143",
7945
   "metadata": {},
7946
   "source": [
7947
    "<img width=\"700\" src=\"https://images.ctfassets.net/xjan103pcp94/2UF2tSV3kmXtrzmqMsYrLF/76bcc71b481986eb6cb3b06d60582ec5/image18.png\">"
7948
   ]
7949
  },
7950
  {
7951
   "cell_type": "markdown",
7952
   "id": "fc4b13ef-11a8-472d-85a5-48b480628e59",
7953
   "metadata": {},
7954
   "source": [
7955
    "For example, we’ve internally developed a feature called Anyscale Doctor that helps developers diagnose and debug issues during development. Issues in code can be caused by a variety of reasons but when the issue is Ray related, the LLM application we built here is called to aid in resolving the particular issue."
7956
   ]
7957
  },
7958
  {
7959
   "cell_type": "markdown",
7960
   "id": "0c61ff8d-d76b-47d5-8ae8-85d7aaf41b19",
7961
   "metadata": {},
7962
   "source": [
7963
    "# Learn more"
7964
   ]
7965
  },
7966
  {
7967
   "cell_type": "markdown",
7968
   "id": "6f451be7-3468-4dab-aed9-2e6621ad0a10",
7969
   "metadata": {},
7970
   "source": [
7971
    "- If your team is investing heavily in developing LLM applications, [reach out](mailto:endpoints-help@anyscale.com) to us to learn more about how [Ray](https://github.com/ray-project/ray) and [Anyscale](http://anyscale.com/) can help you scale and productionize everything.\n",
7972
    "- Start serving (+fine-tuning) OSS LLMs with [Anyscale Endpoints](https://www.anyscale.com/endpoints) ($1/M tokens for Llama-2-70b) w/ 1M free tokens trial.\n",
7973
    "- If you need to deploy on your own private cloud, check out [Anyscale Private Endpoints](https://www.anyscale.com/endpoints#private).\n",
7974
    "- Learn more about how companies like OpenAI, Netflix, Pinterest, Verizon, Instacart and others leverage Ray and Anyscale for their AI workloads at the [Ray Summit](https://raysummit.anyscale.com/).\n"
7975
   ]
7976
  }
7977
 ],
7978
 "metadata": {
7979
  "kernelspec": {
7980
   "display_name": "Python 3 (ipykernel)",
7981
   "language": "python",
7982
   "name": "python3"
7983
  },
7984
  "language_info": {
7985
   "codemirror_mode": {
7986
    "name": "ipython",
7987
    "version": 3
7988
   },
7989
   "file_extension": ".py",
7990
   "mimetype": "text/x-python",
7991
   "name": "python",
7992
   "nbconvert_exporter": "python",
7993
   "pygments_lexer": "ipython3",
7994
   "version": "3.10.8"
7995
  }
7996
 },
7997
 "nbformat": 4,
7998
 "nbformat_minor": 5
7999
}
8000

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.