fastembed

Форк
0
/
FastEmbed_vs_HF_Comparison.ipynb 
428 строк · 50.4 Кб
1
{
2
 "cells": [
3
  {
4
   "cell_type": "markdown",
5
   "metadata": {},
6
   "source": [
7
    "# 🤗 Huggingface vs ⚡ FastEmbed️\n",
8
    "\n",
9
    "Comparing the performance of Huggingface's 🤗 Transformers and ⚡ FastEmbed️ on a simple task on the following machine: Apple M2 Max, 32 GB RAM\n",
10
    "\n",
11
    "## 📦 Imports\n",
12
    "\n",
13
    "Importing the necessary libraries for this comparison."
14
   ]
15
  },
16
  {
17
   "cell_type": "code",
18
   "execution_count": 3,
19
   "metadata": {},
20
   "outputs": [
21
    {
22
     "name": "stdout",
23
     "output_type": "stream",
24
     "text": [
25
      "Requirement already satisfied: matplotlib in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (3.8.3)\n",
26
      "Requirement already satisfied: transformers in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (4.39.2)\n",
27
      "Requirement already satisfied: torch in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (2.2.0)\n",
28
      "Requirement already satisfied: contourpy>=1.0.1 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from matplotlib) (1.2.0)\n",
29
      "Requirement already satisfied: cycler>=0.10 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from matplotlib) (0.12.1)\n",
30
      "Requirement already satisfied: fonttools>=4.22.0 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from matplotlib) (4.50.0)\n",
31
      "Requirement already satisfied: kiwisolver>=1.3.1 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from matplotlib) (1.4.5)\n",
32
      "Requirement already satisfied: numpy<2,>=1.21 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from matplotlib) (1.26.4)\n",
33
      "Requirement already satisfied: packaging>=20.0 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from matplotlib) (24.0)\n",
34
      "Requirement already satisfied: pillow>=8 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from matplotlib) (10.2.0)\n",
35
      "Requirement already satisfied: pyparsing>=2.3.1 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from matplotlib) (3.1.2)\n",
36
      "Requirement already satisfied: python-dateutil>=2.7 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from matplotlib) (2.9.0.post0)\n",
37
      "Requirement already satisfied: filelock in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from transformers) (3.13.1)\n",
38
      "Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from transformers) (0.20.3)\n",
39
      "Requirement already satisfied: pyyaml>=5.1 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from transformers) (6.0.1)\n",
40
      "Requirement already satisfied: regex!=2019.12.17 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from transformers) (2023.12.25)\n",
41
      "Requirement already satisfied: requests in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from transformers) (2.31.0)\n",
42
      "Requirement already satisfied: tokenizers<0.19,>=0.14 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from transformers) (0.15.2)\n",
43
      "Requirement already satisfied: safetensors>=0.4.1 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from transformers) (0.4.2)\n",
44
      "Requirement already satisfied: tqdm>=4.27 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from transformers) (4.66.2)\n",
45
      "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from torch) (4.10.0)\n",
46
      "Requirement already satisfied: sympy in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from torch) (1.12)\n",
47
      "Requirement already satisfied: networkx in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from torch) (3.2.1)\n",
48
      "Requirement already satisfied: jinja2 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from torch) (3.1.3)\n",
49
      "Requirement already satisfied: fsspec in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from torch) (2024.2.0)\n",
50
      "Requirement already satisfied: six>=1.5 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)\n",
51
      "Requirement already satisfied: MarkupSafe>=2.0 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from jinja2->torch) (2.1.5)\n",
52
      "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from requests->transformers) (3.3.2)\n",
53
      "Requirement already satisfied: idna<4,>=2.5 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from requests->transformers) (3.6)\n",
54
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from requests->transformers) (2.2.1)\n",
55
      "Requirement already satisfied: certifi>=2017.4.17 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from requests->transformers) (2024.2.2)\n",
56
      "Requirement already satisfied: mpmath>=0.19 in /opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.10/site-packages (from sympy->torch) (1.3.0)\n"
57
     ]
58
    }
59
   ],
60
   "source": [
61
    "!pip install matplotlib transformers torch"
62
   ]
63
  },
64
  {
65
   "cell_type": "code",
66
   "execution_count": 4,
67
   "metadata": {
68
    "ExecuteTime": {
69
     "end_time": "2024-03-30T00:38:48.671752Z",
70
     "start_time": "2024-03-30T00:38:48.669409Z"
71
    }
72
   },
73
   "outputs": [],
74
   "source": [
75
    "import time\n",
76
    "from typing import Callable, List, Tuple\n",
77
    "\n",
78
    "import torch.nn.functional as F\n",
79
    "from fastembed import TextEmbedding\n",
80
    "import matplotlib.pyplot as plt\n",
81
    "from torch import Tensor\n",
82
    "from transformers import AutoModel, AutoTokenizer"
83
   ]
84
  },
85
  {
86
   "cell_type": "markdown",
87
   "metadata": {},
88
   "source": [
89
    "## 📖 Data\n",
90
    "\n",
91
    "data is a list of strings, each string is a document."
92
   ]
93
  },
94
  {
95
   "cell_type": "code",
96
   "execution_count": 5,
97
   "metadata": {
98
    "ExecuteTime": {
99
     "end_time": "2024-03-30T00:43:34.512097Z",
100
     "start_time": "2024-03-30T00:43:34.509352Z"
101
    }
102
   },
103
   "outputs": [
104
    {
105
     "data": {
106
      "text/plain": [
107
       "12"
108
      ]
109
     },
110
     "execution_count": 5,
111
     "metadata": {},
112
     "output_type": "execute_result"
113
    }
114
   ],
115
   "source": [
116
    "documents: List[str] = [\n",
117
    "    \"Chandrayaan-3 is India's third lunar mission\",\n",
118
    "    \"It aimed to land a rover on the Moon's surface - joining the US, China and Russia\",\n",
119
    "    \"The mission is a follow-up to Chandrayaan-2, which had partial success\",\n",
120
    "    \"Chandrayaan-3 will be launched by the Indian Space Research Organisation (ISRO)\",\n",
121
    "    \"The estimated cost of the mission is around $35 million\",\n",
122
    "    \"It will carry instruments to study the lunar surface and atmosphere\",\n",
123
    "    \"Chandrayaan-3 landed on the Moon's surface on 23rd August 2023\",\n",
124
    "    \"It consists of a lander named Vikram and a rover named Pragyan similar to Chandrayaan-2. Its propulsion module would act like an orbiter.\",\n",
125
    "    \"The propulsion module carries the lander and rover configuration until the spacecraft is in a 100-kilometre (62 mi) lunar orbit\",\n",
126
    "    \"The mission used GSLV Mk III rocket for its launch\",\n",
127
    "    \"Chandrayaan-3 was launched from the Satish Dhawan Space Centre in Sriharikota\",\n",
128
    "    \"Chandrayaan-3 was launched earlier in the year 2023\",\n",
129
    "]\n",
130
    "len(documents)"
131
   ]
132
  },
133
  {
134
   "cell_type": "markdown",
135
   "metadata": {},
136
   "source": [
137
    "## Setting up 🤗 Huggingface\n",
138
    "\n",
139
    "We'll be using the [Huggingface Transformers](https://huggingface.co/transformers/) with PyTorch library to generate embeddings. We'll be using the same model across both libraries for a fair(er?) comparison."
140
   ]
141
  },
142
  {
143
   "cell_type": "code",
144
   "execution_count": 6,
145
   "metadata": {
146
    "ExecuteTime": {
147
     "end_time": "2024-03-30T00:43:35.417504Z",
148
     "start_time": "2024-03-30T00:43:34.800606Z"
149
    }
150
   },
151
   "outputs": [
152
    {
153
     "data": {
154
      "text/plain": [
155
       "torch.Size([12, 384])"
156
      ]
157
     },
158
     "execution_count": 6,
159
     "metadata": {},
160
     "output_type": "execute_result"
161
    }
162
   ],
163
   "source": [
164
    "class HF:\n",
165
    "    \"\"\"\n",
166
    "    HuggingFace Transformer implementation of FlagEmbedding\n",
167
    "    \"\"\"\n",
168
    "\n",
169
    "    def __init__(self, model_id: str):\n",
170
    "        self.model = AutoModel.from_pretrained(model_id)\n",
171
    "        self.tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
172
    "\n",
173
    "    def embed(self, texts: List[str]):\n",
174
    "        encoded_input = self.tokenizer(\n",
175
    "            texts, max_length=512, padding=True, truncation=True, return_tensors=\"pt\"\n",
176
    "        )\n",
177
    "        model_output = self.model(**encoded_input)\n",
178
    "        sentence_embeddings = model_output[0][:, 0]\n",
179
    "        sentence_embeddings = F.normalize(sentence_embeddings)\n",
180
    "        return sentence_embeddings\n",
181
    "\n",
182
    "\n",
183
    "model_id = \"BAAI/bge-small-en\"\n",
184
    "hf = HF(model_id=model_id)\n",
185
    "hf.embed(documents).shape"
186
   ]
187
  },
188
  {
189
   "cell_type": "markdown",
190
   "metadata": {},
191
   "source": [
192
    "## Setting up ⚡️FastEmbed\n",
193
    "\n",
194
    "Sorry, don't have a lot to set up here. We'll be using the default model, which is Flag Embedding, same as the Huggingface model."
195
   ]
196
  },
197
  {
198
   "cell_type": "code",
199
   "execution_count": 7,
200
   "metadata": {
201
    "ExecuteTime": {
202
     "end_time": "2024-03-30T00:43:35.486719Z",
203
     "start_time": "2024-03-30T00:43:35.416166Z"
204
    }
205
   },
206
   "outputs": [
207
    {
208
     "name": "stderr",
209
     "output_type": "stream",
210
     "text": [
211
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
212
      "To disable this warning, you can either:\n",
213
      "\t- Avoid using `tokenizers` before the fork if possible\n",
214
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
215
      "100%|██████████| 77.7M/77.7M [00:15<00:00, 4.99MiB/s]\n"
216
     ]
217
    }
218
   ],
219
   "source": [
220
    "embedding_model = TextEmbedding(model_name=model_id)"
221
   ]
222
  },
223
  {
224
   "cell_type": "markdown",
225
   "metadata": {},
226
   "source": [
227
    "## 📊 Comparison\n",
228
    "\n",
229
    "We'll be comparing the following metrics: Minimum, Maximum, Mean, across k runs. Let's write a function to do that:\n",
230
    "\n",
231
    "### 🚀 Calculating Stats"
232
   ]
233
  },
234
  {
235
   "cell_type": "code",
236
   "execution_count": 8,
237
   "metadata": {
238
    "ExecuteTime": {
239
     "end_time": "2024-03-30T00:43:35.693539Z",
240
     "start_time": "2024-03-30T00:43:35.488973Z"
241
    }
242
   },
243
   "outputs": [
244
    {
245
     "name": "stdout",
246
     "output_type": "stream",
247
     "text": [
248
      "Huggingface Transformers (Average, Max, Min): (0.07326400279998779, 0.08054399490356445, 0.06598401069641113)\n",
249
      "FastEmbed (Average, Max, Min): (0.09258604049682617, 0.13388705253601074, 0.0512850284576416)\n"
250
     ]
251
    }
252
   ],
253
   "source": [
254
    "import types\n",
255
    "\n",
256
    "\n",
257
    "def calculate_time_stats(\n",
258
    "    embed_func: Callable, documents: list, k: int\n",
259
    ") -> Tuple[float, float, float]:\n",
260
    "    times = []\n",
261
    "    for _ in range(k):\n",
262
    "        # Timing the embed_func call\n",
263
    "        start_time = time.time()\n",
264
    "        embeddings = embed_func(documents)\n",
265
    "        # Force computation if embed_func returns a generator\n",
266
    "        if isinstance(embeddings, types.GeneratorType):\n",
267
    "            list(embeddings)\n",
268
    "\n",
269
    "        end_time = time.time()\n",
270
    "        times.append(end_time - start_time)\n",
271
    "\n",
272
    "    # Returning mean, max, and min time for the call\n",
273
    "    return (sum(times) / k, max(times), min(times))\n",
274
    "\n",
275
    "\n",
276
    "hf_stats = calculate_time_stats(hf.embed, documents, k=2)\n",
277
    "print(f\"Huggingface Transformers (Average, Max, Min): {hf_stats}\")\n",
278
    "fst_stats = calculate_time_stats(embedding_model.embed, documents, k=2)\n",
279
    "print(f\"FastEmbed (Average, Max, Min): {fst_stats}\")"
280
   ]
281
  },
282
  {
283
   "cell_type": "markdown",
284
   "metadata": {},
285
   "source": [
286
    "## 📈 Results\n",
287
    "\n",
288
    "Let's run the comparison and see the results."
289
   ]
290
  },
291
  {
292
   "cell_type": "code",
293
   "execution_count": 9,
294
   "metadata": {
295
    "ExecuteTime": {
296
     "end_time": "2024-03-30T00:43:35.746781Z",
297
     "start_time": "2024-03-30T00:43:35.698423Z"
298
    }
299
   },
300
   "outputs": [
301
    {
302
     "data": {
303
      "image/png": "",
304
      "text/plain": [
305
       "<Figure size 640x480 with 1 Axes>"
306
      ]
307
     },
308
     "metadata": {},
309
     "output_type": "display_data"
310
    }
311
   ],
312
   "source": [
313
    "def plot_character_per_second_comparison(\n",
314
    "    hf_stats: Tuple[float, float, float], fst_stats: Tuple[float, float, float], documents: list\n",
315
    "):\n",
316
    "    # Calculating total characters in documents\n",
317
    "    total_characters = sum(len(doc) for doc in documents)\n",
318
    "\n",
319
    "    # Calculating characters per second for each model\n",
320
    "    hf_chars_per_sec = total_characters / hf_stats[0]  # Mean time is at index 0\n",
321
    "    fst_chars_per_sec = total_characters / fst_stats[0]\n",
322
    "\n",
323
    "    # Plotting the bar chart\n",
324
    "    models = [\"HF Embed (Torch)\", \"FastEmbed\"]\n",
325
    "    chars_per_sec = [hf_chars_per_sec, fst_chars_per_sec]\n",
326
    "\n",
327
    "    bars = plt.bar(models, chars_per_sec, color=[\"#1f356c\", \"#dd1f4b\"])\n",
328
    "    plt.ylabel(\"Characters per Second\")\n",
329
    "    plt.title(\"Characters Processed per Second Comparison\")\n",
330
    "\n",
331
    "    # Adding the number at the top of each bar\n",
332
    "    for bar, chars in zip(bars, chars_per_sec):\n",
333
    "        plt.text(\n",
334
    "            bar.get_x() + bar.get_width() / 2,\n",
335
    "            bar.get_height(),\n",
336
    "            f\"{chars:.1f}\",\n",
337
    "            ha=\"center\",\n",
338
    "            va=\"bottom\",\n",
339
    "            color=\"#1f356c\",\n",
340
    "            fontsize=12,\n",
341
    "        )\n",
342
    "\n",
343
    "    plt.show()\n",
344
    "\n",
345
    "\n",
346
    "plot_character_per_second_comparison(hf_stats, fst_stats, documents)"
347
   ]
348
  },
349
  {
350
   "cell_type": "markdown",
351
   "metadata": {},
352
   "source": [
353
    "## Are the Embeddings the same?\n",
354
    "\n",
355
    "This is a very important question. Let's see if the embeddings are the same."
356
   ]
357
  },
358
  {
359
   "cell_type": "code",
360
   "execution_count": 10,
361
   "metadata": {
362
    "ExecuteTime": {
363
     "end_time": "2024-03-30T00:43:25.537072Z",
364
     "start_time": "2024-03-30T00:43:25.419184Z"
365
    }
366
   },
367
   "outputs": [
368
    {
369
     "name": "stderr",
370
     "output_type": "stream",
371
     "text": [
372
      "/var/folders/b4/grpbcmrd36gc7q5_11whbn540000gn/T/ipykernel_42284/1958479940.py:8: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:278.)\n",
373
      "  calculate_cosine_similarity(hf.embed(documents), Tensor(list(embedding_model.embed(documents))))\n"
374
     ]
375
    },
376
    {
377
     "data": {
378
      "text/plain": [
379
       "0.9999997019767761"
380
      ]
381
     },
382
     "execution_count": 10,
383
     "metadata": {},
384
     "output_type": "execute_result"
385
    }
386
   ],
387
   "source": [
388
    "def calculate_cosine_similarity(embeddings1: Tensor, embeddings2: Tensor) -> float:\n",
389
    "    \"\"\"\n",
390
    "    Calculate cosine similarity between two sets of embeddings\n",
391
    "    \"\"\"\n",
392
    "    return F.cosine_similarity(embeddings1, embeddings2).mean().item()\n",
393
    "\n",
394
    "\n",
395
    "calculate_cosine_similarity(hf.embed(documents), Tensor(list(embedding_model.embed(documents))))"
396
   ]
397
  },
398
  {
399
   "cell_type": "markdown",
400
   "metadata": {},
401
   "source": [
402
    "This indicates the embeddings are quite close to each with a cosine similarity of 0.99 for BAAI/bge-small-en and 0.92 for BAAI/bge-small-en-v1.5. This gives us confidence that the embeddings are the same and we are not sacrificing accuracy for speed."
403
   ]
404
  }
405
 ],
406
 "metadata": {
407
  "kernelspec": {
408
   "display_name": "fst",
409
   "language": "python",
410
   "name": "python3"
411
  },
412
  "language_info": {
413
   "codemirror_mode": {
414
    "name": "ipython",
415
    "version": 3
416
   },
417
   "file_extension": ".py",
418
   "mimetype": "text/x-python",
419
   "name": "python",
420
   "nbconvert_exporter": "python",
421
   "pygments_lexer": "ipython3",
422
   "version": "3.10.13"
423
  },
424
  "orig_nbformat": 4
425
 },
426
 "nbformat": 4,
427
 "nbformat_minor": 2
428
}
429

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.