slovnet

Форк
0
107 строк · 2.3 Кб
1

2
import re
3
from os import listdir, makedirs
4
from os.path import exists, join, expanduser
5
from random import seed, shuffle
6

7
from tqdm.notebook import tqdm as log_progress
8

9
import torch
10
from torch import optim
11

12
from apex import amp
13
O2 = 'O2'
14

15
from corus import (
16
    load_buriy_news,
17
    load_taiga_fontanka,
18
    load_ods_gazeta,
19
    load_ods_interfax,
20
    load_lenta
21
)
22

23
from slovnet.io import (
24
    load_lines,
25
    dump_lines
26
)
27
from slovnet.s3 import S3
28
from slovnet.board import TensorBoard
29
from slovnet.const import CUDA0
30

31
from slovnet.model.bert import (
32
    RuBERTConfig,
33
    BERTEmbedding,
34
    BERTEncoder,
35
    BERTMLMHead,
36
    BERTMLM
37
)
38
from slovnet.vocab import BERTVocab
39
from slovnet.encoders.bert import BERTMLMTrainEncoder
40
from slovnet.score import (
41
    MLMScoreMeter,
42
    score_mlm_batch,
43
    score_mlm_batches
44
)
45
from slovnet.loss import masked_flatten_cross_entropy
46

47

48
DATA_DIR = 'data'
49
MODEL_DIR = 'model'
50
RUBERT_DIR = 'rubert'
51
RAW_DIR = join(DATA_DIR, 'raw')
52

53
TRAIN = join(DATA_DIR, 'train.txt')
54
TEST = join(DATA_DIR, 'test.txt')
55

56
S3_DIR = '01_bert_news'
57
S3_TRAIN = join(S3_DIR, TRAIN)
58
S3_TEST = join(S3_DIR, TEST)
59

60
VOCAB = 'vocab.txt'
61
EMB = 'emb.pt'
62
ENCODER = 'encoder.pt'
63
MLM = 'mlm.pt'
64

65
RUBERT_VOCAB = join(RUBERT_DIR, VOCAB)
66
RUBERT_EMB = join(RUBERT_DIR, EMB)
67
RUBERT_ENCODER = join(RUBERT_DIR, ENCODER)
68
RUBERT_MLM = join(RUBERT_DIR, MLM)
69

70
S3_RUBERT_VOCAB = join(S3_DIR, RUBERT_VOCAB)
71
S3_RUBERT_EMB = join(S3_DIR, RUBERT_EMB)
72
S3_RUBERT_ENCODER = join(S3_DIR, RUBERT_ENCODER)
73
S3_RUBERT_MLM = join(S3_DIR, RUBERT_MLM)
74

75
MODEL_EMB = join(MODEL_DIR, EMB)
76
MODEL_ENCODER = join(MODEL_DIR, ENCODER)
77
MODEL_MLM = join(MODEL_DIR, MLM)
78

79
S3_MODEL_EMB = join(S3_DIR, MODEL_EMB)
80
S3_MODEL_ENCODER = join(S3_DIR, MODEL_ENCODER)
81
S3_MODEL_MLM = join(S3_DIR, MODEL_MLM)
82

83
BOARD_NAME = '01_bert_news'
84
RUNS_DIR = 'runs'
85

86
TRAIN_BOARD = '01_train'
87
TEST_BOARD = '02_test'
88

89
DEVICE = CUDA0
90

91

92
def every(step, period):
93
    return step > 0 and step % period == 0
94

95

96
def process_batch(model, criterion, batch):
97
    pred = model(batch.input)
98
    loss = criterion(pred, batch.target.value, batch.target.mask)
99
    return batch.processed(loss, pred)
100

101

102
def infer_batches(model, criterion, batches):
103
    training = model.training
104
    model.eval()
105
    with torch.no_grad():
106
        for batch in batches:
107
            yield process_batch(model, criterion, batch)
108
    model.train(training)
109

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.