examples

Форк
0
163 строки · 4.7 Кб
1
{
2
 "cells": [
3
  {
4
   "attachments": {},
5
   "cell_type": "markdown",
6
   "metadata": {},
7
   "source": [
8
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pinecone-io/examples/blob/master/learn/analytics-and-ml/model-training/gpl/03-ce-scoring.ipynb) [![Open nbviewer](https://raw.githubusercontent.com/pinecone-io/examples/master/assets/nbviewer-shield.svg)](https://nbviewer.org/github/pinecone-io/examples/blob/master/learn/analytics-and-ml/model-training/gpl/03-ce-scoring.ipynb)\n",
9
    "\n",
10
    "# Cross-Encoder Scoring\n",
11
    "\n",
12
    "The final step in preparing our training data for GPL is a cross encoder scoring step. Given all of the query, passage pairs we've generated, both:\n",
13
    "\n",
14
    "$$Positives = (Q, P^+)$$\n",
15
    "\n",
16
    "<center>and</center>\n",
17
    "\n",
18
    "$$Negatives = (Q, P^-)$$\n",
19
    "\n",
20
    "We pass these into a cross encoder model that is trained to predict similarity scores of *(Q, P)* pairs.\n",
21
    "\n",
22
    "First, we will load our cross encoder model."
23
   ]
24
  },
25
  {
26
   "cell_type": "code",
27
   "execution_count": 1,
28
   "metadata": {},
29
   "outputs": [
30
    {
31
     "data": {
32
      "text/plain": [
33
       "<sentence_transformers.cross_encoder.CrossEncoder.CrossEncoder at 0x7ff671abbe50>"
34
      ]
35
     },
36
     "execution_count": 1,
37
     "metadata": {},
38
     "output_type": "execute_result"
39
    }
40
   ],
41
   "source": [
42
    "from sentence_transformers import CrossEncoder\n",
43
    "\n",
44
    "model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')\n",
45
    "model"
46
   ]
47
  },
48
  {
49
   "cell_type": "markdown",
50
   "metadata": {},
51
   "source": [
52
    "Let's define a generator function to read pairs from file."
53
   ]
54
  },
55
  {
56
   "cell_type": "code",
57
   "execution_count": 2,
58
   "metadata": {},
59
   "outputs": [],
60
   "source": [
61
    "from tqdm.auto import tqdm\n",
62
    "\n",
63
    "def get_lines():\n",
64
    "    # loop through each file\n",
65
    "    with open('data/triplets.tsv', 'r', encoding='utf-8') as fp:\n",
66
    "        lines = fp.read().split('\\n')\n",
67
    "    # loop through each line in the current file\n",
68
    "    for line in tqdm(lines):\n",
69
    "        q, p, n = line.split('\\t')\n",
70
    "        # return the query, positive, negative\n",
71
    "        yield q, p, n"
72
   ]
73
  },
74
  {
75
   "cell_type": "markdown",
76
   "metadata": {},
77
   "source": [
78
    "We use the cross encoder to calculate the similarity of the positive pair, and negative pair, and then take the *score* as the margin between the two similarity scores. We are taking the margin as we will be training our bi-encoder model using Margin MSE loss (which requires the margin/difference).\n",
79
    "\n",
80
    "$$ Margin = sim(Q, P^+) - sim(Q, P^-) $$"
81
   ]
82
  },
83
  {
84
   "cell_type": "code",
85
   "execution_count": 3,
86
   "metadata": {},
87
   "outputs": [
88
    {
89
     "data": {
90
      "application/vnd.jupyter.widget-view+json": {
91
       "model_id": "61f52178f8f2431abd642f65a719acfd",
92
       "version_major": 2,
93
       "version_minor": 0
94
      },
95
      "text/plain": [
96
       "  0%|          | 0/200000 [00:00<?, ?it/s]"
97
      ]
98
     },
99
     "metadata": {},
100
     "output_type": "display_data"
101
    }
102
   ],
103
   "source": [
104
    "lines = get_lines()\n",
105
    "label_lines = []\n",
106
    "\n",
107
    "for line in lines:\n",
108
    "    q, p, n = line\n",
109
    "    p_score = model.predict((q, p))\n",
110
    "    n_score = model.predict((q, n))\n",
111
    "    margin = p_score - n_score\n",
112
    "    # append pairs to label_lines with margin score\n",
113
    "    label_lines.append(\n",
114
    "        q + '\\t' + p + '\\t' + n + '\\t' + str(margin)\n",
115
    "    )\n",
116
    "\n",
117
    "with open(\"data/triplets_margin.tsv\", 'w', encoding='utf-8') as fp:\n",
118
    "    fp.write('\\n'.join(label_lines))"
119
   ]
120
  },
121
  {
122
   "cell_type": "markdown",
123
   "metadata": {},
124
   "source": [
125
    "Now we have our *(Q, P<sup>+</sup>, P<sup>-</sup>)* pairs, we can move on to fine-tuning with MarginMSE loss.\n",
126
    "\n",
127
    "---"
128
   ]
129
  }
130
 ],
131
 "metadata": {
132
  "environment": {
133
   "kernel": "python3",
134
   "name": "common-cu110.m91",
135
   "type": "gcloud",
136
   "uri": "gcr.io/deeplearning-platform-release/base-cu110:m91"
137
  },
138
  "kernelspec": {
139
   "display_name": "Python 3",
140
   "language": "python",
141
   "name": "python3"
142
  },
143
  "language_info": {
144
   "codemirror_mode": {
145
    "name": "ipython",
146
    "version": 3
147
   },
148
   "file_extension": ".py",
149
   "mimetype": "text/x-python",
150
   "name": "python",
151
   "nbconvert_exporter": "python",
152
   "pygments_lexer": "ipython3",
153
   "version": "3.10.7 (main, Sep 14 2022, 22:38:23) [Clang 14.0.0 (clang-1400.0.29.102)]"
154
  },
155
  "vscode": {
156
   "interpreter": {
157
    "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
158
   }
159
  }
160
 },
161
 "nbformat": 4,
162
 "nbformat_minor": 4
163
}
164

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.