slovnet
135 строк · 3.2 Кб
1
2from os import getenv, environ3from os.path import exists, join4from itertools import chain, islice as head5from random import seed, sample, randint, uniform6from subprocess import run7
8from tqdm.notebook import tqdm as log_progress9
10import torch11from torch import optim12
13from naeval.syntax.datasets import load_dataset14
15from slovnet.s3 import S316from slovnet.io import (17format_jl,18parse_jl,19
20load_gz_lines,21dump_gz_lines
22)
23from slovnet.board import (24TensorBoard,25LogBoard,26MultiBoard
27)
28from slovnet.const import (29TRAIN, TEST,30PAD, CUDA0,31)
32
33from slovnet.model.bert import (34RuBERTConfig,35BERTEmbedding,36BERTEncoder,37BERTSyntaxHead,38BERTSyntaxRel,39BERTSyntax
40)
41from slovnet.markup import SyntaxMarkup42from slovnet.vocab import BERTVocab, Vocab43from slovnet.encoders.bert import BERTSyntaxTrainEncoder44from slovnet.loss import masked_flatten_cross_entropy45from slovnet.score import (46SyntaxScoreMeter,47score_syntax_batch
48)
49from slovnet.mask import (50Masked,51split_masked,52pad_masked
53)
54
55
56DATA_DIR = 'data'57MODEL_DIR = 'model'58BERT_DIR = 'bert'59RAW_DIR = join(DATA_DIR, 'raw')60
61NEWS = join(DATA_DIR, 'news.jl.gz')62FICTION = join(DATA_DIR, 'fiction.jl.gz')63GRAMRU_DIR = join(RAW_DIR, 'GramEval2020-master')64GRAMRU_FILES = {65NEWS: [66'dataOpenTest/GramEval2020-RuEval2017-Lenta-news-dev.conllu',67'dataTrain/MorphoRuEval2017-Lenta-train.conllu',68],69FICTION: [70'dataOpenTest/GramEval2020-SynTagRus-dev.conllu',71'dataTrain/GramEval2020-SynTagRus-train-v2.conllu',72'dataTrain/MorphoRuEval2017-JZ-gold.conllu'73],74}
75
76S3_DIR = '04_bert_syntax'77S3_NEWS = join(S3_DIR, NEWS)78S3_FICTION = join(S3_DIR, FICTION)79
80VOCAB = 'vocab.txt'81EMB = 'emb.pt'82ENCODER = 'encoder.pt'83HEAD = 'head.pt'84REL = 'rel.pt'85
86BERT_VOCAB = join(BERT_DIR, VOCAB)87BERT_EMB = join(BERT_DIR, EMB)88BERT_ENCODER = join(BERT_DIR, ENCODER)89
90S3_RUBERT_DIR = '01_bert_news/rubert'91S3_MLM_DIR = '01_bert_news/model'92S3_BERT_VOCAB = join(S3_RUBERT_DIR, VOCAB)93S3_BERT_EMB = join(S3_MLM_DIR, EMB)94S3_BERT_ENCODER = join(S3_MLM_DIR, ENCODER)95
96RELS_VOCAB = join(MODEL_DIR, 'rels_vocab.txt')97MODEL_ENCODER = join(MODEL_DIR, ENCODER)98MODEL_HEAD = join(MODEL_DIR, HEAD)99MODEL_REL = join(MODEL_DIR, REL)100
101S3_RELS_VOCAB = join(S3_DIR, RELS_VOCAB)102S3_MODEL_ENCODER = join(S3_DIR, MODEL_ENCODER)103S3_MODEL_HEAD = join(S3_DIR, MODEL_HEAD)104S3_MODEL_REL = join(S3_DIR, MODEL_REL)105
106BOARD_NAME = getenv('board_name', '04_bert_syntax_01')107RUNS_DIR = 'runs'108
109TRAIN_BOARD = '01_train'110TEST_BOARD = '02_test'111
112SEED = int(getenv('seed', 50))113DEVICE = getenv('device', CUDA0)114BERT_LR = float(getenv('bert_lr', 0.000058))115LR = float(getenv('lr', 0.00012))116LR_GAMMA = float(getenv('lr_gamma', 0.29))117EPOCHS = int(getenv('epochs', 2))118
119
120def process_batch(model, criterion, batch):121input, target = batch122
123pred = model(124input.word_id, input.word_mask, input.pad_mask,125target.mask, target.head_id126)127
128loss = (129criterion(pred.head_id, target.head_id, target.mask)130+ criterion(pred.rel_id, target.rel_id, target.mask)131)132
133pred.head_id = model.head.decode(pred.head_id, target.mask)134pred.rel_id = model.rel.decode(pred.rel_id, target.mask)135
136return batch.processed(loss, pred)137