DL_Research

Форк
0
121 строка · 3.7 Кб
1
import torch
2
import torch.nn as nn
3

4

5
class Flattener(nn.Module):
6
    """Слой переводит многомерный тензор в одномерный"""
7
    def forward(self, x):
8
        batch_size, *_ = x.shape
9
        return x.view(batch_size, -1)
10

11

12
def do_epoch(model, train_loader, loss_function, optimizer, device):
13
    # Enter train mode
14
    model.train()
15

16
    total_loss, correct_samples, total_samples = 0, 0, 0
17
    for i_step, (x, y) in enumerate(train_loader):
18
        x_gpu, y_gpu = x.to(device), y.to(device)
19

20
        # Prediction
21
        prediction = model(x_gpu)
22
        loss = loss_function(prediction, y_gpu)
23

24
        optimizer.zero_grad()
25
        loss.backward()
26
        optimizer.step()
27

28
        _, indices = torch.max(prediction, 1)
29
        correct_samples += torch.sum(indices == y_gpu)
30
        total_samples += y.shape[0]
31
        total_loss += float(loss)
32
    return total_loss / i_step, float(correct_samples) / total_samples
33

34

35
def train_model(model, train_loader, val_loader, loss_function, optimizer, num_epochs, device, scheduler=None,
36
                scheduler_loss=False):
37
    train_losses, train_accuracy_history = [], []
38
    validate_losses, validate_accuracy_history = [], []
39

40
    for epoch in range(num_epochs):
41

42
        train_loss, train_accuracy = do_epoch(model, train_loader, loss_function, optimizer, device)
43
        validate_loss, validate_accuracy = 0, 0
44
        if scheduler is not None:
45
            if scheduler_loss:
46
                scheduler.step(validate_loss)
47
            else:
48
                scheduler.step()
49

50
        validate_loss, validate_accuracy = compute_loss_accuracy(model, val_loader, loss_function, device)
51

52
        train_losses.append(train_loss)
53
        train_accuracy_history.append(train_accuracy)
54
        validate_losses.append(validate_loss)
55
        validate_accuracy_history.append(validate_accuracy)
56

57
        print("Epoch #%s - train loss: %f, accuracy: %f | val loss: %f, accuracy: %f" % (
58
            epoch, train_losses[-1], train_accuracy_history[-1], validate_loss, validate_accuracy))
59

60
    return train_losses, train_accuracy_history, validate_losses, validate_accuracy_history
61

62

63
def compute_loss_accuracy(model, loader, loss_function, device):
64
    # Evaluation mode
65
    model.eval()
66

67
    total_loss, correct, total = 0.0, 0.0, 0.0
68
    for i, (x, y) in enumerate(loader):
69
        x_gpu, y_gpu = x.to(device), y.to(device)
70

71
        y_probs = model(x_gpu)
72
        y_hat = torch.argmax(y_probs, 1)
73

74
        loss = loss_function(y_probs, y_gpu)
75
        total_loss += float(loss)
76
        correct += float(torch.sum(y_hat == y_gpu))
77
        total += y_gpu.shape[0]
78

79
    return total_loss / (i + 1), correct / total
80

81

82
def accuracy_number(model, loader, number, device):
83
    # Evaluation mode
84
    model.eval()
85

86
    correct, total = 0.0, 0.0
87

88
    for i, (x, y) in enumerate(loader):
89
        x_gpu, y_gpu = x.to(device), y.to(device)
90

91
        y_number = y_gpu[y_gpu == number]
92
        x_number = x_gpu[y_gpu == number]
93

94
        if len(y_number) == 0:
95
            continue
96

97
        y_probs = model(x_number)
98
        y_hat = torch.argmax(y_probs, 1)
99

100
        correct += float(torch.sum(y_hat == y_number))
101
        total += x_number.shape[0]
102

103
    return correct / total
104

105

106
def compute_error_model(model, loader, device):
107
    # Evaluation mode
108
    model.eval()
109

110
    false_imgs, false_labels, true_labels = [], [], []
111
    for i, (x, y) in enumerate(loader):
112
        x_gpu, y_gpu = x.to(device), y.to(device)
113

114
        y_probs = model(x_gpu)
115
        y_hat = torch.argmax(y_probs, 1)
116

117
        false_imgs += x_gpu[y_hat != y_gpu].cpu().numpy().tolist()
118
        false_labels += y_hat[y_hat != y_gpu].cpu().numpy().tolist()
119
        true_labels += y_gpu[y_hat != y_gpu].cpu().numpy().tolist()
120

121
    return false_imgs, false_labels, true_labels
122

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.