pytorch
151 строка · 4.1 Кб
1from typing import Dict, List, Optional, Tuple2
3import torch4from torch import Tensor5
6
7OUTPUT_DIR = "src/androidTest/assets/"8
9
10def scriptAndSave(module, fileName):11print("-" * 80)12script_module = torch.jit.script(module)13print(script_module.graph)14outputFileName = OUTPUT_DIR + fileName15# note that the lite interpreter model can also be used in full JIT16script_module._save_for_lite_interpreter(outputFileName)17print("Saved to " + outputFileName)18print("=" * 80)19
20
21class Test(torch.jit.ScriptModule):22@torch.jit.script_method23def forward(self, input):24return None25
26@torch.jit.script_method27def eqBool(self, input: bool) -> bool:28return input29
30@torch.jit.script_method31def eqInt(self, input: int) -> int:32return input33
34@torch.jit.script_method35def eqFloat(self, input: float) -> float:36return input37
38@torch.jit.script_method39def eqStr(self, input: str) -> str:40return input41
42@torch.jit.script_method43def eqTensor(self, input: Tensor) -> Tensor:44return input45
46@torch.jit.script_method47def eqDictStrKeyIntValue(self, input: Dict[str, int]) -> Dict[str, int]:48return input49
50@torch.jit.script_method51def eqDictIntKeyIntValue(self, input: Dict[int, int]) -> Dict[int, int]:52return input53
54@torch.jit.script_method55def eqDictFloatKeyIntValue(self, input: Dict[float, int]) -> Dict[float, int]:56return input57
58@torch.jit.script_method59def listIntSumReturnTuple(self, input: List[int]) -> Tuple[List[int], int]:60sum = 061for x in input:62sum += x63return (input, sum)64
65@torch.jit.script_method66def listBoolConjunction(self, input: List[bool]) -> bool:67res = True68for x in input:69res = res and x70return res71
72@torch.jit.script_method73def listBoolDisjunction(self, input: List[bool]) -> bool:74res = False75for x in input:76res = res or x77return res78
79@torch.jit.script_method80def tupleIntSumReturnTuple(81self, input: Tuple[int, int, int]82) -> Tuple[Tuple[int, int, int], int]:83sum = 084for x in input:85sum += x86return (input, sum)87
88@torch.jit.script_method89def optionalIntIsNone(self, input: Optional[int]) -> bool:90return input is None91
92@torch.jit.script_method93def intEq0None(self, input: int) -> Optional[int]:94if input == 0:95return None96return input97
98@torch.jit.script_method99def str3Concat(self, input: str) -> str:100return input + input + input101
102@torch.jit.script_method103def newEmptyShapeWithItem(self, input):104return torch.tensor([int(input.item())])[0]105
106@torch.jit.script_method107def testAliasWithOffset(self) -> List[Tensor]:108x = torch.tensor([100, 200])109a = [x[0], x[1]]110return a111
112@torch.jit.script_method113def testNonContiguous(self):114x = torch.tensor([100, 200, 300])[::2]115assert not x.is_contiguous()116assert x[0] == 100117assert x[1] == 300118return x119
120@torch.jit.script_method121def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:122r = torch.nn.functional.conv2d(x, w)123if toChannelsLast:124r = r.contiguous(memory_format=torch.channels_last)125else:126r = r.contiguous()127return r128
129@torch.jit.script_method130def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:131r = torch.nn.functional.conv3d(x, w)132if toChannelsLast:133r = r.contiguous(memory_format=torch.channels_last_3d)134else:135r = r.contiguous()136return r137
138@torch.jit.script_method139def contiguous(self, x: Tensor) -> Tensor:140return x.contiguous()141
142@torch.jit.script_method143def contiguousChannelsLast(self, x: Tensor) -> Tensor:144return x.contiguous(memory_format=torch.channels_last)145
146@torch.jit.script_method147def contiguousChannelsLast3d(self, x: Tensor) -> Tensor:148return x.contiguous(memory_format=torch.channels_last_3d)149
150
151scriptAndSave(Test(), "test.pt")152