slovnet

Форк
0
93 строки · 2.2 Кб
1

2
from os import getenv
3

4
import logging
5
logging.basicConfig(
6
    level=logging.INFO,
7
    format='%(asctime)-15s %(message)s'
8
)
9
log = logging.info
10

11
from aiohttp import web
12

13
import torch
14
torch.set_grad_enabled(False)
15

16
from slovnet.const import CUDA0
17
from slovnet.vocab import (
18
    BERTVocab,
19
    Vocab
20
)
21
from slovnet.model.bert import (
22
    RuBERTConfig,
23
    BERTEmbedding,
24
    BERTEncoder,
25
    BERTMorphHead,
26
    BERTMorph
27
)
28
from slovnet.encoders.bert import BERTInferEncoder
29
from slovnet.infer.bert import BERTMorphInfer, BERTTagDecoder
30

31

32
WORDS_VOCAB = getenv('WORDS_VOCAB', 'vocab.txt')
33
TAGS_VOCAB = getenv('TAGS_VOCAB', 'tags_vocab.txt')
34
EMB = getenv('EMB', 'emb.pt')
35
ENCODER = getenv('ENCODER', 'encoder.pt')
36
MORPH = getenv('MORPH', 'morph.pt')
37

38
DEVICE = getenv('DEVICE', CUDA0)
39
SEQ_LEN = int(getenv('SEQ_LEN', 256))
40
BATCH_SIZE = int(getenv('BATCH_SIZE', 64))
41

42
HOST = getenv('HOST', '0.0.0.0')
43
PORT = int(getenv('PORT', 8080))
44
MB = 1024 * 1024
45
MAX_SIZE = int(getenv('MAX_SIZE', 100 * MB))
46

47

48
log('Load vocabs: %r, %r' % (WORDS_VOCAB, TAGS_VOCAB))
49
words_vocab = BERTVocab.load(WORDS_VOCAB)
50
tags_vocab = Vocab.load(TAGS_VOCAB)
51

52
config = RuBERTConfig()
53
emb = BERTEmbedding.from_config(config)
54
encoder = BERTEncoder.from_config(config)
55
morph = BERTMorphHead(config.emb_dim, len(tags_vocab))
56
model = BERTMorph(emb, encoder, morph)
57
model.eval()
58

59
log('Load emb: %r' % EMB)
60
model.emb.load(EMB)
61
log('Load encoder: %r' % ENCODER)
62
model.encoder.load(ENCODER)
63
log('Load morph: %r' % MORPH)
64
model.head.load(MORPH)
65
log('Device: %r' % DEVICE)
66
model = model.to(DEVICE)
67

68
log('Seq len: %r' % SEQ_LEN)
69
log('Batch size: %r' % BATCH_SIZE)
70
encoder = BERTInferEncoder(
71
    words_vocab,
72
    seq_len=SEQ_LEN, batch_size=BATCH_SIZE
73
)
74
decoder = BERTTagDecoder(tags_vocab)
75
infer = BERTMorphInfer(model, encoder, decoder)
76

77

78
async def handle(request):
79
    chunk = await request.json()
80
    log('Post chunk size: %r' % len(chunk))
81
    markups = list(infer(chunk))
82

83
    tokens = sum(len(_.tokens) for _ in markups)
84
    log('Infer tokens: %r', tokens)
85

86
    data = [_.as_json for _ in markups]
87
    return web.json_response(data)
88

89

90
log('Max size: %r' % (MAX_SIZE // MB))
91
app = web.Application(client_max_size=MAX_SIZE)
92
app.add_routes([web.post('/', handle)])
93

94
web.run_app(app, host=HOST, port=PORT)
95

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.