pytorch-lightning

Форк
0
111 строк · 3.8 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import time
16

17
import requests
18
from requests.adapters import HTTPAdapter
19
from rich.progress import BarColumn, Progress, TextColumn
20
from urllib3.util.retry import Retry
21

22

23
class FileUploader:
24
    """This class uploads a source file with presigned url to S3.
25

26
    Attributes
27
    ----------
28
    source_file: str
29
        Source file to upload
30
    presigned_url: str
31
        Presigned urls dictionary, with key as part number and values as urls
32
    retries: int
33
        Amount of retries when requests encounter an error
34
    total_size: int
35
        Size of all files to upload
36
    name: str
37
        Name of this upload to display progress
38

39
    """
40

41
    workers: int = 8
42
    retries: int = 10000
43
    disconnect_retry_wait_seconds: int = 5
44
    progress = Progress(
45
        TextColumn("[bold blue]{task.description}", justify="left"),
46
        BarColumn(bar_width=None),
47
        "[self.progress.percentage]{task.percentage:>3.1f}%",
48
    )
49

50
    def __init__(self, presigned_url: str, source_file: str, total_size: int, name: str, use_progress: bool = True):
51
        self.presigned_url = presigned_url
52
        self.source_file = source_file
53
        self.total_size = total_size
54
        self.name = name
55
        self.use_progress = use_progress
56
        self.task_id = None
57

58
    def upload_data(self, url: str, data: bytes, retries: int, disconnect_retry_wait_seconds: int) -> str:
59
        """Send data to url.
60

61
        Parameters
62
        ----------
63
        url: str
64
            url string to send data to
65
        data: bytes
66
             Bytes of data to send to url
67
        retries: int
68
            Amount of retries
69
        disconnect_retry_wait_seconds: int
70
            Amount of seconds between disconnect retry
71

72
        Returns
73
        -------
74
        str
75
            ETag from response
76

77
        """
78
        disconnect_retries = retries
79
        while disconnect_retries > 0:
80
            try:
81
                retries = Retry(total=10)
82
                with requests.Session() as s:
83
                    s.mount("https://", HTTPAdapter(max_retries=retries))
84
                    return self._upload_data(s, url, data)
85
            except BrokenPipeError:
86
                time.sleep(disconnect_retry_wait_seconds)
87
                disconnect_retries -= 1
88

89
        raise ValueError("Unable to upload file after multiple attempts")
90

91
    def _upload_data(self, s: requests.Session, url: str, data: bytes):
92
        resp = s.put(url, data=data)
93
        if "ETag" not in resp.headers:
94
            raise ValueError(f"Unexpected response from {url}, response: {resp.content}")
95
        return resp.headers["ETag"]
96

97
    def upload(self) -> None:
98
        """Upload files from source dir into target path in S3."""
99
        no_task = self.task_id is None
100
        if self.use_progress and no_task:
101
            self.task_id = self.progress.add_task("upload", filename=self.name, total=self.total_size)
102
            self.progress.start()
103
        try:
104
            with open(self.source_file, "rb") as f:
105
                data = f.read()
106
            self.upload_data(self.presigned_url, data, self.retries, self.disconnect_retry_wait_seconds)
107
            if self.use_progress:
108
                self.progress.update(self.task_id, advance=len(data))
109
        finally:
110
            if self.use_progress and no_task:
111
                self.progress.stop()
112

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.