pytorch
1# Owner(s): ["oncall: mobile"]
2
3import unittest
4import io
5import tempfile
6import torch
7import torch.utils.show_pickle
8
9from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
10
11class TestShowPickle(TestCase):
12
13@unittest.skipIf(IS_WINDOWS, "Can't re-open temp file on Windows")
14def test_scripted_model(self):
15class MyCoolModule(torch.nn.Module):
16def __init__(self, weight):
17super().__init__()
18self.weight = weight
19
20def forward(self, x):
21return x * self.weight
22
23m = torch.jit.script(MyCoolModule(torch.tensor([2.0])))
24
25with tempfile.NamedTemporaryFile() as tmp:
26torch.jit.save(m, tmp)
27tmp.flush()
28buf = io.StringIO()
29torch.utils.show_pickle.main(["", tmp.name + "@*/data.pkl"], output_stream=buf)
30output = buf.getvalue()
31self.assertRegex(output, "MyCoolModule")
32self.assertRegex(output, "weight")
33
34
35if __name__ == '__main__':
36run_tests()
37