4
from sys import version_info
5
from textwrap import dedent
6
from unittest import skipIf
8
from torch.package import PackageExporter, PackageImporter
9
from torch.testing._internal.common_utils import run_tests
13
from .common import PackageTestCase
16
from common import PackageTestCase
19
@skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7")
20
class TestResources(PackageTestCase):
21
"""Tests for access APIs for packaged resources."""
23
def test_resource_reader(self):
24
"""Test compliance with the get_resource_reader importlib API."""
26
with PackageExporter(buffer) as pe:
39
pe.save_text("one", "a.txt", "hello, a!")
40
pe.save_text("one", "b.txt", "hello, b!")
41
pe.save_text("one", "c.txt", "hello, c!")
43
pe.save_text("one.three", "d.txt", "hello, d!")
44
pe.save_text("one.three", "e.txt", "hello, e!")
46
pe.save_text("two", "f.txt", "hello, f!")
47
pe.save_text("two", "g.txt", "hello, g!")
50
importer = PackageImporter(buffer)
52
reader_one = importer.get_resource_reader("one")
53
with self.assertRaises(FileNotFoundError):
54
reader_one.resource_path("a.txt")
56
self.assertTrue(reader_one.is_resource("a.txt"))
57
self.assertEqual(reader_one.open_resource("a.txt").getbuffer(), b"hello, a!")
58
self.assertFalse(reader_one.is_resource("three"))
59
reader_one_contents = list(reader_one.contents())
60
self.assertSequenceEqual(
61
reader_one_contents, ["a.txt", "b.txt", "c.txt", "three"]
64
reader_two = importer.get_resource_reader("two")
65
self.assertTrue(reader_two.is_resource("f.txt"))
66
self.assertEqual(reader_two.open_resource("f.txt").getbuffer(), b"hello, f!")
67
reader_two_contents = list(reader_two.contents())
68
self.assertSequenceEqual(reader_two_contents, ["f.txt", "g.txt"])
70
reader_one_three = importer.get_resource_reader("one.three")
71
self.assertTrue(reader_one_three.is_resource("d.txt"))
73
reader_one_three.open_resource("d.txt").getbuffer(), b"hello, d!"
75
reader_one_three_contenst = list(reader_one_three.contents())
76
self.assertSequenceEqual(reader_one_three_contenst, ["d.txt", "e.txt"])
78
self.assertIsNone(importer.get_resource_reader("nonexistent_package"))
80
def test_package_resource_access(self):
81
"""Packaged modules should be able to use the importlib.resources API to access
82
resources saved in the package.
86
import importlib.resources
87
import my_cool_resources
90
return importlib.resources.read_text(my_cool_resources, 'sekrit.txt')
94
with PackageExporter(buffer) as pe:
95
pe.save_source_string("foo.bar", mod_src)
96
pe.save_text("my_cool_resources", "sekrit.txt", "my sekrit plays")
99
importer = PackageImporter(buffer)
101
importer.import_module("foo.bar").secret_message(), "my sekrit plays"
104
def test_importer_access(self):
106
with PackageExporter(buffer) as he:
107
he.save_text("main", "main", "my string")
108
he.save_binary("main", "main_binary", b"my string")
112
import torch_package_importer as resources
114
t = resources.load_text('main', 'main')
115
b = resources.load_binary('main', 'main_binary')
118
he.save_source_string("main", src, is_package=True)
120
hi = PackageImporter(buffer)
121
m = hi.import_module("main")
122
self.assertEqual(m.t, "my string")
123
self.assertEqual(m.b, b"my string")
125
def test_resource_access_by_path(self):
127
Tests that packaged code can used importlib.resources.path.
130
with PackageExporter(buffer) as he:
131
he.save_binary("string_module", "my_string", b"my string")
134
import importlib.resources
137
with importlib.resources.path(string_module, 'my_string') as path:
138
with open(path, mode='r', encoding='utf-8') as f:
142
he.save_source_string("main", src, is_package=True)
144
hi = PackageImporter(buffer)
145
m = hi.import_module("main")
146
self.assertEqual(m.s, "my string")
149
if __name__ == "__main__":