Amazing-Python-Scripts

Форк
0
74 строки · 2.7 Кб
1
### 1. Imports and class names setup ###
2
import gradio as gr
3
import os
4
import torch
5

6
from model import create_effnetb2_model
7
from timeit import default_timer as timer
8
from typing import Tuple, Dict
9

10
# Setup class names
11
with open("class_names.txt", "r") as f:
12
    class_names = [food_name.strip() for food_name in f.readlines()]
13

14
### 2. Model and transforms preparation ###
15
# Create model and transforms
16
effnetb2, effnetb2_transforms = create_effnetb2_model(num_classes=101)
17

18
# Load saved weights
19
effnetb2.load_state_dict(
20
    torch.load(f="09_pretrained_effnetb2_feature_extractor_food101_20_percent.pth",
21
               map_location=torch.device("cpu"))  # load to CPU
22
)
23

24
### 3. Predict function ###
25

26

27
def predict(img) -> Tuple[Dict, float]:
28
    # Start a timer
29
    start_time = timer()
30

31
    # Transform the input image for use with EffNetB2
32
    # unsqueeze = add batch dimension on 0th index
33
    img = effnetb2_transforms(img).unsqueeze(0)
34

35
    # Put model into eval mode, make prediction
36
    effnetb2.eval()
37
    with torch.inference_mode():
38
        # Pass transformed image through the model and turn the prediction logits into probaiblities
39
        pred_probs = torch.softmax(effnetb2(img), dim=1)
40

41
    # Create a prediction label and prediction probability dictionary
42
    pred_labels_and_probs = {class_names[i]: float(
43
        pred_probs[0][i]) for i in range(len(class_names))}
44

45
    # Calculate pred time
46
    end_time = timer()
47
    pred_time = round(end_time - start_time, 4)
48

49
    # Return pred dict and pred time
50
    return pred_labels_and_probs, pred_time
51

52
### 4. Gradio app ###
53

54

55
# Create title, description and article
56
title = "FoodVision BIG 🍔👁💪"
57
description = "An [EfficientNetB2 feature extractor](https://pytorch.org/vision/stable/models/generated/torchvision.models.efficientnet_b2.html#torchvision.models.efficientnet_b2) computer vision model to classify images [101 classes of food from the Food101 dataset](https://github.com/mrdbourke/pytorch-deep-learning/blob/main/extras/food101_class_names.txt)."
58
article = "Created at [09. PyTorch Model Deployment](https://www.learnpytorch.io/09_pytorch_model_deployment/#11-turning-our-foodvision-big-model-into-a-deployable-app)."
59

60
# Create example list
61
example_list = [["examples/" + example] for example in os.listdir("examples")]
62

63
# Create the Gradio demo
64
demo = gr.Interface(fn=predict,  # maps inputs to outputs
65
                    inputs=gr.Image(type="pil"),
66
                    outputs=[gr.Label(num_top_classes=5, label="Predictions"),
67
                             gr.Number(label="Prediction time (s)")],
68
                    examples=example_list,
69
                    title=title,
70
                    description=description,
71
                    article=article)
72

73
# Launch the demo!
74
demo.launch()
75

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.