russian_art_2024

Форк
0
/
training.py 
197 строк · 4.9 Кб
1
"""Script for transfer learning ResNet50 with IMAGENET1K_V2 weights on Russian art dataset"""
2

3
"""
4
1. Import libraries, data, constants
5
1.1 Import libs
6
"""
7

8
import os
9
import shutil
10

11
import numpy as np
12
import pandas as pd
13
import torch
14
from sklearn.model_selection import train_test_split
15
from sklearn.utils.class_weight import compute_class_weight
16
from torch import nn
17
from torch.utils.data import WeightedRandomSampler
18
from torchvision import transforms
19

20
from src.initial_model_utils import init_model, set_seed
21
from src.train_utils import ArtDataset, train_model
22

23
"""
24
1.2 Define constants
25
"""
26

27
# fixed seed in order to reproduce results
28
RANDOM_STATE = 139
29

30
# resnet input image size
31
IMG_SIZE = 224
32

33
# amount of art classes of dataset
34
NUM_CLASSES = 35
35

36
# loaders params
37
BATCH_SIZE = 32
38
NUM_WORKERS = 4
39

40
# set to true if first run to split data
41
SPLIT_DATA_FOLDERS = False
42

43
# path to original data
44
TRAIN_DATASET = "./data/train/"
45
ORIGIN_TRAIN_CSV = "./data/private_info/train.csv"
46

47
# path to new splitted data and information
48
# train path remains same
49
TEST_DATASET = "./data/test/"
50
TRAIN_CSV = "./data/private_info/new_train.csv"
51
TEST_CSV = "./data/private_info/new_test.csv"
52

53
# path to save model's weight
54
MODEL_WEIGHTS = "./data/weights/resnet50_tl_68.pt"
55

56
"""
57
1.3 Set libs setting
58
"""
59
# Fixed all random states
60
set_seed(RANDOM_STATE)
61

62
"""
63
1.4 Split initial data to train and test parts
64
"""
65

66
# Split default data into train and test files
67
if SPLIT_DATA_FOLDERS:
68
    # original dataset
69
    train_data_info = pd.read_csv(ORIGIN_TRAIN_CSV, sep="\t")
70

71
    # extract id, classes
72
    indices = train_data_info.index
73
    labels = train_data_info.label_id
74

75
    # stratified split due to class imbalance
76
    ind_train, ind_test, _, _ = train_test_split(
77
        indices, labels, test_size=0.2, random_state=RANDOM_STATE, stratify=labels
78
    )
79

80
    # new datasets info
81
    new_train, new_test = (
82
        train_data_info.loc[ind_train].reset_index(drop=True),
83
        train_data_info.loc[ind_test].reset_index(drop=True),
84
    )
85

86
    # move test photos from original train to new test folder
87
    source_dir = TRAIN_DATASET
88
    target_dir = TEST_DATASET
89
    file_names = new_test["image_name"].values
90

91
    for file_name in file_names:
92
        shutil.move(os.path.join(source_dir, file_name), target_dir)
93

94
    new_train.to_csv(TRAIN_CSV, index=False, sep="\t")
95
    new_test.to_csv(TEST_CSV, index=False, sep="\t")
96

97

98
"""
99
1.5 Import images data
100
"""
101

102
# define base transforms (from ResNet50 docs)
103
base_transforms = transforms.Compose(
104
    [
105
        transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
106
        transforms.Lambda(lambda x: np.array(x, dtype="float32") / 255),
107
        transforms.ToTensor(),
108
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
109
    ]
110
)
111

112
# add some random augmentations in order to prevent overfitting
113
augmentations = transforms.RandomChoice(
114
    [
115
        transforms.Compose(
116
            [
117
                transforms.Resize(size=300, max_size=301),
118
                transforms.CenterCrop(size=300),
119
                transforms.RandomCrop(250),
120
            ]
121
        ),
122
        transforms.RandomRotation(degrees=(-25, 25)),
123
        transforms.RandomHorizontalFlip(p=0.5),
124
    ]
125
)
126

127
# final compose and datasets
128
train_transform = transforms.Compose([augmentations, base_transforms])
129
test_transform = base_transforms
130

131
train_dataset = ArtDataset(TRAIN_DATASET, TRAIN_CSV, train_transform)
132
test_dataset = ArtDataset(TEST_DATASET, TEST_CSV, test_transform)
133

134

135
"""
136
2. Training process
137
2.1 Resolve class imbalance by using weightened sampler
138
"""
139

140
# class2weight mapper
141
weight_mapper = compute_class_weight(
142
    "balanced", classes=np.unique(train_dataset.targets), y=train_dataset.targets
143
)
144

145
samples_weight = np.array([weight_mapper[t] for t in train_dataset.targets])
146
samples_weight = torch.from_numpy(samples_weight)
147

148
sampler = WeightedRandomSampler(
149
    samples_weight.type("torch.DoubleTensor"), len(samples_weight)
150
)
151

152
"""
153
2.2 Loaders, model and other pretraining objects
154
"""
155
train_loader = torch.utils.data.DataLoader(
156
    train_dataset,
157
    BATCH_SIZE=BATCH_SIZE,
158
    NUM_WORKERS=NUM_WORKERS,
159
    sampler=sampler,
160
)
161

162
test_loader = torch.utils.data.DataLoader(
163
    test_dataset, BATCH_SIZE=BATCH_SIZE, shuffle=False, NUM_WORKERS=NUM_WORKERS
164
)
165

166
loaders = {"train": train_loader, "val": test_loader}
167
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
168

169
model = init_model(device, num_classes=NUM_CLASSES)
170

171
optimizer = torch.optim.Adam(
172
    params=[param for param in model.parameters() if param.requires_grad],
173
    lr=1e-3,
174
    weight_decay=5e-4,
175
)
176

177
# simple, but actually works
178
sheduler = torch.optim.lr_scheduler.StepLR(
179
    optimizer=optimizer, step_size=10, gamma=0.5, verbose=True
180
)
181

182
criterion = nn.CrossEntropyLoss()
183

184
"""
185
2.3 Training and results
186
"""
187
train_results = train_model(
188
    model,
189
    loaders,
190
    criterion,
191
    optimizer,
192
    phases=["train", "val"],
193
    sheduler=sheduler,
194
    num_epochs=80,
195
)
196

197
torch.save(model.state_dict(), MODEL_WEIGHTS)
198

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.