fastembed

Форк
0
/
Throughput_Across_Models.ipynb 
344 строки · 43.5 Кб
1
{
2
 "cells": [
3
  {
4
   "cell_type": "markdown",
5
   "metadata": {},
6
   "source": [
7
    "# 🤗 Huggingface vs ⚡ FastEmbed️\n",
8
    "\n",
9
    "Comparing the performance of Huggingface's 🤗 Transformers and ⚡ FastEmbed️ on a simple task on the following machine: Apple M2 Max, 32 GB RAM\n",
10
    "\n",
11
    "## 📦 Imports\n",
12
    "\n",
13
    "Importing the necessary libraries for this comparison."
14
   ]
15
  },
16
  {
17
   "cell_type": "code",
18
   "execution_count": 1,
19
   "metadata": {
20
    "ExecuteTime": {
21
     "end_time": "2024-03-30T00:33:35.753669Z",
22
     "start_time": "2024-03-30T00:33:34.371658Z"
23
    }
24
   },
25
   "outputs": [
26
    {
27
     "name": "stderr",
28
     "output_type": "stream",
29
     "text": [
30
      "/Users/joein/work/qdrant/fastembed/venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
31
      "  from .autonotebook import tqdm as notebook_tqdm\n"
32
     ]
33
    }
34
   ],
35
   "source": [
36
    "import time\n",
37
    "from typing import Callable, List, Tuple\n",
38
    "\n",
39
    "import torch.nn.functional as F\n",
40
    "from fastembed import TextEmbedding\n",
41
    "import matplotlib.pyplot as plt\n",
42
    "from transformers import AutoModel, AutoTokenizer"
43
   ]
44
  },
45
  {
46
   "cell_type": "markdown",
47
   "metadata": {},
48
   "source": [
49
    "## 📖 Data\n",
50
    "\n",
51
    "data is a list of strings, each string is a document."
52
   ]
53
  },
54
  {
55
   "cell_type": "code",
56
   "execution_count": 2,
57
   "metadata": {
58
    "ExecuteTime": {
59
     "end_time": "2024-03-30T00:33:35.766679Z",
60
     "start_time": "2024-03-30T00:33:35.755112Z"
61
    }
62
   },
63
   "outputs": [
64
    {
65
     "data": {
66
      "text/plain": "12"
67
     },
68
     "execution_count": 2,
69
     "metadata": {},
70
     "output_type": "execute_result"
71
    }
72
   ],
73
   "source": [
74
    "documents: List[str] = [\n",
75
    "    \"Chandrayaan-3 is India's third lunar mission\",\n",
76
    "    \"It aimed to land a rover on the Moon's surface - joining the US, China and Russia\",\n",
77
    "    \"The mission is a follow-up to Chandrayaan-2, which had partial success\",\n",
78
    "    \"Chandrayaan-3 will be launched by the Indian Space Research Organisation (ISRO)\",\n",
79
    "    \"The estimated cost of the mission is around $35 million\",\n",
80
    "    \"It will carry instruments to study the lunar surface and atmosphere\",\n",
81
    "    \"Chandrayaan-3 landed on the Moon's surface on 23rd August 2023\",\n",
82
    "    \"It consists of a lander named Vikram and a rover named Pragyan similar to Chandrayaan-2. Its propulsion module would act like an orbiter.\",\n",
83
    "    \"The propulsion module carries the lander and rover configuration until the spacecraft is in a 100-kilometre (62 mi) lunar orbit\",\n",
84
    "    \"The mission used GSLV Mk III rocket for its launch\",\n",
85
    "    \"Chandrayaan-3 was launched from the Satish Dhawan Space Centre in Sriharikota\",\n",
86
    "    \"Chandrayaan-3 was launched earlier in the year 2023\",\n",
87
    "]\n",
88
    "len(documents)"
89
   ]
90
  },
91
  {
92
   "cell_type": "code",
93
   "execution_count": 3,
94
   "metadata": {
95
    "ExecuteTime": {
96
     "end_time": "2024-03-30T00:33:35.766791Z",
97
     "start_time": "2024-03-30T00:33:35.756803Z"
98
    }
99
   },
100
   "outputs": [],
101
   "source": [
102
    "model_id = \"BAAI/bge-small-en\""
103
   ]
104
  },
105
  {
106
   "cell_type": "markdown",
107
   "metadata": {},
108
   "source": [
109
    "## Setting up 🤗 Huggingface\n",
110
    "\n",
111
    "We'll be using the [Huggingface Transformers](https://huggingface.co/transformers/) with PyTorch library to generate embeddings. We'll be using the same model across both libraries for a fair(er?) comparison."
112
   ]
113
  },
114
  {
115
   "cell_type": "code",
116
   "execution_count": 4,
117
   "metadata": {
118
    "ExecuteTime": {
119
     "end_time": "2024-03-30T00:34:03.988Z",
120
     "start_time": "2024-03-30T00:33:37.460865Z"
121
    }
122
   },
123
   "outputs": [
124
    {
125
     "name": "stderr",
126
     "output_type": "stream",
127
     "text": [
128
      "config.json: 100%|██████████| 684/684 [00:00<00:00, 491kB/s]\n",
129
      "model.safetensors: 100%|██████████| 133M/133M [00:21<00:00, 6.24MB/s] \n",
130
      "tokenizer_config.json: 100%|██████████| 366/366 [00:00<00:00, 4.06MB/s]\n",
131
      "vocab.txt: 100%|██████████| 232k/232k [00:00<00:00, 1.12MB/s]\n",
132
      "tokenizer.json: 100%|██████████| 711k/711k [00:00<00:00, 1.59MB/s]\n",
133
      "special_tokens_map.json: 100%|██████████| 125/125 [00:00<00:00, 399kB/s]\n"
134
     ]
135
    },
136
    {
137
     "data": {
138
      "text/plain": "torch.Size([12, 384])"
139
     },
140
     "execution_count": 4,
141
     "metadata": {},
142
     "output_type": "execute_result"
143
    }
144
   ],
145
   "source": [
146
    "class HF:\n",
147
    "    \"\"\"\n",
148
    "    HuggingFace Transformer implementation of FlagEmbedding\n",
149
    "    Based on https://huggingface.co/BAAI/bge-base-en\n",
150
    "    \"\"\"\n",
151
    "\n",
152
    "    def __init__(self, model_id: str):\n",
153
    "        self.model = AutoModel.from_pretrained(model_id)\n",
154
    "        self.tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
155
    "\n",
156
    "    def embed(self, texts: List[str]):\n",
157
    "        encoded_input = self.tokenizer(texts, max_length=512, padding=True, truncation=True, return_tensors=\"pt\")\n",
158
    "        model_output = self.model(**encoded_input)\n",
159
    "        sentence_embeddings = model_output[0][:, 0]\n",
160
    "        sentence_embeddings = F.normalize(sentence_embeddings)\n",
161
    "        return sentence_embeddings\n",
162
    "\n",
163
    "\n",
164
    "hf = HF(model_id=model_id)\n",
165
    "hf.embed(documents).shape"
166
   ]
167
  },
168
  {
169
   "cell_type": "markdown",
170
   "metadata": {},
171
   "source": [
172
    "## Setting up ⚡️FastEmbed\n",
173
    "\n",
174
    "Sorry, don't have a lot to set up here. We'll be using the default model, which is Flag Embedding, same as the Huggingface model."
175
   ]
176
  },
177
  {
178
   "cell_type": "code",
179
   "execution_count": 5,
180
   "metadata": {
181
    "ExecuteTime": {
182
     "end_time": "2024-03-30T00:34:04.076422Z",
183
     "start_time": "2024-03-30T00:34:03.987162Z"
184
    }
185
   },
186
   "outputs": [],
187
   "source": [
188
    "embedding_model = TextEmbedding(model_name=model_id)"
189
   ]
190
  },
191
  {
192
   "cell_type": "markdown",
193
   "metadata": {},
194
   "source": [
195
    "## 📊 Comparison\n",
196
    "\n",
197
    "We'll be comparing the following metrics: Minimum, Maximum, Mean, across k runs. Let's write a function to do that:\n",
198
    "\n",
199
    "### 🚀 Calculating Stats"
200
   ]
201
  },
202
  {
203
   "cell_type": "code",
204
   "execution_count": 6,
205
   "metadata": {
206
    "ExecuteTime": {
207
     "end_time": "2024-03-30T00:34:06.543782Z",
208
     "start_time": "2024-03-30T00:34:06.357816Z"
209
    }
210
   },
211
   "outputs": [
212
    {
213
     "name": "stdout",
214
     "output_type": "stream",
215
     "text": [
216
      "Huggingface Transformers (Average, Max, Min): (0.05358994007110596, 0.0568850040435791, 0.05029487609863281)\n",
217
      "FastEmbed (Average, Max, Min): (0.035953521728515625, 0.03631591796875, 0.03559112548828125)\n"
218
     ]
219
    }
220
   ],
221
   "source": [
222
    "import types\n",
223
    "\n",
224
    "\n",
225
    "def calculate_time_stats(embed_func: Callable, documents: list, k: int) -> Tuple[float, float, float]:\n",
226
    "    times = []\n",
227
    "    for _ in range(k):\n",
228
    "        # Timing the embed_func call\n",
229
    "        start_time = time.time()\n",
230
    "        embeddings = embed_func(documents)\n",
231
    "        # Force computation if embed_func returns a generator\n",
232
    "        if isinstance(embeddings, types.GeneratorType):\n",
233
    "            list(embeddings)\n",
234
    "\n",
235
    "        end_time = time.time()\n",
236
    "        times.append(end_time - start_time)\n",
237
    "\n",
238
    "    # Returning mean, max, and min time for the call\n",
239
    "    return (sum(times) / k, max(times), min(times))\n",
240
    "\n",
241
    "\n",
242
    "hf_stats = calculate_time_stats(hf.embed, documents, k=2)\n",
243
    "print(f\"Huggingface Transformers (Average, Max, Min): {hf_stats}\")\n",
244
    "fst_stats = calculate_time_stats(embedding_model.embed, documents, k=2)\n",
245
    "print(f\"FastEmbed (Average, Max, Min): {fst_stats}\")"
246
   ]
247
  },
248
  {
249
   "cell_type": "markdown",
250
   "metadata": {},
251
   "source": [
252
    "## 📈 Results\n",
253
    "\n",
254
    "Let's run the comparison and see the results."
255
   ]
256
  },
257
  {
258
   "cell_type": "code",
259
   "execution_count": 7,
260
   "metadata": {
261
    "ExecuteTime": {
262
     "end_time": "2024-03-30T00:34:11.032206Z",
263
     "start_time": "2024-03-30T00:34:10.828410Z"
264
    }
265
   },
266
   "outputs": [
267
    {
268
     "data": {
269
      "text/plain": "<Figure size 640x480 with 1 Axes>",
270
      "image/png": ""
271
     },
272
     "metadata": {},
273
     "output_type": "display_data"
274
    }
275
   ],
276
   "source": [
277
    "def plot_character_per_second_comparison(\n",
278
    "    hf_stats: Tuple[float, float, float], fst_stats: Tuple[float, float, float], documents: list\n",
279
    "):\n",
280
    "    # Calculating total characters in documents\n",
281
    "    total_characters = sum(len(doc) for doc in documents)\n",
282
    "\n",
283
    "    # Calculating characters per second for each model\n",
284
    "    hf_chars_per_sec = total_characters / hf_stats[0]  # Mean time is at index 0\n",
285
    "    fst_chars_per_sec = total_characters / fst_stats[0]\n",
286
    "\n",
287
    "    # Plotting the bar chart\n",
288
    "    models = [\"HF Embed (Torch)\", \"FastEmbed\"]\n",
289
    "    chars_per_sec = [hf_chars_per_sec, fst_chars_per_sec]\n",
290
    "\n",
291
    "    bars = plt.bar(models, chars_per_sec, color=[\"#1f356c\", \"#dd1f4b\"])\n",
292
    "    plt.ylabel(\"Characters per Second\")\n",
293
    "    plt.title(\"Characters Processed per Second Comparison\")\n",
294
    "\n",
295
    "    # Adding the number at the top of each bar\n",
296
    "    for bar, chars in zip(bars, chars_per_sec):\n",
297
    "        plt.text(\n",
298
    "            bar.get_x() + bar.get_width() / 2,\n",
299
    "            bar.get_height(),\n",
300
    "            f\"{chars:.1f}\",\n",
301
    "            ha=\"center\",\n",
302
    "            va=\"bottom\",\n",
303
    "            color=\"#1f356c\",\n",
304
    "            fontsize=12,\n",
305
    "        )\n",
306
    "\n",
307
    "    plt.show()\n",
308
    "\n",
309
    "\n",
310
    "plot_character_per_second_comparison(hf_stats, fst_stats, documents)"
311
   ]
312
  },
313
  {
314
   "cell_type": "code",
315
   "outputs": [],
316
   "source": [],
317
   "metadata": {
318
    "collapsed": false
319
   }
320
  }
321
 ],
322
 "metadata": {
323
  "kernelspec": {
324
   "display_name": "fst",
325
   "language": "python",
326
   "name": "python3"
327
  },
328
  "language_info": {
329
   "codemirror_mode": {
330
    "name": "ipython",
331
    "version": 3
332
   },
333
   "file_extension": ".py",
334
   "mimetype": "text/x-python",
335
   "name": "python",
336
   "nbconvert_exporter": "python",
337
   "pygments_lexer": "ipython3",
338
   "version": "3.9.17"
339
  },
340
  "orig_nbformat": 4
341
 },
342
 "nbformat": 4,
343
 "nbformat_minor": 2
344
}
345

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.