russian_art_2024

Форк
0
/
train_utils.py 
210 строк · 6.1 Кб
1
"""Useful classes / functions for training process"""
2

3
import os
4
import time
5
from typing import Dict, List, Tuple, Union
6

7
import matplotlib.pyplot as plt
8
import pandas as pd
9
import seaborn as sns
10
import torch
11
import torch.nn as nn
12
from Ipython.display import clear_output
13
from PIL import Image
14
from sklearn.metrics import f1_score
15
from torch.utils.data import Dataset
16
from tqdm import tqdm
17

18

19
def train_model(
20
    model: nn.Module,
21
    dataloaders: Dict[str, torch.data.utils.DataLoader],
22
    criterion: nn.Module,
23
    optimizer,
24
    phases: List[str],
25
    device: torch.device,
26
    sheduler: Union[None, nn.Module] = None,
27
    num_epochs: int = 3,
28
) -> Tuple[nn.Module, Dict[str, List[float]]]:
29
    """Function to train custom image classifier
30

31
    Parameters
32
    ----------
33
    model : nn.Module
34
        Model with new classifier, backbone pretrained
35
    dataloaders : Dict[str, torch.data.utils.DataLoader]
36
        Loaders of train and test data
37
    criterion : nn.Module
38
        Function to optimize
39
    optimizer : _type_
40
        Selected optimized. Preferred Adam
41
    phases : List[str]
42
        Train and
43
    device : torch.device
44
        User device where to store data and model
45
    sheduler : Union[None, nn.Module], optional
46
        Learning rate sheduler, by default None
47
    num_epochs : int, optional
48
        Epochs to train, by default 3
49

50
    Returns
51
    -------
52
    Tuple[nn.Module, Dict[str, List[float]]]
53
        Trained model and metric history
54
    """
55

56
    start_time = time.time()
57

58
    metric_history = {k: list() for k in phases}
59
    loss_history = {k: list() for k in phases}
60

61
    for epoch in range(1, num_epochs + 1):
62
        print("Epoch {}/{}".format(epoch, num_epochs))
63
        print("-" * 10)
64

65
        # each epoch has a training and validation phase
66
        for phase in phases:
67
            if phase == "train":
68
                # set model to training mode
69
                model.train()
70
            else:
71
                # set model to evaluate mode
72
                model.eval()
73

74
            running_loss = 0.0
75
            phase_preds, phase_labels = [], []
76

77
            # tterate over data
78
            n_batches = len(dataloaders[phase])
79
            for inputs, labels in tqdm(dataloaders[phase], total=n_batches):
80
                inputs = inputs.to(device)
81
                labels = labels.to(device)
82

83
                # zero the parameter gradients
84
                optimizer.zero_grad()
85

86
                # forward pass
87
                with torch.set_grad_enabled(phase == "train"):
88
                    outputs = model(inputs)
89
                    loss = criterion(outputs, labels)
90

91
                    _, preds = torch.max(outputs, 1)
92

93
                    # backward + optimize only if in train phase
94
                    if phase == "train":
95
                        loss.backward()
96
                        optimizer.step()
97

98
                # statistics
99
                running_loss += loss.item() * inputs.size(0)
100
                phase_preds.extend(preds.detach().cpu().numpy())
101
                phase_labels.extend(labels.detach().cpu().numpy())
102

103
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
104
            epoch_f1 = f1_score(phase_labels, phase_preds, average="macro")
105

106
            print("{} Loss: {:.4f} f1: {:.4f}".format(phase, epoch_loss, epoch_f1))
107
            loss_history[phase].append(epoch_loss)
108
            metric_history[phase].append(epoch_f1)
109

110
        # run sheduler after validation phase
111
        if sheduler is not None and phase == "val":
112
            sheduler.step()
113

114
        plot_train_process(
115
            cur_epoch_num=epoch,
116
            loss_history=loss_history,
117
            metric_history=metric_history,
118
        )
119

120
    time_elapsed = time.time() - start_time
121
    print(
122
        "Training complete in {:.0f}m {:.0f}s".format(
123
            time_elapsed // 60, time_elapsed % 60
124
        )
125
    )
126

127
    return model, metric_history
128

129

130
class ArtDataset(Dataset):
131
    def __init__(self, root_dir, csv_path=None, transform=None):
132
        self.transform = transform
133
        self.files = [os.path.join(root_dir, fname) for fname in os.listdir(root_dir)]
134
        self.targets = None
135
        if csv_path:
136
            df = pd.read_csv(csv_path, sep="\t")
137
            self.targets = df["label_id"].tolist()
138
            self.files = [
139
                os.path.join(root_dir, fname) for fname in df["image_name"].tolist()
140
            ]
141

142
    def __len__(self):
143
        return len(self.files)
144

145
    def __getitem__(self, idx):
146
        image = Image.open(self.files[idx]).convert("RGB")
147
        target = self.targets[idx] if self.targets else -1
148
        if self.transform:
149
            image = self.transform(image)
150
        return image, target
151

152

153
def plot_train_process(
154
    cur_epoch_num: int,
155
    loss_history: Dict[str, List[float]],
156
    metric_history: Dict[str, List[float]],
157
) -> None:
158
    """Function to plot losses and metrics on training process
159

160
    Plots 2 graphics with loss and metrics on train and test data
161

162
    Parameters
163
    ----------
164
    cur_epoch_num : int
165
        Current epoch number
166
    loss_history : Dict[str, List[float]]
167
        Storage of train and validation loss history
168
    metric_history : Dict[str, List[float]]
169
        Storage of train and validation metric history
170
    """
171

172
    clear_output(wait=True)
173

174
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(25, 15))
175

176
    marker_style = {
177
        "marker": "o",
178
        "markersize": 5,
179
        "markerfacecolor": "black",
180
    }
181

182
    train_style = {
183
        "label": "Train value",
184
        "color": "b",
185
    } | marker_style
186

187
    val_style = {
188
        "label": "Test value",
189
        "color": "r",
190
    } | marker_style
191

192
    x_epoch = np.arange(1, cur_epoch_num + 1)
193

194
    sns.lineplot(ax=ax[0], x=x_epoch, y=loss_history["train"], **train_style)
195
    sns.lineplot(ax=ax[0], x=x_epoch, y=loss_history["val"], **val_style)
196

197
    sns.lineplot(ax=ax[1], x=x_epoch, y=metric_history["train"], **train_style)
198
    sns.lineplot(ax=ax[1], x=x_epoch, y=metric_history["val"], **val_style)
199

200
    ax[0].set_ylabel("Loss")
201
    ax[0].set_xlabel("Epoch")
202
    ax[0].set_title("Loss graph")
203
    ax[0].grid()
204

205
    ax[1].set_ylabel("Metric")
206
    ax[1].set_xlabel("Epoch")
207
    ax[1].set_title("F1-macro score")
208
    ax[1].grid()
209

210
    plt.show()
211

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.