pytorch

Форк
0
/
test_show_pickle.py 
38 строк · 1.1 Кб
1
# Owner(s): ["oncall: mobile"]
2

3
import io
4
import tempfile
5
import unittest
6

7
import torch
8
import torch.utils.show_pickle
9
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase
10

11

12
class TestShowPickle(TestCase):
13
    @unittest.skipIf(IS_WINDOWS, "Can't re-open temp file on Windows")
14
    def test_scripted_model(self):
15
        class MyCoolModule(torch.nn.Module):
16
            def __init__(self, weight):
17
                super().__init__()
18
                self.weight = weight
19

20
            def forward(self, x):
21
                return x * self.weight
22

23
        m = torch.jit.script(MyCoolModule(torch.tensor([2.0])))
24

25
        with tempfile.NamedTemporaryFile() as tmp:
26
            torch.jit.save(m, tmp)
27
            tmp.flush()
28
            buf = io.StringIO()
29
            torch.utils.show_pickle.main(
30
                ["", tmp.name + "@*/data.pkl"], output_stream=buf
31
            )
32
            output = buf.getvalue()
33
            self.assertRegex(output, "MyCoolModule")
34
            self.assertRegex(output, "weight")
35

36

37
if __name__ == "__main__":
38
    run_tests()
39

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.