slovnet

Форк
0
211 строк · 5.2 Кб
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": 5,
6
   "metadata": {},
7
   "outputs": [],
8
   "source": [
9
    "%run main.py\n",
10
    "%load_ext autoreload\n",
11
    "%autoreload 2\n",
12
    "\n",
13
    "!mkdir -p {DATA_DIR} {RUBERT_DIR} {MODEL_DIR}\n",
14
    "s3 = S3()"
15
   ]
16
  },
17
  {
18
   "cell_type": "code",
19
   "execution_count": 7,
20
   "metadata": {},
21
   "outputs": [],
22
   "source": [
23
    "if not exists(TEST):\n",
24
    "    s3.download(S3_TEST, TEST)\n",
25
    "    s3.download(S3_TRAIN, TRAIN)"
26
   ]
27
  },
28
  {
29
   "cell_type": "code",
30
   "execution_count": null,
31
   "metadata": {},
32
   "outputs": [],
33
   "source": [
34
    "if not exists(RUBERT_VOCAB):\n",
35
    "    s3.download(S3_RUBERT_VOCAB, RUBERT_VOCAB)\n",
36
    "    s3.download(S3_RUBERT_EMB, RUBERT_EMB)\n",
37
    "    s3.download(S3_RUBERT_ENCODER, RUBERT_ENCODER)\n",
38
    "    s3.download(S3_RUBERT_MLM, RUBERT_MLM)"
39
   ]
40
  },
41
  {
42
   "cell_type": "code",
43
   "execution_count": null,
44
   "metadata": {},
45
   "outputs": [],
46
   "source": [
47
    "vocab = BERTVocab.load(RUBERT_VOCAB)"
48
   ]
49
  },
50
  {
51
   "cell_type": "code",
52
   "execution_count": null,
53
   "metadata": {},
54
   "outputs": [],
55
   "source": [
56
    "config = RuBERTConfig()\n",
57
    "emb = BERTEmbedding.from_config()\n",
58
    "encoder = BERTEncoder.from_config()\n",
59
    "head = BERTMLMHead(config.emb_dim, config.vocab_size)\n",
60
    "model = BERTMLM(emb, encoder, head)\n",
61
    "\n",
62
    " # fix pos emb, train on short seqs\n",
63
    "emb.position.weight.requires_grad = False\n",
64
    "\n",
65
    "model.emb.load(RUBERT_EMB)\n",
66
    "model.encoder.load(RUBERT_ENCODER)\n",
67
    "model.head.load(RUBERT_MLM)\n",
68
    "model = model.to(DEVICE)\n",
69
    "\n",
70
    "criterion = masked_flatten_cross_entropy"
71
   ]
72
  },
73
  {
74
   "cell_type": "code",
75
   "execution_count": null,
76
   "metadata": {},
77
   "outputs": [],
78
   "source": [
79
    "torch.manual_seed(1)\n",
80
    "seed(1)"
81
   ]
82
  },
83
  {
84
   "cell_type": "code",
85
   "execution_count": null,
86
   "metadata": {},
87
   "outputs": [],
88
   "source": [
89
    "encode = BERTMLMTrainEncoder(\n",
90
    "    vocab,\n",
91
    "    seq_len=128,\n",
92
    "    batch_size=32,\n",
93
    "    shuffle_size=10000\n",
94
    ")\n",
95
    "\n",
96
    "lines = load_lines(TEST)\n",
97
    "batches = encode(lines)\n",
98
    "test_batches = [_.to(DEVICE) for _ in batches]\n",
99
    "\n",
100
    "lines = load_lines(TRAIN)\n",
101
    "batches = encode(lines)\n",
102
    "train_batches = (_.to(DEVICE) for _ in batches)"
103
   ]
104
  },
105
  {
106
   "cell_type": "code",
107
   "execution_count": null,
108
   "metadata": {},
109
   "outputs": [],
110
   "source": [
111
    "board = TensorBoard(BOARD_NAME, RUNS_DIR)\n",
112
    "train_board = board.section(TRAIN_BOARD)\n",
113
    "test_board = board.section(TEST_BOARD)"
114
   ]
115
  },
116
  {
117
   "cell_type": "code",
118
   "execution_count": null,
119
   "metadata": {},
120
   "outputs": [],
121
   "source": [
122
    "optimizer = optim.Adam(model.parameters(), lr=0.0001)\n",
123
    "model, optimizer = amp.initialize(model, optimizer, opt_level=O2)\n",
124
    "scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.999)"
125
   ]
126
  },
127
  {
128
   "cell_type": "code",
129
   "execution_count": null,
130
   "metadata": {},
131
   "outputs": [],
132
   "source": [
133
    "train_meter = MLMScoreMeter()\n",
134
    "test_meter = MLMScoreMeter()\n",
135
    "\n",
136
    "accum_steps = 64  # 2K batch\n",
137
    "log_steps = 256\n",
138
    "eval_steps = 512\n",
139
    "save_steps = eval_steps * 10\n",
140
    "\n",
141
    "model.train()\n",
142
    "optimizer.zero_grad()\n",
143
    "\n",
144
    "for step, batch in log_progress(enumerate(train_batches)):\n",
145
    "    batch = process_batch(model, criterion, batch)\n",
146
    "    batch.loss /= accum_steps\n",
147
    "    \n",
148
    "    with amp.scale_loss(batch.loss, optimizer) as scaled:\n",
149
    "        scaled.backward()\n",
150
    "\n",
151
    "    score = score_mlm_batch(batch, ks=())\n",
152
    "    train_meter.add(score)\n",
153
    "\n",
154
    "    if every(step, log_steps):\n",
155
    "        train_meter.write(train_board)\n",
156
    "        train_meter.reset()\n",
157
    "\n",
158
    "    if every(step, accum_steps):\n",
159
    "        optimizer.step()\n",
160
    "        scheduler.step()\n",
161
    "        optimizer.zero_grad()\n",
162
    "\n",
163
    "        if every(step, eval_steps):\n",
164
    "            batches = infer_batches(model, criterion, test_batches)\n",
165
    "            scores = score_mlm_batches(batches)\n",
166
    "            test_meter.extend(scores)\n",
167
    "            test_meter.write(test_board)\n",
168
    "            test_meter.reset()\n",
169
    "    \n",
170
    "    if every(step, save_steps):\n",
171
    "        model.emb.dump(MODEL_EMB)\n",
172
    "        model.encoder.dump(MODEL_ENCODER)\n",
173
    "        model.mlm.dump(MODEL_MLM)\n",
174
    "        \n",
175
    "        s3.upload(MODEL_EMB, S3_MODEL_EMB)\n",
176
    "        s3.upload(MODEL_ENCODER, S3_MODEL_ENCODER)\n",
177
    "        s3.upload(MODEL_MLM, S3_MODEL_MLM)\n",
178
    "            \n",
179
    "    board.step()"
180
   ]
181
  },
182
  {
183
   "cell_type": "code",
184
   "execution_count": null,
185
   "metadata": {},
186
   "outputs": [],
187
   "source": []
188
  }
189
 ],
190
 "metadata": {
191
  "kernelspec": {
192
   "display_name": "Python 3",
193
   "language": "python",
194
   "name": "python3"
195
  },
196
  "language_info": {
197
   "codemirror_mode": {
198
    "name": "ipython",
199
    "version": 3
200
   },
201
   "file_extension": ".py",
202
   "mimetype": "text/x-python",
203
   "name": "python",
204
   "nbconvert_exporter": "python",
205
   "pygments_lexer": "ipython3",
206
   "version": "3.6.9"
207
  }
208
 },
209
 "nbformat": 4,
210
 "nbformat_minor": 2
211
}
212

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.