pytorch

Форк
0
/
test_importer.py 
162 строки · 5.5 Кб
1
# Owner(s): ["oncall: package/deploy"]
2

3
from io import BytesIO
4

5
import torch
6
from torch.package import (
7
    Importer,
8
    OrderedImporter,
9
    PackageExporter,
10
    PackageImporter,
11
    sys_importer,
12
)
13
from torch.testing._internal.common_utils import run_tests
14

15

16
try:
17
    from .common import PackageTestCase
18
except ImportError:
19
    # Support the case where we run this file directly.
20
    from common import PackageTestCase
21

22

23
class TestImporter(PackageTestCase):
24
    """Tests for Importer and derived classes."""
25

26
    def test_sys_importer(self):
27
        import package_a
28
        import package_a.subpackage
29

30
        self.assertIs(sys_importer.import_module("package_a"), package_a)
31
        self.assertIs(
32
            sys_importer.import_module("package_a.subpackage"), package_a.subpackage
33
        )
34

35
    def test_sys_importer_roundtrip(self):
36
        import package_a
37
        import package_a.subpackage
38

39
        importer = sys_importer
40
        type_ = package_a.subpackage.PackageASubpackageObject
41
        module_name, type_name = importer.get_name(type_)
42

43
        module = importer.import_module(module_name)
44
        self.assertIs(getattr(module, type_name), type_)
45

46
    def test_single_ordered_importer(self):
47
        import module_a  # noqa: F401
48
        import package_a
49

50
        buffer = BytesIO()
51
        with PackageExporter(buffer) as pe:
52
            pe.save_module(package_a.__name__)
53

54
        buffer.seek(0)
55
        importer = PackageImporter(buffer)
56

57
        # Construct an importer-only environment.
58
        ordered_importer = OrderedImporter(importer)
59

60
        # The module returned by this environment should be the same one that's
61
        # in the importer.
62
        self.assertIs(
63
            ordered_importer.import_module("package_a"),
64
            importer.import_module("package_a"),
65
        )
66
        # It should not be the one available in the outer Python environment.
67
        self.assertIsNot(ordered_importer.import_module("package_a"), package_a)
68

69
        # We didn't package this module, so it should not be available.
70
        with self.assertRaises(ModuleNotFoundError):
71
            ordered_importer.import_module("module_a")
72

73
    def test_ordered_importer_basic(self):
74
        import package_a
75

76
        buffer = BytesIO()
77
        with PackageExporter(buffer) as pe:
78
            pe.save_module(package_a.__name__)
79

80
        buffer.seek(0)
81
        importer = PackageImporter(buffer)
82

83
        ordered_importer_sys_first = OrderedImporter(sys_importer, importer)
84
        self.assertIs(ordered_importer_sys_first.import_module("package_a"), package_a)
85

86
        ordered_importer_package_first = OrderedImporter(importer, sys_importer)
87
        self.assertIs(
88
            ordered_importer_package_first.import_module("package_a"),
89
            importer.import_module("package_a"),
90
        )
91

92
    def test_ordered_importer_whichmodule(self):
93
        """OrderedImporter's implementation of whichmodule should try each
94
        underlying importer's whichmodule in order.
95
        """
96

97
        class DummyImporter(Importer):
98
            def __init__(self, whichmodule_return):
99
                self._whichmodule_return = whichmodule_return
100

101
            def import_module(self, module_name):
102
                raise NotImplementedError
103

104
            def whichmodule(self, obj, name):
105
                return self._whichmodule_return
106

107
        class DummyClass:
108
            pass
109

110
        dummy_importer_foo = DummyImporter("foo")
111
        dummy_importer_bar = DummyImporter("bar")
112
        dummy_importer_not_found = DummyImporter(
113
            "__main__"
114
        )  # __main__ is used as a proxy for "not found" by CPython
115

116
        foo_then_bar = OrderedImporter(dummy_importer_foo, dummy_importer_bar)
117
        self.assertEqual(foo_then_bar.whichmodule(DummyClass(), ""), "foo")
118

119
        bar_then_foo = OrderedImporter(dummy_importer_bar, dummy_importer_foo)
120
        self.assertEqual(bar_then_foo.whichmodule(DummyClass(), ""), "bar")
121

122
        notfound_then_foo = OrderedImporter(
123
            dummy_importer_not_found, dummy_importer_foo
124
        )
125
        self.assertEqual(notfound_then_foo.whichmodule(DummyClass(), ""), "foo")
126

127
    def test_package_importer_whichmodule_no_dunder_module(self):
128
        """Exercise corner case where we try to pickle an object whose
129
        __module__ doesn't exist because it's from a C extension.
130
        """
131
        # torch.float16 is an example of such an object: it is a C extension
132
        # type for which there is no __module__ defined. The default pickler
133
        # finds it using special logic to traverse sys.modules and look up
134
        # `float16` on each module (see pickle.py:whichmodule).
135
        #
136
        # We must ensure that we emulate the same behavior from PackageImporter.
137
        my_dtype = torch.float16
138

139
        # Set up a PackageImporter which has a torch.float16 object pickled:
140
        buffer = BytesIO()
141
        with PackageExporter(buffer) as exporter:
142
            exporter.save_pickle("foo", "foo.pkl", my_dtype)
143
        buffer.seek(0)
144

145
        importer = PackageImporter(buffer)
146
        my_loaded_dtype = importer.load_pickle("foo", "foo.pkl")
147

148
        # Re-save a package with only our PackageImporter as the importer
149
        buffer2 = BytesIO()
150
        with PackageExporter(buffer2, importer=importer) as exporter:
151
            exporter.save_pickle("foo", "foo.pkl", my_loaded_dtype)
152

153
        buffer2.seek(0)
154

155
        importer2 = PackageImporter(buffer2)
156
        my_loaded_dtype2 = importer2.load_pickle("foo", "foo.pkl")
157
        self.assertIs(my_dtype, my_loaded_dtype)
158
        self.assertIs(my_dtype, my_loaded_dtype2)
159

160

161
if __name__ == "__main__":
162
    run_tests()
163

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.