slovnet
107 строк · 2.3 Кб
1
2import re3from os import listdir, makedirs4from os.path import exists, join, expanduser5from random import seed, shuffle6
7from tqdm.notebook import tqdm as log_progress8
9import torch10from torch import optim11
12from apex import amp13O2 = 'O2'14
15from corus import (16load_buriy_news,17load_taiga_fontanka,18load_ods_gazeta,19load_ods_interfax,20load_lenta
21)
22
23from slovnet.io import (24load_lines,25dump_lines
26)
27from slovnet.s3 import S328from slovnet.board import TensorBoard29from slovnet.const import CUDA030
31from slovnet.model.bert import (32RuBERTConfig,33BERTEmbedding,34BERTEncoder,35BERTMLMHead,36BERTMLM
37)
38from slovnet.vocab import BERTVocab39from slovnet.encoders.bert import BERTMLMTrainEncoder40from slovnet.score import (41MLMScoreMeter,42score_mlm_batch,43score_mlm_batches
44)
45from slovnet.loss import masked_flatten_cross_entropy46
47
48DATA_DIR = 'data'49MODEL_DIR = 'model'50RUBERT_DIR = 'rubert'51RAW_DIR = join(DATA_DIR, 'raw')52
53TRAIN = join(DATA_DIR, 'train.txt')54TEST = join(DATA_DIR, 'test.txt')55
56S3_DIR = '01_bert_news'57S3_TRAIN = join(S3_DIR, TRAIN)58S3_TEST = join(S3_DIR, TEST)59
60VOCAB = 'vocab.txt'61EMB = 'emb.pt'62ENCODER = 'encoder.pt'63MLM = 'mlm.pt'64
65RUBERT_VOCAB = join(RUBERT_DIR, VOCAB)66RUBERT_EMB = join(RUBERT_DIR, EMB)67RUBERT_ENCODER = join(RUBERT_DIR, ENCODER)68RUBERT_MLM = join(RUBERT_DIR, MLM)69
70S3_RUBERT_VOCAB = join(S3_DIR, RUBERT_VOCAB)71S3_RUBERT_EMB = join(S3_DIR, RUBERT_EMB)72S3_RUBERT_ENCODER = join(S3_DIR, RUBERT_ENCODER)73S3_RUBERT_MLM = join(S3_DIR, RUBERT_MLM)74
75MODEL_EMB = join(MODEL_DIR, EMB)76MODEL_ENCODER = join(MODEL_DIR, ENCODER)77MODEL_MLM = join(MODEL_DIR, MLM)78
79S3_MODEL_EMB = join(S3_DIR, MODEL_EMB)80S3_MODEL_ENCODER = join(S3_DIR, MODEL_ENCODER)81S3_MODEL_MLM = join(S3_DIR, MODEL_MLM)82
83BOARD_NAME = '01_bert_news'84RUNS_DIR = 'runs'85
86TRAIN_BOARD = '01_train'87TEST_BOARD = '02_test'88
89DEVICE = CUDA090
91
92def every(step, period):93return step > 0 and step % period == 094
95
96def process_batch(model, criterion, batch):97pred = model(batch.input)98loss = criterion(pred, batch.target.value, batch.target.mask)99return batch.processed(loss, pred)100
101
102def infer_batches(model, criterion, batches):103training = model.training104model.eval()105with torch.no_grad():106for batch in batches:107yield process_batch(model, criterion, batch)108model.train(training)109