vision

Форк
0
/
test_datasets_download.py 
388 строк · 11.4 Кб
1
import contextlib
2
import itertools
3
import shutil
4
import tempfile
5
import time
6
import traceback
7
import unittest.mock
8
import warnings
9
from datetime import datetime
10
from os import path
11
from urllib.error import HTTPError, URLError
12
from urllib.parse import urlparse
13
from urllib.request import Request, urlopen
14

15
import pytest
16
from torchvision import datasets
17
from torchvision.datasets.utils import _get_redirect_url, USER_AGENT
18

19

20
def limit_requests_per_time(min_secs_between_requests=2.0):
21
    last_requests = {}
22

23
    def outer_wrapper(fn):
24
        def inner_wrapper(request, *args, **kwargs):
25
            url = request.full_url if isinstance(request, Request) else request
26

27
            netloc = urlparse(url).netloc
28
            last_request = last_requests.get(netloc)
29
            if last_request is not None:
30
                elapsed_secs = (datetime.now() - last_request).total_seconds()
31
                delta = min_secs_between_requests - elapsed_secs
32
                if delta > 0:
33
                    time.sleep(delta)
34

35
            response = fn(request, *args, **kwargs)
36
            last_requests[netloc] = datetime.now()
37

38
            return response
39

40
        return inner_wrapper
41

42
    return outer_wrapper
43

44

45
urlopen = limit_requests_per_time()(urlopen)
46

47

48
def resolve_redirects(max_hops=3):
49
    def outer_wrapper(fn):
50
        def inner_wrapper(request, *args, **kwargs):
51
            initial_url = request.full_url if isinstance(request, Request) else request
52
            url = _get_redirect_url(initial_url, max_hops=max_hops)
53

54
            if url == initial_url:
55
                return fn(request, *args, **kwargs)
56

57
            warnings.warn(f"The URL {initial_url} ultimately redirects to {url}.")
58

59
            if not isinstance(request, Request):
60
                return fn(url, *args, **kwargs)
61

62
            request_attrs = {
63
                attr: getattr(request, attr) for attr in ("data", "headers", "origin_req_host", "unverifiable")
64
            }
65
            # the 'method' attribute does only exist if the request was created with it
66
            if hasattr(request, "method"):
67
                request_attrs["method"] = request.method
68

69
            return fn(Request(url, **request_attrs), *args, **kwargs)
70

71
        return inner_wrapper
72

73
    return outer_wrapper
74

75

76
urlopen = resolve_redirects()(urlopen)
77

78

79
@contextlib.contextmanager
80
def log_download_attempts(
81
    urls,
82
    *,
83
    dataset_module,
84
):
85
    def maybe_add_mock(*, module, name, stack, lst=None):
86
        patcher = unittest.mock.patch(f"torchvision.datasets.{module}.{name}")
87

88
        try:
89
            mock = stack.enter_context(patcher)
90
        except AttributeError:
91
            return
92

93
        if lst is not None:
94
            lst.append(mock)
95

96
    with contextlib.ExitStack() as stack:
97
        download_url_mocks = []
98
        download_file_from_google_drive_mocks = []
99
        for module in [dataset_module, "utils"]:
100
            maybe_add_mock(module=module, name="download_url", stack=stack, lst=download_url_mocks)
101
            maybe_add_mock(
102
                module=module,
103
                name="download_file_from_google_drive",
104
                stack=stack,
105
                lst=download_file_from_google_drive_mocks,
106
            )
107
            maybe_add_mock(module=module, name="extract_archive", stack=stack)
108

109
        try:
110
            yield
111
        finally:
112
            for download_url_mock in download_url_mocks:
113
                for args, kwargs in download_url_mock.call_args_list:
114
                    urls.append(args[0] if args else kwargs["url"])
115

116
            for download_file_from_google_drive_mock in download_file_from_google_drive_mocks:
117
                for args, kwargs in download_file_from_google_drive_mock.call_args_list:
118
                    file_id = args[0] if args else kwargs["file_id"]
119
                    urls.append(f"https://drive.google.com/file/d/{file_id}")
120

121

122
def retry(fn, times=1, wait=5.0):
123
    tbs = []
124
    for _ in range(times + 1):
125
        try:
126
            return fn()
127
        except AssertionError as error:
128
            tbs.append("".join(traceback.format_exception(type(error), error, error.__traceback__)))
129
            time.sleep(wait)
130
    else:
131
        raise AssertionError(
132
            "\n".join(
133
                (
134
                    "\n",
135
                    *[f"{'_' * 40}  {idx:2d}  {'_' * 40}\n\n{tb}" for idx, tb in enumerate(tbs, 1)],
136
                    (
137
                        f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time. "
138
                        f"You can find the the full tracebacks above."
139
                    ),
140
                )
141
            )
142
        )
143

144

145
@contextlib.contextmanager
146
def assert_server_response_ok():
147
    try:
148
        yield
149
    except HTTPError as error:
150
        raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error
151
    except URLError as error:
152
        raise AssertionError(
153
            "Connection not possible due to SSL." if "SSL" in str(error) else "The request timed out."
154
        ) from error
155
    except RecursionError as error:
156
        raise AssertionError(str(error)) from error
157

158

159
def assert_url_is_accessible(url, timeout=5.0):
160
    request = Request(url, headers={"User-Agent": USER_AGENT}, method="HEAD")
161
    with assert_server_response_ok():
162
        urlopen(request, timeout=timeout)
163

164

165
def collect_urls(dataset_cls, *args, **kwargs):
166
    urls = []
167
    with contextlib.suppress(Exception), log_download_attempts(
168
        urls, dataset_module=dataset_cls.__module__.split(".")[-1]
169
    ):
170
        dataset_cls(*args, **kwargs)
171

172
    return [(url, f"{dataset_cls.__name__}, {url}") for url in urls]
173

174

175
# This is a workaround since fixtures, such as the built-in tmp_dir, can only be used within a test but not within a
176
# parametrization. Thus, we use a single root directory for all datasets and remove it when all download tests are run.
177
ROOT = tempfile.mkdtemp()
178

179

180
@pytest.fixture(scope="module", autouse=True)
181
def root():
182
    yield ROOT
183
    shutil.rmtree(ROOT)
184

185

186
def places365():
187
    return itertools.chain.from_iterable(
188
        [
189
            collect_urls(
190
                datasets.Places365,
191
                ROOT,
192
                split=split,
193
                small=small,
194
                download=True,
195
            )
196
            for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True))
197
        ]
198
    )
199

200

201
def caltech101():
202
    return collect_urls(datasets.Caltech101, ROOT, download=True)
203

204

205
def caltech256():
206
    return collect_urls(datasets.Caltech256, ROOT, download=True)
207

208

209
def cifar10():
210
    return collect_urls(datasets.CIFAR10, ROOT, download=True)
211

212

213
def cifar100():
214
    return collect_urls(datasets.CIFAR100, ROOT, download=True)
215

216

217
def voc():
218
    # TODO: Also test the "2007-test" key
219
    return itertools.chain.from_iterable(
220
        [
221
            collect_urls(datasets.VOCSegmentation, ROOT, year=year, download=True)
222
            for year in ("2007", "2008", "2009", "2010", "2011", "2012")
223
        ]
224
    )
225

226

227
def mnist():
228
    with unittest.mock.patch.object(datasets.MNIST, "mirrors", datasets.MNIST.mirrors[-1:]):
229
        return collect_urls(datasets.MNIST, ROOT, download=True)
230

231

232
def fashion_mnist():
233
    return collect_urls(datasets.FashionMNIST, ROOT, download=True)
234

235

236
def kmnist():
237
    return collect_urls(datasets.KMNIST, ROOT, download=True)
238

239

240
def emnist():
241
    # the 'split' argument can be any valid one, since everything is downloaded anyway
242
    return collect_urls(datasets.EMNIST, ROOT, split="byclass", download=True)
243

244

245
def qmnist():
246
    return itertools.chain.from_iterable(
247
        [collect_urls(datasets.QMNIST, ROOT, what=what, download=True) for what in ("train", "test", "nist")]
248
    )
249

250

251
def moving_mnist():
252
    return collect_urls(datasets.MovingMNIST, ROOT, download=True)
253

254

255
def omniglot():
256
    return itertools.chain.from_iterable(
257
        [collect_urls(datasets.Omniglot, ROOT, background=background, download=True) for background in (True, False)]
258
    )
259

260

261
def phototour():
262
    return itertools.chain.from_iterable(
263
        [
264
            collect_urls(datasets.PhotoTour, ROOT, name=name, download=True)
265
            # The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all
266
            # requests timeout from within CI. They are disabled until this is resolved.
267
            for name in ("notredame", "yosemite", "liberty")  # "notredame_harris", "yosemite_harris", "liberty_harris"
268
        ]
269
    )
270

271

272
def sbdataset():
273
    return collect_urls(datasets.SBDataset, ROOT, download=True)
274

275

276
def sbu():
277
    return collect_urls(datasets.SBU, ROOT, download=True)
278

279

280
def semeion():
281
    return collect_urls(datasets.SEMEION, ROOT, download=True)
282

283

284
def stl10():
285
    return collect_urls(datasets.STL10, ROOT, download=True)
286

287

288
def svhn():
289
    return itertools.chain.from_iterable(
290
        [collect_urls(datasets.SVHN, ROOT, split=split, download=True) for split in ("train", "test", "extra")]
291
    )
292

293

294
def usps():
295
    return itertools.chain.from_iterable(
296
        [collect_urls(datasets.USPS, ROOT, train=train, download=True) for train in (True, False)]
297
    )
298

299

300
def celeba():
301
    return collect_urls(datasets.CelebA, ROOT, download=True)
302

303

304
def widerface():
305
    return collect_urls(datasets.WIDERFace, ROOT, download=True)
306

307

308
def kinetics():
309
    return itertools.chain.from_iterable(
310
        [
311
            collect_urls(
312
                datasets.Kinetics,
313
                path.join(ROOT, f"Kinetics{num_classes}"),
314
                frames_per_clip=1,
315
                num_classes=num_classes,
316
                split=split,
317
                download=True,
318
            )
319
            for num_classes, split in itertools.product(("400", "600", "700"), ("train", "val"))
320
        ]
321
    )
322

323

324
def kitti():
325
    return itertools.chain.from_iterable(
326
        [collect_urls(datasets.Kitti, ROOT, train=train, download=True) for train in (True, False)]
327
    )
328

329

330
def url_parametrization(*dataset_urls_and_ids_fns):
331
    return pytest.mark.parametrize(
332
        "url",
333
        [
334
            pytest.param(url, id=id)
335
            for dataset_urls_and_ids_fn in dataset_urls_and_ids_fns
336
            for url, id in sorted(set(dataset_urls_and_ids_fn()))
337
        ],
338
    )
339

340

341
@url_parametrization(
342
    caltech101,
343
    caltech256,
344
    cifar10,
345
    cifar100,
346
    # The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details.
347
    # voc,
348
    mnist,
349
    fashion_mnist,
350
    kmnist,
351
    emnist,
352
    qmnist,
353
    omniglot,
354
    phototour,
355
    sbdataset,
356
    semeion,
357
    stl10,
358
    svhn,
359
    usps,
360
    celeba,
361
    widerface,
362
    kinetics,
363
    kitti,
364
    places365,
365
    sbu,
366
)
367
def test_url_is_accessible(url):
368
    """
369
    If you see this test failing, find the offending dataset in the parametrization and move it to
370
    ``test_url_is_not_accessible`` and link an issue detailing the problem.
371
    """
372
    retry(lambda: assert_url_is_accessible(url))
373

374

375
# TODO: if e.g. caltech101 starts failing, remove the pytest.mark.parametrize below and use
376
# @url_parametrization(caltech101)
377
@pytest.mark.parametrize("url", ("http://url_that_doesnt_exist.com",))  # here until we actually have a failing dataset
378
@pytest.mark.xfail
379
def test_url_is_not_accessible(url):
380
    """
381
    As the name implies, this test is the 'inverse' of ``test_url_is_accessible``. Since the download servers are
382
    beyond our control, some files might not be accessible for longer stretches of time. Still, we want to know if they
383
    come back up, or if we need to remove the download functionality of the dataset for good.
384

385
    If you see this test failing, find the offending dataset in the parametrization and move it to
386
    ``test_url_is_accessible``.
387
    """
388
    assert_url_is_accessible(url)
389

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.