pytorch

Форк
0
/
generate_test_torchscripts.py 
151 строка · 4.1 Кб
1
from typing import Dict, List, Optional, Tuple
2

3
import torch
4
from torch import Tensor
5

6

7
OUTPUT_DIR = "src/androidTest/assets/"
8

9

10
def scriptAndSave(module, fileName):
11
    print("-" * 80)
12
    script_module = torch.jit.script(module)
13
    print(script_module.graph)
14
    outputFileName = OUTPUT_DIR + fileName
15
    # note that the lite interpreter model can also be used in full JIT
16
    script_module._save_for_lite_interpreter(outputFileName)
17
    print("Saved to " + outputFileName)
18
    print("=" * 80)
19

20

21
class Test(torch.jit.ScriptModule):
22
    @torch.jit.script_method
23
    def forward(self, input):
24
        return None
25

26
    @torch.jit.script_method
27
    def eqBool(self, input: bool) -> bool:
28
        return input
29

30
    @torch.jit.script_method
31
    def eqInt(self, input: int) -> int:
32
        return input
33

34
    @torch.jit.script_method
35
    def eqFloat(self, input: float) -> float:
36
        return input
37

38
    @torch.jit.script_method
39
    def eqStr(self, input: str) -> str:
40
        return input
41

42
    @torch.jit.script_method
43
    def eqTensor(self, input: Tensor) -> Tensor:
44
        return input
45

46
    @torch.jit.script_method
47
    def eqDictStrKeyIntValue(self, input: Dict[str, int]) -> Dict[str, int]:
48
        return input
49

50
    @torch.jit.script_method
51
    def eqDictIntKeyIntValue(self, input: Dict[int, int]) -> Dict[int, int]:
52
        return input
53

54
    @torch.jit.script_method
55
    def eqDictFloatKeyIntValue(self, input: Dict[float, int]) -> Dict[float, int]:
56
        return input
57

58
    @torch.jit.script_method
59
    def listIntSumReturnTuple(self, input: List[int]) -> Tuple[List[int], int]:
60
        sum = 0
61
        for x in input:
62
            sum += x
63
        return (input, sum)
64

65
    @torch.jit.script_method
66
    def listBoolConjunction(self, input: List[bool]) -> bool:
67
        res = True
68
        for x in input:
69
            res = res and x
70
        return res
71

72
    @torch.jit.script_method
73
    def listBoolDisjunction(self, input: List[bool]) -> bool:
74
        res = False
75
        for x in input:
76
            res = res or x
77
        return res
78

79
    @torch.jit.script_method
80
    def tupleIntSumReturnTuple(
81
        self, input: Tuple[int, int, int]
82
    ) -> Tuple[Tuple[int, int, int], int]:
83
        sum = 0
84
        for x in input:
85
            sum += x
86
        return (input, sum)
87

88
    @torch.jit.script_method
89
    def optionalIntIsNone(self, input: Optional[int]) -> bool:
90
        return input is None
91

92
    @torch.jit.script_method
93
    def intEq0None(self, input: int) -> Optional[int]:
94
        if input == 0:
95
            return None
96
        return input
97

98
    @torch.jit.script_method
99
    def str3Concat(self, input: str) -> str:
100
        return input + input + input
101

102
    @torch.jit.script_method
103
    def newEmptyShapeWithItem(self, input):
104
        return torch.tensor([int(input.item())])[0]
105

106
    @torch.jit.script_method
107
    def testAliasWithOffset(self) -> List[Tensor]:
108
        x = torch.tensor([100, 200])
109
        a = [x[0], x[1]]
110
        return a
111

112
    @torch.jit.script_method
113
    def testNonContiguous(self):
114
        x = torch.tensor([100, 200, 300])[::2]
115
        assert not x.is_contiguous()
116
        assert x[0] == 100
117
        assert x[1] == 300
118
        return x
119

120
    @torch.jit.script_method
121
    def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
122
        r = torch.nn.functional.conv2d(x, w)
123
        if toChannelsLast:
124
            r = r.contiguous(memory_format=torch.channels_last)
125
        else:
126
            r = r.contiguous()
127
        return r
128

129
    @torch.jit.script_method
130
    def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
131
        r = torch.nn.functional.conv3d(x, w)
132
        if toChannelsLast:
133
            r = r.contiguous(memory_format=torch.channels_last_3d)
134
        else:
135
            r = r.contiguous()
136
        return r
137

138
    @torch.jit.script_method
139
    def contiguous(self, x: Tensor) -> Tensor:
140
        return x.contiguous()
141

142
    @torch.jit.script_method
143
    def contiguousChannelsLast(self, x: Tensor) -> Tensor:
144
        return x.contiguous(memory_format=torch.channels_last)
145

146
    @torch.jit.script_method
147
    def contiguousChannelsLast3d(self, x: Tensor) -> Tensor:
148
        return x.contiguous(memory_format=torch.channels_last_3d)
149

150

151
scriptAndSave(Test(), "test.pt")
152

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.