russian_art_2024

Форк
0
53 строки · 1.3 Кб
1
import os
2

3
import numpy as np
4
import torch
5
import uvicorn
6
from config import ArtClassifierConstants
7
from fastapi import FastAPI, UploadFile
8
from PIL import Image
9
from torchvision import transforms as T
10

11
from initial_model_utils import init_model
12

13
config = ArtClassifierConstants()
14

15
device = torch.device("cpu")
16

17
model = init_model(device, num_classes=len(config.id2label), pretrained=False)
18
model.load_state_dict(torch.load(config.model_weights_path, map_location=device))
19
model.eval()
20

21
trans = T.Compose(
22
    [
23
        T.Resize(size=(config.image_size, config.image_size)),
24
        T.Lambda(lambda x: np.array(x, dtype="float32") / 255),
25
        T.ToTensor(),
26
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
27
    ]
28
)
29

30
app = FastAPI()
31

32

33
@app.post("/predict")
34
def get_art_label(image: UploadFile):
35
    """Prediction endpoint
36

37
    Parameters
38
    ----------
39
    image : UploadFile
40
        Image to predict
41
    """
42
    with torch.no_grad():
43
        image = Image.open(image.file).convert("RGB")
44
        transformed_image = torch.unsqueeze(trans(image), 0)
45
        outputs = model(transformed_image)
46
        _, preds = torch.max(outputs, 1)
47
        pred_pos = preds.cpu().numpy()[0]
48

49
    return {"sub_category": config.id2label[pred_pos]}
50

51

52
if __name__ == "__main__":
53
    uvicorn.run(app, host="0.0.0.0", port=8000)
54

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.