pytorch

Форк
0
/
test_directory_reader.py 
290 строк · 10.3 Кб
1
# Owner(s): ["oncall: package/deploy"]
2

3
import os
4
import zipfile
5
from sys import version_info
6
from tempfile import TemporaryDirectory
7
from textwrap import dedent
8
from unittest import skipIf
9

10
import torch
11
from torch.package import PackageExporter, PackageImporter
12
from torch.testing._internal.common_utils import (
13
    IS_FBCODE,
14
    IS_SANDCASTLE,
15
    IS_WINDOWS,
16
    run_tests,
17
)
18

19

20
try:
21
    from torchvision.models import resnet18
22

23
    HAS_TORCHVISION = True
24
except ImportError:
25
    HAS_TORCHVISION = False
26
skipIfNoTorchVision = skipIf(not HAS_TORCHVISION, "no torchvision")
27

28

29
try:
30
    from .common import PackageTestCase
31
except ImportError:
32
    # Support the case where we run this file directly.
33
    from common import PackageTestCase
34

35
from pathlib import Path
36

37

38
packaging_directory = Path(__file__).parent
39

40

41
@skipIf(
42
    IS_FBCODE or IS_SANDCASTLE or IS_WINDOWS,
43
    "Tests that use temporary files are disabled in fbcode",
44
)
45
class DirectoryReaderTest(PackageTestCase):
46
    """Tests use of DirectoryReader as accessor for opened packages."""
47

48
    @skipIfNoTorchVision
49
    @skipIf(
50
        True,
51
        "Does not work with latest TorchVision, see https://github.com/pytorch/pytorch/issues/81115",
52
    )
53
    def test_loading_pickle(self):
54
        """
55
        Test basic saving and loading of modules and pickles from a DirectoryReader.
56
        """
57
        resnet = resnet18()
58

59
        filename = self.temp()
60
        with PackageExporter(filename) as e:
61
            e.intern("**")
62
            e.save_pickle("model", "model.pkl", resnet)
63

64
        zip_file = zipfile.ZipFile(filename, "r")
65

66
        with TemporaryDirectory() as temp_dir:
67
            zip_file.extractall(path=temp_dir)
68
            importer = PackageImporter(Path(temp_dir) / Path(filename).name)
69
            dir_mod = importer.load_pickle("model", "model.pkl")
70
            input = torch.rand(1, 3, 224, 224)
71
            self.assertEqual(dir_mod(input), resnet(input))
72

73
    def test_loading_module(self):
74
        """
75
        Test basic saving and loading of a packages from a DirectoryReader.
76
        """
77
        import package_a
78

79
        filename = self.temp()
80
        with PackageExporter(filename) as e:
81
            e.save_module("package_a")
82

83
        zip_file = zipfile.ZipFile(filename, "r")
84

85
        with TemporaryDirectory() as temp_dir:
86
            zip_file.extractall(path=temp_dir)
87
            dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
88
            dir_mod = dir_importer.import_module("package_a")
89
            self.assertEqual(dir_mod.result, package_a.result)
90

91
    def test_loading_has_record(self):
92
        """
93
        Test DirectoryReader's has_record().
94
        """
95
        import package_a  # noqa: F401
96

97
        filename = self.temp()
98
        with PackageExporter(filename) as e:
99
            e.save_module("package_a")
100

101
        zip_file = zipfile.ZipFile(filename, "r")
102

103
        with TemporaryDirectory() as temp_dir:
104
            zip_file.extractall(path=temp_dir)
105
            dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
106
            self.assertTrue(dir_importer.zip_reader.has_record("package_a/__init__.py"))
107
            self.assertFalse(dir_importer.zip_reader.has_record("package_a"))
108

109
    @skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7")
110
    def test_resource_reader(self):
111
        """Tests DirectoryReader as the base for get_resource_reader."""
112
        filename = self.temp()
113
        with PackageExporter(filename) as pe:
114
            # Layout looks like:
115
            #    package
116
            #    |-- one/
117
            #    |   |-- a.txt
118
            #    |   |-- b.txt
119
            #    |   |-- c.txt
120
            #    |   +-- three/
121
            #    |       |-- d.txt
122
            #    |       +-- e.txt
123
            #    +-- two/
124
            #       |-- f.txt
125
            #       +-- g.txt
126
            pe.save_text("one", "a.txt", "hello, a!")
127
            pe.save_text("one", "b.txt", "hello, b!")
128
            pe.save_text("one", "c.txt", "hello, c!")
129

130
            pe.save_text("one.three", "d.txt", "hello, d!")
131
            pe.save_text("one.three", "e.txt", "hello, e!")
132

133
            pe.save_text("two", "f.txt", "hello, f!")
134
            pe.save_text("two", "g.txt", "hello, g!")
135

136
        zip_file = zipfile.ZipFile(filename, "r")
137

138
        with TemporaryDirectory() as temp_dir:
139
            zip_file.extractall(path=temp_dir)
140
            importer = PackageImporter(Path(temp_dir) / Path(filename).name)
141
            reader_one = importer.get_resource_reader("one")
142

143
            # Different behavior from still zipped archives
144
            resource_path = os.path.join(
145
                Path(temp_dir), Path(filename).name, "one", "a.txt"
146
            )
147
            self.assertEqual(reader_one.resource_path("a.txt"), resource_path)
148

149
            self.assertTrue(reader_one.is_resource("a.txt"))
150
            self.assertEqual(
151
                reader_one.open_resource("a.txt").getbuffer(), b"hello, a!"
152
            )
153
            self.assertFalse(reader_one.is_resource("three"))
154
            reader_one_contents = list(reader_one.contents())
155
            reader_one_contents.sort()
156
            self.assertSequenceEqual(
157
                reader_one_contents, ["a.txt", "b.txt", "c.txt", "three"]
158
            )
159

160
            reader_two = importer.get_resource_reader("two")
161
            self.assertTrue(reader_two.is_resource("f.txt"))
162
            self.assertEqual(
163
                reader_two.open_resource("f.txt").getbuffer(), b"hello, f!"
164
            )
165
            reader_two_contents = list(reader_two.contents())
166
            reader_two_contents.sort()
167
            self.assertSequenceEqual(reader_two_contents, ["f.txt", "g.txt"])
168

169
            reader_one_three = importer.get_resource_reader("one.three")
170
            self.assertTrue(reader_one_three.is_resource("d.txt"))
171
            self.assertEqual(
172
                reader_one_three.open_resource("d.txt").getbuffer(), b"hello, d!"
173
            )
174
            reader_one_three_contents = list(reader_one_three.contents())
175
            reader_one_three_contents.sort()
176
            self.assertSequenceEqual(reader_one_three_contents, ["d.txt", "e.txt"])
177

178
            self.assertIsNone(importer.get_resource_reader("nonexistent_package"))
179

180
    @skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7")
181
    def test_package_resource_access(self):
182
        """Packaged modules should be able to use the importlib.resources API to access
183
        resources saved in the package.
184
        """
185
        mod_src = dedent(
186
            """\
187
            import importlib.resources
188
            import my_cool_resources
189

190
            def secret_message():
191
                return importlib.resources.read_text(my_cool_resources, 'sekrit.txt')
192
            """
193
        )
194
        filename = self.temp()
195
        with PackageExporter(filename) as pe:
196
            pe.save_source_string("foo.bar", mod_src)
197
            pe.save_text("my_cool_resources", "sekrit.txt", "my sekrit plays")
198

199
        zip_file = zipfile.ZipFile(filename, "r")
200

201
        with TemporaryDirectory() as temp_dir:
202
            zip_file.extractall(path=temp_dir)
203
            dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
204
            self.assertEqual(
205
                dir_importer.import_module("foo.bar").secret_message(),
206
                "my sekrit plays",
207
            )
208

209
    @skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7")
210
    def test_importer_access(self):
211
        filename = self.temp()
212
        with PackageExporter(filename) as he:
213
            he.save_text("main", "main", "my string")
214
            he.save_binary("main", "main_binary", b"my string")
215
            src = dedent(
216
                """\
217
                import importlib
218
                import torch_package_importer as resources
219

220
                t = resources.load_text('main', 'main')
221
                b = resources.load_binary('main', 'main_binary')
222
                """
223
            )
224
            he.save_source_string("main", src, is_package=True)
225

226
        zip_file = zipfile.ZipFile(filename, "r")
227

228
        with TemporaryDirectory() as temp_dir:
229
            zip_file.extractall(path=temp_dir)
230
            dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
231
            m = dir_importer.import_module("main")
232
            self.assertEqual(m.t, "my string")
233
            self.assertEqual(m.b, b"my string")
234

235
    @skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7")
236
    def test_resource_access_by_path(self):
237
        """
238
        Tests that packaged code can used importlib.resources.path.
239
        """
240
        filename = self.temp()
241
        with PackageExporter(filename) as e:
242
            e.save_binary("string_module", "my_string", b"my string")
243
            src = dedent(
244
                """\
245
                import importlib.resources
246
                import string_module
247

248
                with importlib.resources.path(string_module, 'my_string') as path:
249
                    with open(path, mode='r', encoding='utf-8') as f:
250
                        s = f.read()
251
                """
252
            )
253
            e.save_source_string("main", src, is_package=True)
254

255
        zip_file = zipfile.ZipFile(filename, "r")
256

257
        with TemporaryDirectory() as temp_dir:
258
            zip_file.extractall(path=temp_dir)
259
            dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
260
            m = dir_importer.import_module("main")
261
            self.assertEqual(m.s, "my string")
262

263
    def test_scriptobject_failure_message(self):
264
        """
265
        Test basic saving and loading of a ScriptModule in a directory.
266
        Currently not supported.
267
        """
268
        from package_a.test_module import ModWithTensor
269

270
        scripted_mod = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
271

272
        filename = self.temp()
273
        with PackageExporter(filename) as e:
274
            e.save_pickle("res", "mod.pkl", scripted_mod)
275

276
        zip_file = zipfile.ZipFile(filename, "r")
277

278
        with self.assertRaisesRegex(
279
            RuntimeError,
280
            "Loading ScriptObjects from a PackageImporter created from a "
281
            "directory is not supported. Use a package archive file instead.",
282
        ):
283
            with TemporaryDirectory() as temp_dir:
284
                zip_file.extractall(path=temp_dir)
285
                dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
286
                dir_mod = dir_importer.load_pickle("res", "mod.pkl")
287

288

289
if __name__ == "__main__":
290
    run_tests()
291

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.