dream

Форк
0
77 строк · 2.6 Кб
1
import logging
2
import time
3
import os
4

5
from transformers import BertTokenizer, BertForMaskedLM
6
import torch
7
from flask import Flask, request, jsonify
8
from healthcheck import HealthCheck
9
import sentry_sdk
10
from sentry_sdk.integrations.flask import FlaskIntegration
11

12
sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])
13

14

15
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
16
logger = logging.getLogger(__name__)
17

18
PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH")
19
logging.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}")
20
MASK_ID = 103
21
try:
22
    cuda = torch.cuda.is_available()
23
    if cuda:
24
        torch.cuda.set_device(0)  # singe gpu
25
        device = torch.device("cuda")
26
    else:
27
        device = torch.device("cpu")
28

29
    logger.info(f"masked_lm is set to run on {device}")
30

31
    # init model
32
    tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
33
    model = BertForMaskedLM.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
34
    model.eval()
35
    if cuda:
36
        model.cuda()
37

38
    logger.info("masked_lm model is ready")
39
except Exception as e:
40
    sentry_sdk.capture_exception(e)
41
    logger.exception(e)
42
    raise e
43

44
app = Flask(__name__)
45
health = HealthCheck(app, "/healthcheck")
46
logging.getLogger("werkzeug").setLevel("WARNING")
47

48

49
@app.route("/respond", methods=["POST"])
50
def respond():
51
    st_time = time.time()
52

53
    text = request.json.get("text", [])
54
    try:
55
        inputs = tokenizer(text, return_tensors="pt", padding=True)
56
        inputs = {k: v.cuda() for k, v in inputs.items()} if cuda else inputs
57
        logits = model(**inputs).logits.cpu()
58
        probs = torch.nn.functional.softmax(logits, dim=2)
59

60
        batch_predicted_tokens = []
61
        for batch_i in range(probs.shape[0]):
62
            masked_tokens = probs[batch_i][inputs["input_ids"][batch_i] == MASK_ID]
63
            predicted_tokens = []
64
            for token_id in range(masked_tokens.shape[0]):
65
                token_probs, token_ids = masked_tokens[token_id].topk(10)
66
                token_probs = token_probs.tolist()
67
                token_ids = [tokenizer.decode([id]) for id in token_ids.tolist()]
68
                predicted_tokens.append({token: prob for token, prob in zip(token_ids, token_probs)})
69
            batch_predicted_tokens.append(predicted_tokens)
70
    except Exception as exc:
71
        logger.exception(exc)
72
        sentry_sdk.capture_exception(exc)
73
        batch_predicted_tokens = [[]] * len(text)
74

75
    total_time = time.time() - st_time
76
    logger.info(f"masked_lm exec time: {total_time:.3f}s")
77
    return jsonify({"predicted_tokens": batch_predicted_tokens})
78

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.