paddlenlp

Форк
0
123 строки · 3.7 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import os
16
import shutil
17
import time
18

19
import paddle
20
import requests
21
from ppfleetx.utils.log import logger
22
from tqdm import tqdm
23

24
DOWNLOAD_RETRY_LIMIT = 3
25

26

27
def is_url(path):
28
    """
29
    Whether path is URL.
30
    Args:
31
        path (string): URL string or not.
32
    """
33
    return path.startswith("http://") or path.startswith("https://")
34

35

36
def _map_path(url, root_dir):
37
    # parse path after download under root_dir
38
    fname = os.path.split(url)[-1]
39
    fpath = fname
40
    return os.path.join(root_dir, fpath)
41

42

43
def cached_path(url_or_path, cache_dir=None):
44
    if cache_dir is None:
45
        cache_dir = "~/.cache/ppfleetx/"
46

47
    cache_dir = os.path.expanduser(cache_dir)
48

49
    if not os.path.exists(cache_dir):
50
        os.makedirs(cache_dir, exist_ok=True)
51

52
    if is_url(url_or_path):
53
        path = _map_path(url_or_path, cache_dir)
54
        url = url_or_path
55
    else:
56
        path = url_or_path
57
        url = None
58

59
    if os.path.exists(path):
60
        logger.info(f"Found {os.path.split(path)[-1]} in cache_dir: {cache_dir}.")
61
        return path
62

63
    download(url, path)
64
    return path
65

66

67
def _download(url, fullname):
68
    """
69
    Download from url, save to path.
70
    url (str): download url
71
    path (str): download to given path
72
    """
73
    retry_cnt = 0
74

75
    while not os.path.exists(fullname):
76
        if retry_cnt < DOWNLOAD_RETRY_LIMIT:
77
            retry_cnt += 1
78
        else:
79
            raise RuntimeError("Download from {} failed. " "Retry limit reached".format(url))
80

81
        logger.info("Downloading {}".format(url))
82

83
        try:
84
            req = requests.get(url, stream=True)
85
        except Exception as e:  # requests.exceptions.ConnectionError
86
            logger.info("Downloading {} failed {} times with exception {}".format(url, retry_cnt + 1, str(e)))
87
            time.sleep(1)
88
            continue
89

90
        if req.status_code != 200:
91
            raise RuntimeError("Downloading from {} failed with code " "{}!".format(url, req.status_code))
92

93
        # For protecting download interupted, download to
94
        # tmp_fullname firstly, move tmp_fullname to fullname
95
        # after download finished
96
        tmp_fullname = fullname + "_tmp"
97
        total_size = req.headers.get("content-length")
98
        with open(tmp_fullname, "wb") as f:
99
            if total_size:
100
                with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
101
                    for chunk in req.iter_content(chunk_size=1024):
102
                        f.write(chunk)
103
                        pbar.update(1)
104
            else:
105
                for chunk in req.iter_content(chunk_size=1024):
106
                    if chunk:
107
                        f.write(chunk)
108
        shutil.move(tmp_fullname, fullname)
109

110
    return fullname
111

112

113
def download(url, path):
114
    local_rank = 0
115
    world_size = 1
116
    if paddle.base.core.is_compiled_with_dist() and paddle.distributed.get_world_size() > 1:
117
        local_rank = paddle.distributed.ParallelEnv().dev_id
118
        world_size = paddle.distributed.get_world_size()
119
    if world_size > 1 and local_rank != 0:
120
        while not os.path.exists(path):
121
            time.sleep(1)
122
    else:
123
        _download(url, path)
124

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.