DL_Research
121 строка · 3.7 Кб
1import torch2import torch.nn as nn3
4
5class Flattener(nn.Module):6"""Слой переводит многомерный тензор в одномерный"""7def forward(self, x):8batch_size, *_ = x.shape9return x.view(batch_size, -1)10
11
12def do_epoch(model, train_loader, loss_function, optimizer, device):13# Enter train mode14model.train()15
16total_loss, correct_samples, total_samples = 0, 0, 017for i_step, (x, y) in enumerate(train_loader):18x_gpu, y_gpu = x.to(device), y.to(device)19
20# Prediction21prediction = model(x_gpu)22loss = loss_function(prediction, y_gpu)23
24optimizer.zero_grad()25loss.backward()26optimizer.step()27
28_, indices = torch.max(prediction, 1)29correct_samples += torch.sum(indices == y_gpu)30total_samples += y.shape[0]31total_loss += float(loss)32return total_loss / i_step, float(correct_samples) / total_samples33
34
35def train_model(model, train_loader, val_loader, loss_function, optimizer, num_epochs, device, scheduler=None,36scheduler_loss=False):37train_losses, train_accuracy_history = [], []38validate_losses, validate_accuracy_history = [], []39
40for epoch in range(num_epochs):41
42train_loss, train_accuracy = do_epoch(model, train_loader, loss_function, optimizer, device)43validate_loss, validate_accuracy = 0, 044if scheduler is not None:45if scheduler_loss:46scheduler.step(validate_loss)47else:48scheduler.step()49
50validate_loss, validate_accuracy = compute_loss_accuracy(model, val_loader, loss_function, device)51
52train_losses.append(train_loss)53train_accuracy_history.append(train_accuracy)54validate_losses.append(validate_loss)55validate_accuracy_history.append(validate_accuracy)56
57print("Epoch #%s - train loss: %f, accuracy: %f | val loss: %f, accuracy: %f" % (58epoch, train_losses[-1], train_accuracy_history[-1], validate_loss, validate_accuracy))59
60return train_losses, train_accuracy_history, validate_losses, validate_accuracy_history61
62
63def compute_loss_accuracy(model, loader, loss_function, device):64# Evaluation mode65model.eval()66
67total_loss, correct, total = 0.0, 0.0, 0.068for i, (x, y) in enumerate(loader):69x_gpu, y_gpu = x.to(device), y.to(device)70
71y_probs = model(x_gpu)72y_hat = torch.argmax(y_probs, 1)73
74loss = loss_function(y_probs, y_gpu)75total_loss += float(loss)76correct += float(torch.sum(y_hat == y_gpu))77total += y_gpu.shape[0]78
79return total_loss / (i + 1), correct / total80
81
82def accuracy_number(model, loader, number, device):83# Evaluation mode84model.eval()85
86correct, total = 0.0, 0.087
88for i, (x, y) in enumerate(loader):89x_gpu, y_gpu = x.to(device), y.to(device)90
91y_number = y_gpu[y_gpu == number]92x_number = x_gpu[y_gpu == number]93
94if len(y_number) == 0:95continue96
97y_probs = model(x_number)98y_hat = torch.argmax(y_probs, 1)99
100correct += float(torch.sum(y_hat == y_number))101total += x_number.shape[0]102
103return correct / total104
105
106def compute_error_model(model, loader, device):107# Evaluation mode108model.eval()109
110false_imgs, false_labels, true_labels = [], [], []111for i, (x, y) in enumerate(loader):112x_gpu, y_gpu = x.to(device), y.to(device)113
114y_probs = model(x_gpu)115y_hat = torch.argmax(y_probs, 1)116
117false_imgs += x_gpu[y_hat != y_gpu].cpu().numpy().tolist()118false_labels += y_hat[y_hat != y_gpu].cpu().numpy().tolist()119true_labels += y_gpu[y_hat != y_gpu].cpu().numpy().tolist()120
121return false_imgs, false_labels, true_labels122