vision

Форк
0
/
download_model_urls.py 
41 строка · 1.3 Кб
1
import asyncio
2
import sys
3
from pathlib import Path
4
from time import perf_counter
5
from urllib.parse import urlsplit
6

7
import aiofiles
8
import aiohttp
9
from torchvision import models
10
from tqdm.asyncio import tqdm
11

12

13
async def main(download_root):
14
    download_root.mkdir(parents=True, exist_ok=True)
15
    urls = {weight.url for name in models.list_models() for weight in iter(models.get_model_weights(name))}
16

17
    async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=None)) as session:
18
        await tqdm.gather(*[download(download_root, session, url) for url in urls])
19

20

21
async def download(download_root, session, url):
22
    response = await session.get(url, params=dict(source="ci"))
23

24
    assert response.ok
25

26
    file_name = Path(urlsplit(url).path).name
27
    async with aiofiles.open(download_root / file_name, "wb") as f:
28
        async for data in response.content.iter_any():
29
            await f.write(data)
30

31

32
if __name__ == "__main__":
33
    download_root = (
34
        (Path(sys.argv[1]) if len(sys.argv) > 1 else Path("~/.cache/torch/hub/checkpoints")).expanduser().resolve()
35
    )
36
    print(f"Downloading model weights to {download_root}")
37
    start = perf_counter()
38
    asyncio.get_event_loop().run_until_complete(main(download_root))
39
    stop = perf_counter()
40
    minutes, seconds = divmod(stop - start, 60)
41
    print(f"Download took {minutes:2.0f}m {seconds:2.0f}s")
42

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.