pytorch

Форк
0
/
test_hub.py 
309 строк · 12.0 Кб
1
# Owner(s): ["module: hub"]
2

3
import os
4
import tempfile
5
import unittest
6
import warnings
7
from unittest.mock import patch
8

9
import torch
10
import torch.hub as hub
11
from torch.testing._internal.common_utils import IS_SANDCASTLE, retry, TestCase
12

13

14
def sum_of_state_dict(state_dict):
15
    s = 0
16
    for v in state_dict.values():
17
        s += v.sum()
18
    return s
19

20

21
SUM_OF_HUB_EXAMPLE = 431080
22
TORCHHUB_EXAMPLE_RELEASE_URL = (
23
    "https://github.com/ailzhang/torchhub_example/releases/download/0.1/mnist_init_ones"
24
)
25

26

27
@unittest.skipIf(IS_SANDCASTLE, "Sandcastle cannot ping external")
28
class TestHub(TestCase):
29
    def setUp(self):
30
        super().setUp()
31
        self.previous_hub_dir = torch.hub.get_dir()
32
        self.tmpdir = tempfile.TemporaryDirectory("hub_dir")
33
        torch.hub.set_dir(self.tmpdir.name)
34
        self.trusted_list_path = os.path.join(torch.hub.get_dir(), "trusted_list")
35

36
    def tearDown(self):
37
        super().tearDown()
38
        torch.hub.set_dir(self.previous_hub_dir)  # probably not needed, but can't hurt
39
        self.tmpdir.cleanup()
40

41
    def _assert_trusted_list_is_empty(self):
42
        with open(self.trusted_list_path) as f:
43
            assert not f.readlines()
44

45
    def _assert_in_trusted_list(self, line):
46
        with open(self.trusted_list_path) as f:
47
            assert line in (l.strip() for l in f)
48

49
    @retry(Exception, tries=3)
50
    def test_load_from_github(self):
51
        hub_model = hub.load(
52
            "ailzhang/torchhub_example",
53
            "mnist",
54
            source="github",
55
            pretrained=True,
56
            verbose=False,
57
        )
58
        self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE)
59

60
    @retry(Exception, tries=3)
61
    def test_load_from_local_dir(self):
62
        local_dir = hub._get_cache_or_reload(
63
            "ailzhang/torchhub_example",
64
            force_reload=False,
65
            trust_repo=True,
66
            calling_fn=None,
67
        )
68
        hub_model = hub.load(
69
            local_dir, "mnist", source="local", pretrained=True, verbose=False
70
        )
71
        self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE)
72

73
    @retry(Exception, tries=3)
74
    def test_load_from_branch(self):
75
        hub_model = hub.load(
76
            "ailzhang/torchhub_example:ci/test_slash",
77
            "mnist",
78
            pretrained=True,
79
            verbose=False,
80
        )
81
        self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE)
82

83
    @retry(Exception, tries=3)
84
    def test_get_set_dir(self):
85
        previous_hub_dir = torch.hub.get_dir()
86
        with tempfile.TemporaryDirectory("hub_dir") as tmpdir:
87
            torch.hub.set_dir(tmpdir)
88
            self.assertEqual(torch.hub.get_dir(), tmpdir)
89
            self.assertNotEqual(previous_hub_dir, tmpdir)
90

91
            hub_model = hub.load(
92
                "ailzhang/torchhub_example", "mnist", pretrained=True, verbose=False
93
            )
94
            self.assertEqual(
95
                sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE
96
            )
97
            assert os.path.exists(
98
                os.path.join(tmpdir, "ailzhang_torchhub_example_master")
99
            )
100

101
        # Test that set_dir properly calls expanduser()
102
        # non-regression test for https://github.com/pytorch/pytorch/issues/69761
103
        new_dir = os.path.join("~", "hub")
104
        torch.hub.set_dir(new_dir)
105
        self.assertEqual(torch.hub.get_dir(), os.path.expanduser(new_dir))
106

107
    @retry(Exception, tries=3)
108
    def test_list_entrypoints(self):
109
        entry_lists = hub.list("ailzhang/torchhub_example", trust_repo=True)
110
        self.assertObjectIn("mnist", entry_lists)
111

112
    @retry(Exception, tries=3)
113
    def test_download_url_to_file(self):
114
        with tempfile.TemporaryDirectory() as tmpdir:
115
            f = os.path.join(tmpdir, "temp")
116
            hub.download_url_to_file(TORCHHUB_EXAMPLE_RELEASE_URL, f, progress=False)
117
            loaded_state = torch.load(f)
118
            self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE)
119
            # Check that the downloaded file has default file permissions
120
            f_ref = os.path.join(tmpdir, "reference")
121
            open(f_ref, "w").close()
122
            expected_permissions = oct(os.stat(f_ref).st_mode & 0o777)
123
            actual_permissions = oct(os.stat(f).st_mode & 0o777)
124
            assert actual_permissions == expected_permissions
125

126
    @retry(Exception, tries=3)
127
    def test_load_state_dict_from_url(self):
128
        loaded_state = hub.load_state_dict_from_url(TORCHHUB_EXAMPLE_RELEASE_URL)
129
        self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE)
130

131
        # with name
132
        file_name = "the_file_name"
133
        loaded_state = hub.load_state_dict_from_url(
134
            TORCHHUB_EXAMPLE_RELEASE_URL, file_name=file_name
135
        )
136
        expected_file_path = os.path.join(torch.hub.get_dir(), "checkpoints", file_name)
137
        self.assertTrue(os.path.exists(expected_file_path))
138
        self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE)
139

140
        # with safe weight_only
141
        loaded_state = hub.load_state_dict_from_url(
142
            TORCHHUB_EXAMPLE_RELEASE_URL, weights_only=True
143
        )
144
        self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE)
145

146
    @retry(Exception, tries=3)
147
    def test_load_legacy_zip_checkpoint(self):
148
        with warnings.catch_warnings(record=True) as ws:
149
            warnings.simplefilter("always")
150
            hub_model = hub.load(
151
                "ailzhang/torchhub_example", "mnist_zip", pretrained=True, verbose=False
152
            )
153
            self.assertEqual(
154
                sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE
155
            )
156
            assert any(
157
                "will be deprecated in favor of default zipfile" in str(w) for w in ws
158
            )
159

160
    # Test the default zipfile serialization format produced by >=1.6 release.
161
    @retry(Exception, tries=3)
162
    def test_load_zip_1_6_checkpoint(self):
163
        hub_model = hub.load(
164
            "ailzhang/torchhub_example",
165
            "mnist_zip_1_6",
166
            pretrained=True,
167
            verbose=False,
168
            trust_repo=True,
169
        )
170
        self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE)
171

172
    @retry(Exception, tries=3)
173
    def test_hub_parse_repo_info(self):
174
        # If the branch is specified we just parse the input and return
175
        self.assertEqual(torch.hub._parse_repo_info("a/b:c"), ("a", "b", "c"))
176
        # For torchvision, the default branch is main
177
        self.assertEqual(
178
            torch.hub._parse_repo_info("pytorch/vision"), ("pytorch", "vision", "main")
179
        )
180
        # For the torchhub_example repo, the default branch is still master
181
        self.assertEqual(
182
            torch.hub._parse_repo_info("ailzhang/torchhub_example"),
183
            ("ailzhang", "torchhub_example", "master"),
184
        )
185

186
    @retry(Exception, tries=3)
187
    def test_load_commit_from_forked_repo(self):
188
        with self.assertRaisesRegex(ValueError, "If it's a commit from a forked repo"):
189
            torch.hub.load("pytorch/vision:4e2c216", "resnet18")
190

191
    @retry(Exception, tries=3)
192
    @patch("builtins.input", return_value="")
193
    def test_trust_repo_false_emptystring(self, patched_input):
194
        with self.assertRaisesRegex(Exception, "Untrusted repository."):
195
            torch.hub.load(
196
                "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False
197
            )
198
        self._assert_trusted_list_is_empty()
199
        patched_input.assert_called_once()
200

201
        patched_input.reset_mock()
202
        with self.assertRaisesRegex(Exception, "Untrusted repository."):
203
            torch.hub.load(
204
                "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False
205
            )
206
        self._assert_trusted_list_is_empty()
207
        patched_input.assert_called_once()
208

209
    @retry(Exception, tries=3)
210
    @patch("builtins.input", return_value="no")
211
    def test_trust_repo_false_no(self, patched_input):
212
        with self.assertRaisesRegex(Exception, "Untrusted repository."):
213
            torch.hub.load(
214
                "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False
215
            )
216
        self._assert_trusted_list_is_empty()
217
        patched_input.assert_called_once()
218

219
        patched_input.reset_mock()
220
        with self.assertRaisesRegex(Exception, "Untrusted repository."):
221
            torch.hub.load(
222
                "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False
223
            )
224
        self._assert_trusted_list_is_empty()
225
        patched_input.assert_called_once()
226

227
    @retry(Exception, tries=3)
228
    @patch("builtins.input", return_value="y")
229
    def test_trusted_repo_false_yes(self, patched_input):
230
        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False)
231
        self._assert_in_trusted_list("ailzhang_torchhub_example")
232
        patched_input.assert_called_once()
233

234
        # Loading a second time with "check", we don't ask for user input
235
        patched_input.reset_mock()
236
        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check")
237
        patched_input.assert_not_called()
238

239
        # Loading again with False, we still ask for user input
240
        patched_input.reset_mock()
241
        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False)
242
        patched_input.assert_called_once()
243

244
    @retry(Exception, tries=3)
245
    @patch("builtins.input", return_value="no")
246
    def test_trust_repo_check_no(self, patched_input):
247
        with self.assertRaisesRegex(Exception, "Untrusted repository."):
248
            torch.hub.load(
249
                "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check"
250
            )
251
        self._assert_trusted_list_is_empty()
252
        patched_input.assert_called_once()
253

254
        patched_input.reset_mock()
255
        with self.assertRaisesRegex(Exception, "Untrusted repository."):
256
            torch.hub.load(
257
                "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check"
258
            )
259
        patched_input.assert_called_once()
260

261
    @retry(Exception, tries=3)
262
    @patch("builtins.input", return_value="y")
263
    def test_trust_repo_check_yes(self, patched_input):
264
        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check")
265
        self._assert_in_trusted_list("ailzhang_torchhub_example")
266
        patched_input.assert_called_once()
267

268
        # Loading a second time with "check", we don't ask for user input
269
        patched_input.reset_mock()
270
        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check")
271
        patched_input.assert_not_called()
272

273
    @retry(Exception, tries=3)
274
    def test_trust_repo_true(self):
275
        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=True)
276
        self._assert_in_trusted_list("ailzhang_torchhub_example")
277

278
    @retry(Exception, tries=3)
279
    def test_trust_repo_builtin_trusted_owners(self):
280
        torch.hub.load("pytorch/vision", "resnet18", trust_repo="check")
281
        self._assert_trusted_list_is_empty()
282

283
    @retry(Exception, tries=3)
284
    def test_trust_repo_none(self):
285
        with warnings.catch_warnings(record=True) as w:
286
            warnings.simplefilter("always")
287
            torch.hub.load(
288
                "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=None
289
            )
290
            assert len(w) == 1
291
            assert issubclass(w[-1].category, UserWarning)
292
            assert (
293
                "You are about to download and run code from an untrusted repository"
294
                in str(w[-1].message)
295
            )
296

297
        self._assert_trusted_list_is_empty()
298

299
    @retry(Exception, tries=3)
300
    def test_trust_repo_legacy(self):
301
        # We first download a repo and then delete the allowlist file
302
        # Then we check that the repo is indeed trusted without a prompt,
303
        # because it was already downloaded in the past.
304
        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=True)
305
        os.remove(self.trusted_list_path)
306

307
        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check")
308

309
        self._assert_trusted_list_is_empty()
310

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.