russian_art_2024
/
training.py
197 строк · 4.9 Кб
1"""Script for transfer learning ResNet50 with IMAGENET1K_V2 weights on Russian art dataset"""
2
3"""
41. Import libraries, data, constants
51.1 Import libs
6"""
7
8import os
9import shutil
10
11import numpy as np
12import pandas as pd
13import torch
14from sklearn.model_selection import train_test_split
15from sklearn.utils.class_weight import compute_class_weight
16from torch import nn
17from torch.utils.data import WeightedRandomSampler
18from torchvision import transforms
19
20from src.initial_model_utils import init_model, set_seed
21from src.train_utils import ArtDataset, train_model
22
23"""
241.2 Define constants
25"""
26
27# fixed seed in order to reproduce results
28RANDOM_STATE = 139
29
30# resnet input image size
31IMG_SIZE = 224
32
33# amount of art classes of dataset
34NUM_CLASSES = 35
35
36# loaders params
37BATCH_SIZE = 32
38NUM_WORKERS = 4
39
40# set to true if first run to split data
41SPLIT_DATA_FOLDERS = False
42
43# path to original data
44TRAIN_DATASET = "./data/train/"
45ORIGIN_TRAIN_CSV = "./data/private_info/train.csv"
46
47# path to new splitted data and information
48# train path remains same
49TEST_DATASET = "./data/test/"
50TRAIN_CSV = "./data/private_info/new_train.csv"
51TEST_CSV = "./data/private_info/new_test.csv"
52
53# path to save model's weight
54MODEL_WEIGHTS = "./data/weights/resnet50_tl_68.pt"
55
56"""
571.3 Set libs setting
58"""
59# Fixed all random states
60set_seed(RANDOM_STATE)
61
62"""
631.4 Split initial data to train and test parts
64"""
65
66# Split default data into train and test files
67if SPLIT_DATA_FOLDERS:
68# original dataset
69train_data_info = pd.read_csv(ORIGIN_TRAIN_CSV, sep="\t")
70
71# extract id, classes
72indices = train_data_info.index
73labels = train_data_info.label_id
74
75# stratified split due to class imbalance
76ind_train, ind_test, _, _ = train_test_split(
77indices, labels, test_size=0.2, random_state=RANDOM_STATE, stratify=labels
78)
79
80# new datasets info
81new_train, new_test = (
82train_data_info.loc[ind_train].reset_index(drop=True),
83train_data_info.loc[ind_test].reset_index(drop=True),
84)
85
86# move test photos from original train to new test folder
87source_dir = TRAIN_DATASET
88target_dir = TEST_DATASET
89file_names = new_test["image_name"].values
90
91for file_name in file_names:
92shutil.move(os.path.join(source_dir, file_name), target_dir)
93
94new_train.to_csv(TRAIN_CSV, index=False, sep="\t")
95new_test.to_csv(TEST_CSV, index=False, sep="\t")
96
97
98"""
991.5 Import images data
100"""
101
102# define base transforms (from ResNet50 docs)
103base_transforms = transforms.Compose(
104[
105transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
106transforms.Lambda(lambda x: np.array(x, dtype="float32") / 255),
107transforms.ToTensor(),
108transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
109]
110)
111
112# add some random augmentations in order to prevent overfitting
113augmentations = transforms.RandomChoice(
114[
115transforms.Compose(
116[
117transforms.Resize(size=300, max_size=301),
118transforms.CenterCrop(size=300),
119transforms.RandomCrop(250),
120]
121),
122transforms.RandomRotation(degrees=(-25, 25)),
123transforms.RandomHorizontalFlip(p=0.5),
124]
125)
126
127# final compose and datasets
128train_transform = transforms.Compose([augmentations, base_transforms])
129test_transform = base_transforms
130
131train_dataset = ArtDataset(TRAIN_DATASET, TRAIN_CSV, train_transform)
132test_dataset = ArtDataset(TEST_DATASET, TEST_CSV, test_transform)
133
134
135"""
1362. Training process
1372.1 Resolve class imbalance by using weightened sampler
138"""
139
140# class2weight mapper
141weight_mapper = compute_class_weight(
142"balanced", classes=np.unique(train_dataset.targets), y=train_dataset.targets
143)
144
145samples_weight = np.array([weight_mapper[t] for t in train_dataset.targets])
146samples_weight = torch.from_numpy(samples_weight)
147
148sampler = WeightedRandomSampler(
149samples_weight.type("torch.DoubleTensor"), len(samples_weight)
150)
151
152"""
1532.2 Loaders, model and other pretraining objects
154"""
155train_loader = torch.utils.data.DataLoader(
156train_dataset,
157BATCH_SIZE=BATCH_SIZE,
158NUM_WORKERS=NUM_WORKERS,
159sampler=sampler,
160)
161
162test_loader = torch.utils.data.DataLoader(
163test_dataset, BATCH_SIZE=BATCH_SIZE, shuffle=False, NUM_WORKERS=NUM_WORKERS
164)
165
166loaders = {"train": train_loader, "val": test_loader}
167device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
168
169model = init_model(device, num_classes=NUM_CLASSES)
170
171optimizer = torch.optim.Adam(
172params=[param for param in model.parameters() if param.requires_grad],
173lr=1e-3,
174weight_decay=5e-4,
175)
176
177# simple, but actually works
178sheduler = torch.optim.lr_scheduler.StepLR(
179optimizer=optimizer, step_size=10, gamma=0.5, verbose=True
180)
181
182criterion = nn.CrossEntropyLoss()
183
184"""
1852.3 Training and results
186"""
187train_results = train_model(
188model,
189loaders,
190criterion,
191optimizer,
192phases=["train", "val"],
193sheduler=sheduler,
194num_epochs=80,
195)
196
197torch.save(model.state_dict(), MODEL_WEIGHTS)
198