russian_art_2024

Форк
0
/
initial_model_utils.py 
76 строк · 1.7 Кб
1
"""Useful functions to initialize model and process"""
2

3
import random
4

5
import numpy as np
6
import torch
7
import torch.nn as nn
8
import torchvision
9

10

11
def set_seed(seed: int) -> None:
12
    """Function to freeze all libs random states
13

14
    Parameters
15
    ----------
16
    seed : int
17
        Random state
18
    """
19
    random.seed(seed)
20
    np.random.seed(seed)
21
    torch.manual_seed(seed)
22

23

24
def set_requires_grad(model: nn.Module, value: bool = False) -> None:
25
    """Function to turn on/off all model's weight's gradients
26

27
    Parameters
28
    ----------
29
    model : torch.nn.Module
30
        Model to change grads
31
    value : bool, optional
32
        Turn on/off gradients of weights, by default False
33
    """
34
    for param in model.parameters():
35
        param.requires_grad = value
36

37

38
def init_model(
39
    device: torch.device, num_classes: int, pretrained: bool = False
40
) -> nn.Module:
41
    """Function to initialize classifier by using backbone of pretrained ResNet50
42

43
    Parameters
44
    ----------
45
    device : torch.device
46
        User's device (cpu or gpu)
47
    num_classes : int
48
        Number of classed to classify
49
    pretrained : bool
50
        Download pretrained model
51

52
    Returns
53
    -------
54
    nn.Module
55
        New model
56
    """
57

58
    weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
59
    model = torchvision.models.resnet50(weights=weights)
60

61
    # freeze all weight's gradients
62
    set_requires_grad(model, False)
63

64
    # new head classifier
65
    model.fc = nn.Sequential(
66
        *[
67
            nn.Dropout(0.735),
68
            nn.Linear(model.fc.in_features, 256),
69
            nn.ReLU(),
70
            nn.Dropout(0.4),
71
            nn.Linear(256, num_classes),
72
        ]
73
    )
74

75
    model = model.to(device)
76
    return model
77

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.