7
from unittest.mock import patch
10
import torch.hub as hub
11
from torch.testing._internal.common_utils import IS_SANDCASTLE, retry, TestCase
14
def sum_of_state_dict(state_dict):
16
for v in state_dict.values():
21
SUM_OF_HUB_EXAMPLE = 431080
22
TORCHHUB_EXAMPLE_RELEASE_URL = (
23
"https://github.com/ailzhang/torchhub_example/releases/download/0.1/mnist_init_ones"
27
@unittest.skipIf(IS_SANDCASTLE, "Sandcastle cannot ping external")
28
class TestHub(TestCase):
31
self.previous_hub_dir = torch.hub.get_dir()
32
self.tmpdir = tempfile.TemporaryDirectory("hub_dir")
33
torch.hub.set_dir(self.tmpdir.name)
34
self.trusted_list_path = os.path.join(torch.hub.get_dir(), "trusted_list")
38
torch.hub.set_dir(self.previous_hub_dir)
41
def _assert_trusted_list_is_empty(self):
42
with open(self.trusted_list_path) as f:
43
assert not f.readlines()
45
def _assert_in_trusted_list(self, line):
46
with open(self.trusted_list_path) as f:
47
assert line in (l.strip() for l in f)
49
@retry(Exception, tries=3)
50
def test_load_from_github(self):
52
"ailzhang/torchhub_example",
58
self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE)
60
@retry(Exception, tries=3)
61
def test_load_from_local_dir(self):
62
local_dir = hub._get_cache_or_reload(
63
"ailzhang/torchhub_example",
69
local_dir, "mnist", source="local", pretrained=True, verbose=False
71
self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE)
73
@retry(Exception, tries=3)
74
def test_load_from_branch(self):
76
"ailzhang/torchhub_example:ci/test_slash",
81
self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE)
83
@retry(Exception, tries=3)
84
def test_get_set_dir(self):
85
previous_hub_dir = torch.hub.get_dir()
86
with tempfile.TemporaryDirectory("hub_dir") as tmpdir:
87
torch.hub.set_dir(tmpdir)
88
self.assertEqual(torch.hub.get_dir(), tmpdir)
89
self.assertNotEqual(previous_hub_dir, tmpdir)
92
"ailzhang/torchhub_example", "mnist", pretrained=True, verbose=False
95
sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE
97
assert os.path.exists(
98
os.path.join(tmpdir, "ailzhang_torchhub_example_master")
103
new_dir = os.path.join("~", "hub")
104
torch.hub.set_dir(new_dir)
105
self.assertEqual(torch.hub.get_dir(), os.path.expanduser(new_dir))
107
@retry(Exception, tries=3)
108
def test_list_entrypoints(self):
109
entry_lists = hub.list("ailzhang/torchhub_example", trust_repo=True)
110
self.assertObjectIn("mnist", entry_lists)
112
@retry(Exception, tries=3)
113
def test_download_url_to_file(self):
114
with tempfile.TemporaryDirectory() as tmpdir:
115
f = os.path.join(tmpdir, "temp")
116
hub.download_url_to_file(TORCHHUB_EXAMPLE_RELEASE_URL, f, progress=False)
117
loaded_state = torch.load(f)
118
self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE)
120
f_ref = os.path.join(tmpdir, "reference")
121
open(f_ref, "w").close()
122
expected_permissions = oct(os.stat(f_ref).st_mode & 0o777)
123
actual_permissions = oct(os.stat(f).st_mode & 0o777)
124
assert actual_permissions == expected_permissions
126
@retry(Exception, tries=3)
127
def test_load_state_dict_from_url(self):
128
loaded_state = hub.load_state_dict_from_url(TORCHHUB_EXAMPLE_RELEASE_URL)
129
self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE)
132
file_name = "the_file_name"
133
loaded_state = hub.load_state_dict_from_url(
134
TORCHHUB_EXAMPLE_RELEASE_URL, file_name=file_name
136
expected_file_path = os.path.join(torch.hub.get_dir(), "checkpoints", file_name)
137
self.assertTrue(os.path.exists(expected_file_path))
138
self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE)
141
loaded_state = hub.load_state_dict_from_url(
142
TORCHHUB_EXAMPLE_RELEASE_URL, weights_only=True
144
self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE)
146
@retry(Exception, tries=3)
147
def test_load_legacy_zip_checkpoint(self):
148
with warnings.catch_warnings(record=True) as ws:
149
warnings.simplefilter("always")
150
hub_model = hub.load(
151
"ailzhang/torchhub_example", "mnist_zip", pretrained=True, verbose=False
154
sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE
157
"will be deprecated in favor of default zipfile" in str(w) for w in ws
161
@retry(Exception, tries=3)
162
def test_load_zip_1_6_checkpoint(self):
163
hub_model = hub.load(
164
"ailzhang/torchhub_example",
170
self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE)
172
@retry(Exception, tries=3)
173
def test_hub_parse_repo_info(self):
175
self.assertEqual(torch.hub._parse_repo_info("a/b:c"), ("a", "b", "c"))
178
torch.hub._parse_repo_info("pytorch/vision"), ("pytorch", "vision", "main")
182
torch.hub._parse_repo_info("ailzhang/torchhub_example"),
183
("ailzhang", "torchhub_example", "master"),
186
@retry(Exception, tries=3)
187
def test_load_commit_from_forked_repo(self):
188
with self.assertRaisesRegex(ValueError, "If it's a commit from a forked repo"):
189
torch.hub.load("pytorch/vision:4e2c216", "resnet18")
191
@retry(Exception, tries=3)
192
@patch("builtins.input", return_value="")
193
def test_trust_repo_false_emptystring(self, patched_input):
194
with self.assertRaisesRegex(Exception, "Untrusted repository."):
196
"ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False
198
self._assert_trusted_list_is_empty()
199
patched_input.assert_called_once()
201
patched_input.reset_mock()
202
with self.assertRaisesRegex(Exception, "Untrusted repository."):
204
"ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False
206
self._assert_trusted_list_is_empty()
207
patched_input.assert_called_once()
209
@retry(Exception, tries=3)
210
@patch("builtins.input", return_value="no")
211
def test_trust_repo_false_no(self, patched_input):
212
with self.assertRaisesRegex(Exception, "Untrusted repository."):
214
"ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False
216
self._assert_trusted_list_is_empty()
217
patched_input.assert_called_once()
219
patched_input.reset_mock()
220
with self.assertRaisesRegex(Exception, "Untrusted repository."):
222
"ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False
224
self._assert_trusted_list_is_empty()
225
patched_input.assert_called_once()
227
@retry(Exception, tries=3)
228
@patch("builtins.input", return_value="y")
229
def test_trusted_repo_false_yes(self, patched_input):
230
torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False)
231
self._assert_in_trusted_list("ailzhang_torchhub_example")
232
patched_input.assert_called_once()
235
patched_input.reset_mock()
236
torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check")
237
patched_input.assert_not_called()
240
patched_input.reset_mock()
241
torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False)
242
patched_input.assert_called_once()
244
@retry(Exception, tries=3)
245
@patch("builtins.input", return_value="no")
246
def test_trust_repo_check_no(self, patched_input):
247
with self.assertRaisesRegex(Exception, "Untrusted repository."):
249
"ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check"
251
self._assert_trusted_list_is_empty()
252
patched_input.assert_called_once()
254
patched_input.reset_mock()
255
with self.assertRaisesRegex(Exception, "Untrusted repository."):
257
"ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check"
259
patched_input.assert_called_once()
261
@retry(Exception, tries=3)
262
@patch("builtins.input", return_value="y")
263
def test_trust_repo_check_yes(self, patched_input):
264
torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check")
265
self._assert_in_trusted_list("ailzhang_torchhub_example")
266
patched_input.assert_called_once()
269
patched_input.reset_mock()
270
torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check")
271
patched_input.assert_not_called()
273
@retry(Exception, tries=3)
274
def test_trust_repo_true(self):
275
torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=True)
276
self._assert_in_trusted_list("ailzhang_torchhub_example")
278
@retry(Exception, tries=3)
279
def test_trust_repo_builtin_trusted_owners(self):
280
torch.hub.load("pytorch/vision", "resnet18", trust_repo="check")
281
self._assert_trusted_list_is_empty()
283
@retry(Exception, tries=3)
284
def test_trust_repo_none(self):
285
with warnings.catch_warnings(record=True) as w:
286
warnings.simplefilter("always")
288
"ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=None
291
assert issubclass(w[-1].category, UserWarning)
293
"You are about to download and run code from an untrusted repository"
294
in str(w[-1].message)
297
self._assert_trusted_list_is_empty()
299
@retry(Exception, tries=3)
300
def test_trust_repo_legacy(self):
304
torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=True)
305
os.remove(self.trusted_list_path)
307
torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check")
309
self._assert_trusted_list_is_empty()