slovnet
211 строк · 5.2 Кб
1{
2"cells": [
3{
4"cell_type": "code",
5"execution_count": 5,
6"metadata": {},
7"outputs": [],
8"source": [
9"%run main.py\n",
10"%load_ext autoreload\n",
11"%autoreload 2\n",
12"\n",
13"!mkdir -p {DATA_DIR} {RUBERT_DIR} {MODEL_DIR}\n",
14"s3 = S3()"
15]
16},
17{
18"cell_type": "code",
19"execution_count": 7,
20"metadata": {},
21"outputs": [],
22"source": [
23"if not exists(TEST):\n",
24" s3.download(S3_TEST, TEST)\n",
25" s3.download(S3_TRAIN, TRAIN)"
26]
27},
28{
29"cell_type": "code",
30"execution_count": null,
31"metadata": {},
32"outputs": [],
33"source": [
34"if not exists(RUBERT_VOCAB):\n",
35" s3.download(S3_RUBERT_VOCAB, RUBERT_VOCAB)\n",
36" s3.download(S3_RUBERT_EMB, RUBERT_EMB)\n",
37" s3.download(S3_RUBERT_ENCODER, RUBERT_ENCODER)\n",
38" s3.download(S3_RUBERT_MLM, RUBERT_MLM)"
39]
40},
41{
42"cell_type": "code",
43"execution_count": null,
44"metadata": {},
45"outputs": [],
46"source": [
47"vocab = BERTVocab.load(RUBERT_VOCAB)"
48]
49},
50{
51"cell_type": "code",
52"execution_count": null,
53"metadata": {},
54"outputs": [],
55"source": [
56"config = RuBERTConfig()\n",
57"emb = BERTEmbedding.from_config()\n",
58"encoder = BERTEncoder.from_config()\n",
59"head = BERTMLMHead(config.emb_dim, config.vocab_size)\n",
60"model = BERTMLM(emb, encoder, head)\n",
61"\n",
62" # fix pos emb, train on short seqs\n",
63"emb.position.weight.requires_grad = False\n",
64"\n",
65"model.emb.load(RUBERT_EMB)\n",
66"model.encoder.load(RUBERT_ENCODER)\n",
67"model.head.load(RUBERT_MLM)\n",
68"model = model.to(DEVICE)\n",
69"\n",
70"criterion = masked_flatten_cross_entropy"
71]
72},
73{
74"cell_type": "code",
75"execution_count": null,
76"metadata": {},
77"outputs": [],
78"source": [
79"torch.manual_seed(1)\n",
80"seed(1)"
81]
82},
83{
84"cell_type": "code",
85"execution_count": null,
86"metadata": {},
87"outputs": [],
88"source": [
89"encode = BERTMLMTrainEncoder(\n",
90" vocab,\n",
91" seq_len=128,\n",
92" batch_size=32,\n",
93" shuffle_size=10000\n",
94")\n",
95"\n",
96"lines = load_lines(TEST)\n",
97"batches = encode(lines)\n",
98"test_batches = [_.to(DEVICE) for _ in batches]\n",
99"\n",
100"lines = load_lines(TRAIN)\n",
101"batches = encode(lines)\n",
102"train_batches = (_.to(DEVICE) for _ in batches)"
103]
104},
105{
106"cell_type": "code",
107"execution_count": null,
108"metadata": {},
109"outputs": [],
110"source": [
111"board = TensorBoard(BOARD_NAME, RUNS_DIR)\n",
112"train_board = board.section(TRAIN_BOARD)\n",
113"test_board = board.section(TEST_BOARD)"
114]
115},
116{
117"cell_type": "code",
118"execution_count": null,
119"metadata": {},
120"outputs": [],
121"source": [
122"optimizer = optim.Adam(model.parameters(), lr=0.0001)\n",
123"model, optimizer = amp.initialize(model, optimizer, opt_level=O2)\n",
124"scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.999)"
125]
126},
127{
128"cell_type": "code",
129"execution_count": null,
130"metadata": {},
131"outputs": [],
132"source": [
133"train_meter = MLMScoreMeter()\n",
134"test_meter = MLMScoreMeter()\n",
135"\n",
136"accum_steps = 64 # 2K batch\n",
137"log_steps = 256\n",
138"eval_steps = 512\n",
139"save_steps = eval_steps * 10\n",
140"\n",
141"model.train()\n",
142"optimizer.zero_grad()\n",
143"\n",
144"for step, batch in log_progress(enumerate(train_batches)):\n",
145" batch = process_batch(model, criterion, batch)\n",
146" batch.loss /= accum_steps\n",
147" \n",
148" with amp.scale_loss(batch.loss, optimizer) as scaled:\n",
149" scaled.backward()\n",
150"\n",
151" score = score_mlm_batch(batch, ks=())\n",
152" train_meter.add(score)\n",
153"\n",
154" if every(step, log_steps):\n",
155" train_meter.write(train_board)\n",
156" train_meter.reset()\n",
157"\n",
158" if every(step, accum_steps):\n",
159" optimizer.step()\n",
160" scheduler.step()\n",
161" optimizer.zero_grad()\n",
162"\n",
163" if every(step, eval_steps):\n",
164" batches = infer_batches(model, criterion, test_batches)\n",
165" scores = score_mlm_batches(batches)\n",
166" test_meter.extend(scores)\n",
167" test_meter.write(test_board)\n",
168" test_meter.reset()\n",
169" \n",
170" if every(step, save_steps):\n",
171" model.emb.dump(MODEL_EMB)\n",
172" model.encoder.dump(MODEL_ENCODER)\n",
173" model.mlm.dump(MODEL_MLM)\n",
174" \n",
175" s3.upload(MODEL_EMB, S3_MODEL_EMB)\n",
176" s3.upload(MODEL_ENCODER, S3_MODEL_ENCODER)\n",
177" s3.upload(MODEL_MLM, S3_MODEL_MLM)\n",
178" \n",
179" board.step()"
180]
181},
182{
183"cell_type": "code",
184"execution_count": null,
185"metadata": {},
186"outputs": [],
187"source": []
188}
189],
190"metadata": {
191"kernelspec": {
192"display_name": "Python 3",
193"language": "python",
194"name": "python3"
195},
196"language_info": {
197"codemirror_mode": {
198"name": "ipython",
199"version": 3
200},
201"file_extension": ".py",
202"mimetype": "text/x-python",
203"name": "python",
204"nbconvert_exporter": "python",
205"pygments_lexer": "ipython3",
206"version": "3.6.9"
207}
208},
209"nbformat": 4,
210"nbformat_minor": 2
211}
212