pytorch
1# Owner(s): ["oncall: mobile"]
2
3import io4import tempfile5import unittest6
7import torch8import torch.utils.show_pickle9from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase10
11
12class TestShowPickle(TestCase):13@unittest.skipIf(IS_WINDOWS, "Can't re-open temp file on Windows")14def test_scripted_model(self):15class MyCoolModule(torch.nn.Module):16def __init__(self, weight):17super().__init__()18self.weight = weight19
20def forward(self, x):21return x * self.weight22
23m = torch.jit.script(MyCoolModule(torch.tensor([2.0])))24
25with tempfile.NamedTemporaryFile() as tmp:26torch.jit.save(m, tmp)27tmp.flush()28buf = io.StringIO()29torch.utils.show_pickle.main(30["", tmp.name + "@*/data.pkl"], output_stream=buf31)32output = buf.getvalue()33self.assertRegex(output, "MyCoolModule")34self.assertRegex(output, "weight")35
36
37if __name__ == "__main__":38run_tests()39