russian_art_2024
53 строки · 1.3 Кб
1import os
2
3import numpy as np
4import torch
5import uvicorn
6from config import ArtClassifierConstants
7from fastapi import FastAPI, UploadFile
8from PIL import Image
9from torchvision import transforms as T
10
11from initial_model_utils import init_model
12
13config = ArtClassifierConstants()
14
15device = torch.device("cpu")
16
17model = init_model(device, num_classes=len(config.id2label), pretrained=False)
18model.load_state_dict(torch.load(config.model_weights_path, map_location=device))
19model.eval()
20
21trans = T.Compose(
22[
23T.Resize(size=(config.image_size, config.image_size)),
24T.Lambda(lambda x: np.array(x, dtype="float32") / 255),
25T.ToTensor(),
26T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
27]
28)
29
30app = FastAPI()
31
32
33@app.post("/predict")
34def get_art_label(image: UploadFile):
35"""Prediction endpoint
36
37Parameters
38----------
39image : UploadFile
40Image to predict
41"""
42with torch.no_grad():
43image = Image.open(image.file).convert("RGB")
44transformed_image = torch.unsqueeze(trans(image), 0)
45outputs = model(transformed_image)
46_, preds = torch.max(outputs, 1)
47pred_pos = preds.cpu().numpy()[0]
48
49return {"sub_category": config.id2label[pred_pos]}
50
51
52if __name__ == "__main__":
53uvicorn.run(app, host="0.0.0.0", port=8000)
54