slovnet
93 строки · 2.2 Кб
1
2from os import getenv3
4import logging5logging.basicConfig(6level=logging.INFO,7format='%(asctime)-15s %(message)s'8)
9log = logging.info10
11from aiohttp import web12
13import torch14torch.set_grad_enabled(False)15
16from slovnet.const import CUDA017from slovnet.vocab import (18BERTVocab,19Vocab
20)
21from slovnet.model.bert import (22RuBERTConfig,23BERTEmbedding,24BERTEncoder,25BERTMorphHead,26BERTMorph
27)
28from slovnet.encoders.bert import BERTInferEncoder29from slovnet.infer.bert import BERTMorphInfer, BERTTagDecoder30
31
32WORDS_VOCAB = getenv('WORDS_VOCAB', 'vocab.txt')33TAGS_VOCAB = getenv('TAGS_VOCAB', 'tags_vocab.txt')34EMB = getenv('EMB', 'emb.pt')35ENCODER = getenv('ENCODER', 'encoder.pt')36MORPH = getenv('MORPH', 'morph.pt')37
38DEVICE = getenv('DEVICE', CUDA0)39SEQ_LEN = int(getenv('SEQ_LEN', 256))40BATCH_SIZE = int(getenv('BATCH_SIZE', 64))41
42HOST = getenv('HOST', '0.0.0.0')43PORT = int(getenv('PORT', 8080))44MB = 1024 * 102445MAX_SIZE = int(getenv('MAX_SIZE', 100 * MB))46
47
48log('Load vocabs: %r, %r' % (WORDS_VOCAB, TAGS_VOCAB))49words_vocab = BERTVocab.load(WORDS_VOCAB)50tags_vocab = Vocab.load(TAGS_VOCAB)51
52config = RuBERTConfig()53emb = BERTEmbedding.from_config(config)54encoder = BERTEncoder.from_config(config)55morph = BERTMorphHead(config.emb_dim, len(tags_vocab))56model = BERTMorph(emb, encoder, morph)57model.eval()58
59log('Load emb: %r' % EMB)60model.emb.load(EMB)61log('Load encoder: %r' % ENCODER)62model.encoder.load(ENCODER)63log('Load morph: %r' % MORPH)64model.head.load(MORPH)65log('Device: %r' % DEVICE)66model = model.to(DEVICE)67
68log('Seq len: %r' % SEQ_LEN)69log('Batch size: %r' % BATCH_SIZE)70encoder = BERTInferEncoder(71words_vocab,72seq_len=SEQ_LEN, batch_size=BATCH_SIZE73)
74decoder = BERTTagDecoder(tags_vocab)75infer = BERTMorphInfer(model, encoder, decoder)76
77
78async def handle(request):79chunk = await request.json()80log('Post chunk size: %r' % len(chunk))81markups = list(infer(chunk))82
83tokens = sum(len(_.tokens) for _ in markups)84log('Infer tokens: %r', tokens)85
86data = [_.as_json for _ in markups]87return web.json_response(data)88
89
90log('Max size: %r' % (MAX_SIZE // MB))91app = web.Application(client_max_size=MAX_SIZE)92app.add_routes([web.post('/', handle)])93
94web.run_app(app, host=HOST, port=PORT)95