pytorch-lightning
111 строк · 3.8 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import time
16
17import requests
18from requests.adapters import HTTPAdapter
19from rich.progress import BarColumn, Progress, TextColumn
20from urllib3.util.retry import Retry
21
22
23class FileUploader:
24"""This class uploads a source file with presigned url to S3.
25
26Attributes
27----------
28source_file: str
29Source file to upload
30presigned_url: str
31Presigned urls dictionary, with key as part number and values as urls
32retries: int
33Amount of retries when requests encounter an error
34total_size: int
35Size of all files to upload
36name: str
37Name of this upload to display progress
38
39"""
40
41workers: int = 8
42retries: int = 10000
43disconnect_retry_wait_seconds: int = 5
44progress = Progress(
45TextColumn("[bold blue]{task.description}", justify="left"),
46BarColumn(bar_width=None),
47"[self.progress.percentage]{task.percentage:>3.1f}%",
48)
49
50def __init__(self, presigned_url: str, source_file: str, total_size: int, name: str, use_progress: bool = True):
51self.presigned_url = presigned_url
52self.source_file = source_file
53self.total_size = total_size
54self.name = name
55self.use_progress = use_progress
56self.task_id = None
57
58def upload_data(self, url: str, data: bytes, retries: int, disconnect_retry_wait_seconds: int) -> str:
59"""Send data to url.
60
61Parameters
62----------
63url: str
64url string to send data to
65data: bytes
66Bytes of data to send to url
67retries: int
68Amount of retries
69disconnect_retry_wait_seconds: int
70Amount of seconds between disconnect retry
71
72Returns
73-------
74str
75ETag from response
76
77"""
78disconnect_retries = retries
79while disconnect_retries > 0:
80try:
81retries = Retry(total=10)
82with requests.Session() as s:
83s.mount("https://", HTTPAdapter(max_retries=retries))
84return self._upload_data(s, url, data)
85except BrokenPipeError:
86time.sleep(disconnect_retry_wait_seconds)
87disconnect_retries -= 1
88
89raise ValueError("Unable to upload file after multiple attempts")
90
91def _upload_data(self, s: requests.Session, url: str, data: bytes):
92resp = s.put(url, data=data)
93if "ETag" not in resp.headers:
94raise ValueError(f"Unexpected response from {url}, response: {resp.content}")
95return resp.headers["ETag"]
96
97def upload(self) -> None:
98"""Upload files from source dir into target path in S3."""
99no_task = self.task_id is None
100if self.use_progress and no_task:
101self.task_id = self.progress.add_task("upload", filename=self.name, total=self.total_size)
102self.progress.start()
103try:
104with open(self.source_file, "rb") as f:
105data = f.read()
106self.upload_data(self.presigned_url, data, self.retries, self.disconnect_retry_wait_seconds)
107if self.use_progress:
108self.progress.update(self.task_id, advance=len(data))
109finally:
110if self.use_progress and no_task:
111self.progress.stop()
112