russian_art_2024

Форк
0
/
make_submission.py 
52 строки · 1.7 Кб
1
"""Script for generation competition's submission.
2
All constants remains with competition rules (cpu usage)
3
"""
4

5
import numpy as np
6
import torch
7
from torchvision import transforms
8

9
from src.initial_model_utils import init_model
10
from src.train_utils import ArtDataset
11

12
MODEL_WEIGHTS = "./data/weights/resnet50_tl_68.pt"
13
TEST_DATASET = "./data/test/"
14
SUBMISSION_PATH = "./data/submission.csv"
15

16
if __name__ == "__main__":
17
    device = torch.device("cpu")
18
    model = init_model(device, num_classes=35)
19
    model.load_state_dict(torch.load(MODEL_WEIGHTS, map_location=device))
20
    model.eval()
21

22
    img_size = 224
23
    trans = transforms.Compose(
24
        [
25
            transforms.Resize(size=(img_size, img_size)),
26
            transforms.Lambda(lambda x: np.array(x, dtype="float32") / 255),
27
            transforms.ToTensor(),
28
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
29
        ]
30
    )
31

32
    ard_data = ArtDataset(TEST_DATASET, transform=trans)
33
    batch_size = 16
34
    num_workers = 4
35
    testloader = torch.utils.data.DataLoader(
36
        ard_data, batch_size=batch_size, shuffle=False, num_workers=num_workers
37
    )
38

39
    all_image_names = [item.split("/")[-1] for item in ard_data.files]
40
    all_preds = []
41

42
    with torch.no_grad():
43
        for idx, (images, _) in enumerate(testloader, 0):
44
            images = images.to(device)
45
            outputs = model(images)
46
            _, preds = torch.max(outputs, 1)
47
            all_preds.extend(preds.cpu().numpy().tolist())
48

49
    with open(SUBMISSION_PATH, "w") as f:
50
        f.write("image_name\tlabel_id\n")
51
        for name, cl_id in zip(all_image_names, all_preds):
52
            f.write(f"{name}\t{cl_id}\n")
53

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.