russian_art_2024
210 строк · 6.1 Кб
1"""Useful classes / functions for training process"""
2
3import os
4import time
5from typing import Dict, List, Tuple, Union
6
7import matplotlib.pyplot as plt
8import pandas as pd
9import seaborn as sns
10import torch
11import torch.nn as nn
12from Ipython.display import clear_output
13from PIL import Image
14from sklearn.metrics import f1_score
15from torch.utils.data import Dataset
16from tqdm import tqdm
17
18
19def train_model(
20model: nn.Module,
21dataloaders: Dict[str, torch.data.utils.DataLoader],
22criterion: nn.Module,
23optimizer,
24phases: List[str],
25device: torch.device,
26sheduler: Union[None, nn.Module] = None,
27num_epochs: int = 3,
28) -> Tuple[nn.Module, Dict[str, List[float]]]:
29"""Function to train custom image classifier
30
31Parameters
32----------
33model : nn.Module
34Model with new classifier, backbone pretrained
35dataloaders : Dict[str, torch.data.utils.DataLoader]
36Loaders of train and test data
37criterion : nn.Module
38Function to optimize
39optimizer : _type_
40Selected optimized. Preferred Adam
41phases : List[str]
42Train and
43device : torch.device
44User device where to store data and model
45sheduler : Union[None, nn.Module], optional
46Learning rate sheduler, by default None
47num_epochs : int, optional
48Epochs to train, by default 3
49
50Returns
51-------
52Tuple[nn.Module, Dict[str, List[float]]]
53Trained model and metric history
54"""
55
56start_time = time.time()
57
58metric_history = {k: list() for k in phases}
59loss_history = {k: list() for k in phases}
60
61for epoch in range(1, num_epochs + 1):
62print("Epoch {}/{}".format(epoch, num_epochs))
63print("-" * 10)
64
65# each epoch has a training and validation phase
66for phase in phases:
67if phase == "train":
68# set model to training mode
69model.train()
70else:
71# set model to evaluate mode
72model.eval()
73
74running_loss = 0.0
75phase_preds, phase_labels = [], []
76
77# tterate over data
78n_batches = len(dataloaders[phase])
79for inputs, labels in tqdm(dataloaders[phase], total=n_batches):
80inputs = inputs.to(device)
81labels = labels.to(device)
82
83# zero the parameter gradients
84optimizer.zero_grad()
85
86# forward pass
87with torch.set_grad_enabled(phase == "train"):
88outputs = model(inputs)
89loss = criterion(outputs, labels)
90
91_, preds = torch.max(outputs, 1)
92
93# backward + optimize only if in train phase
94if phase == "train":
95loss.backward()
96optimizer.step()
97
98# statistics
99running_loss += loss.item() * inputs.size(0)
100phase_preds.extend(preds.detach().cpu().numpy())
101phase_labels.extend(labels.detach().cpu().numpy())
102
103epoch_loss = running_loss / len(dataloaders[phase].dataset)
104epoch_f1 = f1_score(phase_labels, phase_preds, average="macro")
105
106print("{} Loss: {:.4f} f1: {:.4f}".format(phase, epoch_loss, epoch_f1))
107loss_history[phase].append(epoch_loss)
108metric_history[phase].append(epoch_f1)
109
110# run sheduler after validation phase
111if sheduler is not None and phase == "val":
112sheduler.step()
113
114plot_train_process(
115cur_epoch_num=epoch,
116loss_history=loss_history,
117metric_history=metric_history,
118)
119
120time_elapsed = time.time() - start_time
121print(
122"Training complete in {:.0f}m {:.0f}s".format(
123time_elapsed // 60, time_elapsed % 60
124)
125)
126
127return model, metric_history
128
129
130class ArtDataset(Dataset):
131def __init__(self, root_dir, csv_path=None, transform=None):
132self.transform = transform
133self.files = [os.path.join(root_dir, fname) for fname in os.listdir(root_dir)]
134self.targets = None
135if csv_path:
136df = pd.read_csv(csv_path, sep="\t")
137self.targets = df["label_id"].tolist()
138self.files = [
139os.path.join(root_dir, fname) for fname in df["image_name"].tolist()
140]
141
142def __len__(self):
143return len(self.files)
144
145def __getitem__(self, idx):
146image = Image.open(self.files[idx]).convert("RGB")
147target = self.targets[idx] if self.targets else -1
148if self.transform:
149image = self.transform(image)
150return image, target
151
152
153def plot_train_process(
154cur_epoch_num: int,
155loss_history: Dict[str, List[float]],
156metric_history: Dict[str, List[float]],
157) -> None:
158"""Function to plot losses and metrics on training process
159
160Plots 2 graphics with loss and metrics on train and test data
161
162Parameters
163----------
164cur_epoch_num : int
165Current epoch number
166loss_history : Dict[str, List[float]]
167Storage of train and validation loss history
168metric_history : Dict[str, List[float]]
169Storage of train and validation metric history
170"""
171
172clear_output(wait=True)
173
174fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(25, 15))
175
176marker_style = {
177"marker": "o",
178"markersize": 5,
179"markerfacecolor": "black",
180}
181
182train_style = {
183"label": "Train value",
184"color": "b",
185} | marker_style
186
187val_style = {
188"label": "Test value",
189"color": "r",
190} | marker_style
191
192x_epoch = np.arange(1, cur_epoch_num + 1)
193
194sns.lineplot(ax=ax[0], x=x_epoch, y=loss_history["train"], **train_style)
195sns.lineplot(ax=ax[0], x=x_epoch, y=loss_history["val"], **val_style)
196
197sns.lineplot(ax=ax[1], x=x_epoch, y=metric_history["train"], **train_style)
198sns.lineplot(ax=ax[1], x=x_epoch, y=metric_history["val"], **val_style)
199
200ax[0].set_ylabel("Loss")
201ax[0].set_xlabel("Epoch")
202ax[0].set_title("Loss graph")
203ax[0].grid()
204
205ax[1].set_ylabel("Metric")
206ax[1].set_xlabel("Epoch")
207ax[1].set_title("F1-macro score")
208ax[1].grid()
209
210plt.show()
211