5
from transformers import BertTokenizer, BertForMaskedLM
7
from flask import Flask, request, jsonify
8
from healthcheck import HealthCheck
10
from sentry_sdk.integrations.flask import FlaskIntegration
12
sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])
15
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
16
logger = logging.getLogger(__name__)
18
PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH")
19
logging.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}")
22
cuda = torch.cuda.is_available()
24
torch.cuda.set_device(0) # singe gpu
25
device = torch.device("cuda")
27
device = torch.device("cpu")
29
logger.info(f"masked_lm is set to run on {device}")
32
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
33
model = BertForMaskedLM.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
38
logger.info("masked_lm model is ready")
40
sentry_sdk.capture_exception(e)
45
health = HealthCheck(app, "/healthcheck")
46
logging.getLogger("werkzeug").setLevel("WARNING")
49
@app.route("/respond", methods=["POST"])
53
text = request.json.get("text", [])
55
inputs = tokenizer(text, return_tensors="pt", padding=True)
56
inputs = {k: v.cuda() for k, v in inputs.items()} if cuda else inputs
57
logits = model(**inputs).logits.cpu()
58
probs = torch.nn.functional.softmax(logits, dim=2)
60
batch_predicted_tokens = []
61
for batch_i in range(probs.shape[0]):
62
masked_tokens = probs[batch_i][inputs["input_ids"][batch_i] == MASK_ID]
64
for token_id in range(masked_tokens.shape[0]):
65
token_probs, token_ids = masked_tokens[token_id].topk(10)
66
token_probs = token_probs.tolist()
67
token_ids = [tokenizer.decode([id]) for id in token_ids.tolist()]
68
predicted_tokens.append({token: prob for token, prob in zip(token_ids, token_probs)})
69
batch_predicted_tokens.append(predicted_tokens)
70
except Exception as exc:
72
sentry_sdk.capture_exception(exc)
73
batch_predicted_tokens = [[]] * len(text)
75
total_time = time.time() - st_time
76
logger.info(f"masked_lm exec time: {total_time:.3f}s")
77
return jsonify({"predicted_tokens": batch_predicted_tokens})