9
from datetime import datetime
11
from urllib.error import HTTPError, URLError
12
from urllib.parse import urlparse
13
from urllib.request import Request, urlopen
16
from torchvision import datasets
17
from torchvision.datasets.utils import _get_redirect_url, USER_AGENT
20
def limit_requests_per_time(min_secs_between_requests=2.0):
23
def outer_wrapper(fn):
24
def inner_wrapper(request, *args, **kwargs):
25
url = request.full_url if isinstance(request, Request) else request
27
netloc = urlparse(url).netloc
28
last_request = last_requests.get(netloc)
29
if last_request is not None:
30
elapsed_secs = (datetime.now() - last_request).total_seconds()
31
delta = min_secs_between_requests - elapsed_secs
35
response = fn(request, *args, **kwargs)
36
last_requests[netloc] = datetime.now()
45
urlopen = limit_requests_per_time()(urlopen)
48
def resolve_redirects(max_hops=3):
49
def outer_wrapper(fn):
50
def inner_wrapper(request, *args, **kwargs):
51
initial_url = request.full_url if isinstance(request, Request) else request
52
url = _get_redirect_url(initial_url, max_hops=max_hops)
54
if url == initial_url:
55
return fn(request, *args, **kwargs)
57
warnings.warn(f"The URL {initial_url} ultimately redirects to {url}.")
59
if not isinstance(request, Request):
60
return fn(url, *args, **kwargs)
63
attr: getattr(request, attr) for attr in ("data", "headers", "origin_req_host", "unverifiable")
65
# the 'method' attribute does only exist if the request was created with it
66
if hasattr(request, "method"):
67
request_attrs["method"] = request.method
69
return fn(Request(url, **request_attrs), *args, **kwargs)
76
urlopen = resolve_redirects()(urlopen)
79
@contextlib.contextmanager
80
def log_download_attempts(
85
def maybe_add_mock(*, module, name, stack, lst=None):
86
patcher = unittest.mock.patch(f"torchvision.datasets.{module}.{name}")
89
mock = stack.enter_context(patcher)
90
except AttributeError:
96
with contextlib.ExitStack() as stack:
97
download_url_mocks = []
98
download_file_from_google_drive_mocks = []
99
for module in [dataset_module, "utils"]:
100
maybe_add_mock(module=module, name="download_url", stack=stack, lst=download_url_mocks)
103
name="download_file_from_google_drive",
105
lst=download_file_from_google_drive_mocks,
107
maybe_add_mock(module=module, name="extract_archive", stack=stack)
112
for download_url_mock in download_url_mocks:
113
for args, kwargs in download_url_mock.call_args_list:
114
urls.append(args[0] if args else kwargs["url"])
116
for download_file_from_google_drive_mock in download_file_from_google_drive_mocks:
117
for args, kwargs in download_file_from_google_drive_mock.call_args_list:
118
file_id = args[0] if args else kwargs["file_id"]
119
urls.append(f"https://drive.google.com/file/d/{file_id}")
122
def retry(fn, times=1, wait=5.0):
124
for _ in range(times + 1):
127
except AssertionError as error:
128
tbs.append("".join(traceback.format_exception(type(error), error, error.__traceback__)))
131
raise AssertionError(
135
*[f"{'_' * 40} {idx:2d} {'_' * 40}\n\n{tb}" for idx, tb in enumerate(tbs, 1)],
137
f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time. "
138
f"You can find the the full tracebacks above."
145
@contextlib.contextmanager
146
def assert_server_response_ok():
149
except HTTPError as error:
150
raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error
151
except URLError as error:
152
raise AssertionError(
153
"Connection not possible due to SSL." if "SSL" in str(error) else "The request timed out."
155
except RecursionError as error:
156
raise AssertionError(str(error)) from error
159
def assert_url_is_accessible(url, timeout=5.0):
160
request = Request(url, headers={"User-Agent": USER_AGENT}, method="HEAD")
161
with assert_server_response_ok():
162
urlopen(request, timeout=timeout)
165
def collect_urls(dataset_cls, *args, **kwargs):
167
with contextlib.suppress(Exception), log_download_attempts(
168
urls, dataset_module=dataset_cls.__module__.split(".")[-1]
170
dataset_cls(*args, **kwargs)
172
return [(url, f"{dataset_cls.__name__}, {url}") for url in urls]
175
# This is a workaround since fixtures, such as the built-in tmp_dir, can only be used within a test but not within a
176
# parametrization. Thus, we use a single root directory for all datasets and remove it when all download tests are run.
177
ROOT = tempfile.mkdtemp()
180
@pytest.fixture(scope="module", autouse=True)
187
return itertools.chain.from_iterable(
196
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True))
202
return collect_urls(datasets.Caltech101, ROOT, download=True)
206
return collect_urls(datasets.Caltech256, ROOT, download=True)
210
return collect_urls(datasets.CIFAR10, ROOT, download=True)
214
return collect_urls(datasets.CIFAR100, ROOT, download=True)
218
# TODO: Also test the "2007-test" key
219
return itertools.chain.from_iterable(
221
collect_urls(datasets.VOCSegmentation, ROOT, year=year, download=True)
222
for year in ("2007", "2008", "2009", "2010", "2011", "2012")
228
with unittest.mock.patch.object(datasets.MNIST, "mirrors", datasets.MNIST.mirrors[-1:]):
229
return collect_urls(datasets.MNIST, ROOT, download=True)
233
return collect_urls(datasets.FashionMNIST, ROOT, download=True)
237
return collect_urls(datasets.KMNIST, ROOT, download=True)
241
# the 'split' argument can be any valid one, since everything is downloaded anyway
242
return collect_urls(datasets.EMNIST, ROOT, split="byclass", download=True)
246
return itertools.chain.from_iterable(
247
[collect_urls(datasets.QMNIST, ROOT, what=what, download=True) for what in ("train", "test", "nist")]
252
return collect_urls(datasets.MovingMNIST, ROOT, download=True)
256
return itertools.chain.from_iterable(
257
[collect_urls(datasets.Omniglot, ROOT, background=background, download=True) for background in (True, False)]
262
return itertools.chain.from_iterable(
264
collect_urls(datasets.PhotoTour, ROOT, name=name, download=True)
265
# The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all
266
# requests timeout from within CI. They are disabled until this is resolved.
267
for name in ("notredame", "yosemite", "liberty") # "notredame_harris", "yosemite_harris", "liberty_harris"
273
return collect_urls(datasets.SBDataset, ROOT, download=True)
277
return collect_urls(datasets.SBU, ROOT, download=True)
281
return collect_urls(datasets.SEMEION, ROOT, download=True)
285
return collect_urls(datasets.STL10, ROOT, download=True)
289
return itertools.chain.from_iterable(
290
[collect_urls(datasets.SVHN, ROOT, split=split, download=True) for split in ("train", "test", "extra")]
295
return itertools.chain.from_iterable(
296
[collect_urls(datasets.USPS, ROOT, train=train, download=True) for train in (True, False)]
301
return collect_urls(datasets.CelebA, ROOT, download=True)
305
return collect_urls(datasets.WIDERFace, ROOT, download=True)
309
return itertools.chain.from_iterable(
313
path.join(ROOT, f"Kinetics{num_classes}"),
315
num_classes=num_classes,
319
for num_classes, split in itertools.product(("400", "600", "700"), ("train", "val"))
325
return itertools.chain.from_iterable(
326
[collect_urls(datasets.Kitti, ROOT, train=train, download=True) for train in (True, False)]
330
def url_parametrization(*dataset_urls_and_ids_fns):
331
return pytest.mark.parametrize(
334
pytest.param(url, id=id)
335
for dataset_urls_and_ids_fn in dataset_urls_and_ids_fns
336
for url, id in sorted(set(dataset_urls_and_ids_fn()))
346
# The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details.
367
def test_url_is_accessible(url):
369
If you see this test failing, find the offending dataset in the parametrization and move it to
370
``test_url_is_not_accessible`` and link an issue detailing the problem.
372
retry(lambda: assert_url_is_accessible(url))
375
# TODO: if e.g. caltech101 starts failing, remove the pytest.mark.parametrize below and use
376
# @url_parametrization(caltech101)
377
@pytest.mark.parametrize("url", ("http://url_that_doesnt_exist.com",)) # here until we actually have a failing dataset
379
def test_url_is_not_accessible(url):
381
As the name implies, this test is the 'inverse' of ``test_url_is_accessible``. Since the download servers are
382
beyond our control, some files might not be accessible for longer stretches of time. Still, we want to know if they
383
come back up, or if we need to remove the download functionality of the dataset for good.
385
If you see this test failing, find the offending dataset in the parametrization and move it to
386
``test_url_is_accessible``.
388
assert_url_is_accessible(url)