pytorch

Форк
0
/
create_dummy_torchscript_model.py 
36 строк · 900.0 Байт
1
# Usage: python create_dummy_model.py <name_of_the_file>
2
import sys
3

4
import torch
5
from torch import nn
6

7

8
class NeuralNetwork(nn.Module):
9
    def __init__(self) -> None:
10
        super().__init__()
11
        self.flatten = nn.Flatten()
12
        self.linear_relu_stack = nn.Sequential(
13
            nn.Linear(28 * 28, 512),
14
            nn.ReLU(),
15
            nn.Linear(512, 512),
16
            nn.ReLU(),
17
            nn.Linear(512, 10),
18
        )
19

20
    def forward(self, x):
21
        x = self.flatten(x)
22
        logits = self.linear_relu_stack(x)
23
        return logits
24

25

26
if __name__ == "__main__":
27
    jit_module = torch.jit.script(NeuralNetwork())
28
    torch.jit.save(jit_module, sys.argv[1])
29
    orig_module = nn.Sequential(
30
        nn.Linear(28 * 28, 512),
31
        nn.ReLU(),
32
        nn.Linear(512, 512),
33
        nn.ReLU(),
34
        nn.Linear(512, 10),
35
    )
36
    torch.save(orig_module, sys.argv[1] + ".orig")
37

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.