3
from pathlib import Path
4
from time import perf_counter
5
from urllib.parse import urlsplit
9
from torchvision import models
10
from tqdm.asyncio import tqdm
13
async def main(download_root):
14
download_root.mkdir(parents=True, exist_ok=True)
15
urls = {weight.url for name in models.list_models() for weight in iter(models.get_model_weights(name))}
17
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=None)) as session:
18
await tqdm.gather(*[download(download_root, session, url) for url in urls])
21
async def download(download_root, session, url):
22
response = await session.get(url, params=dict(source="ci"))
26
file_name = Path(urlsplit(url).path).name
27
async with aiofiles.open(download_root / file_name, "wb") as f:
28
async for data in response.content.iter_any():
32
if __name__ == "__main__":
34
(Path(sys.argv[1]) if len(sys.argv) > 1 else Path("~/.cache/torch/hub/checkpoints")).expanduser().resolve()
36
print(f"Downloading model weights to {download_root}")
37
start = perf_counter()
38
asyncio.get_event_loop().run_until_complete(main(download_root))
40
minutes, seconds = divmod(stop - start, 60)
41
print(f"Download took {minutes:2.0f}m {seconds:2.0f}s")