pytorch

Форк
0
/
test_show_pickle.py 
36 строк · 1.0 Кб
1
# Owner(s): ["oncall: mobile"]
2

3
import unittest
4
import io
5
import tempfile
6
import torch
7
import torch.utils.show_pickle
8

9
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
10

11
class TestShowPickle(TestCase):
12

13
    @unittest.skipIf(IS_WINDOWS, "Can't re-open temp file on Windows")
14
    def test_scripted_model(self):
15
        class MyCoolModule(torch.nn.Module):
16
            def __init__(self, weight):
17
                super().__init__()
18
                self.weight = weight
19

20
            def forward(self, x):
21
                return x * self.weight
22

23
        m = torch.jit.script(MyCoolModule(torch.tensor([2.0])))
24

25
        with tempfile.NamedTemporaryFile() as tmp:
26
            torch.jit.save(m, tmp)
27
            tmp.flush()
28
            buf = io.StringIO()
29
            torch.utils.show_pickle.main(["", tmp.name + "@*/data.pkl"], output_stream=buf)
30
            output = buf.getvalue()
31
            self.assertRegex(output, "MyCoolModule")
32
            self.assertRegex(output, "weight")
33

34

35
if __name__ == '__main__':
36
    run_tests()
37

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.