russian_art_2024
76 строк · 1.7 Кб
1"""Useful functions to initialize model and process"""
2
3import random4
5import numpy as np6import torch7import torch.nn as nn8import torchvision9
10
11def set_seed(seed: int) -> None:12"""Function to freeze all libs random states13
14Parameters
15----------
16seed : int
17Random state
18"""
19random.seed(seed)20np.random.seed(seed)21torch.manual_seed(seed)22
23
24def set_requires_grad(model: nn.Module, value: bool = False) -> None:25"""Function to turn on/off all model's weight's gradients26
27Parameters
28----------
29model : torch.nn.Module
30Model to change grads
31value : bool, optional
32Turn on/off gradients of weights, by default False
33"""
34for param in model.parameters():35param.requires_grad = value36
37
38def init_model(39device: torch.device, num_classes: int, pretrained: bool = False40) -> nn.Module:41"""Function to initialize classifier by using backbone of pretrained ResNet5042
43Parameters
44----------
45device : torch.device
46User's device (cpu or gpu)
47num_classes : int
48Number of classed to classify
49pretrained : bool
50Download pretrained model
51
52Returns
53-------
54nn.Module
55New model
56"""
57
58weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None59model = torchvision.models.resnet50(weights=weights)60
61# freeze all weight's gradients62set_requires_grad(model, False)63
64# new head classifier65model.fc = nn.Sequential(66*[67nn.Dropout(0.735),68nn.Linear(model.fc.in_features, 256),69nn.ReLU(),70nn.Dropout(0.4),71nn.Linear(256, num_classes),72]73)74
75model = model.to(device)76return model77