pytorch-lightning
201 строка · 5.9 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import math
16import os
17import subprocess
18import tarfile
19from dataclasses import dataclass
20from typing import Optional, Tuple
21
22import click
23
24MAX_SPLIT_COUNT = 999
25
26
27def _get_dir_size_and_count(source_dir: str, prefix: Optional[str] = None) -> Tuple[int, int]:
28"""Get size and file count of a directory.
29
30Parameters
31----------
32source_dir: str
33Directory path
34
35Returns
36-------
37Tuple[int, int]
38Size in megabytes and file count
39
40"""
41size = 0
42count = 0
43for root, _, files in os.walk(source_dir, topdown=True):
44for f in files:
45if prefix and not f.startswith(prefix):
46continue
47
48full_path = os.path.join(root, f)
49size += os.path.getsize(full_path)
50count += 1
51
52return (size, count)
53
54
55@dataclass
56class _TarResults:
57"""This class holds the results of running tar_path.
58
59Attributes
60----------
61before_size: int
62The total size of the original directory files in bytes
63after_size: int
64The total size of the compressed and tarred split files in bytes
65
66"""
67
68before_size: int
69after_size: int
70
71
72def _get_split_size(
73total_size: int, minimum_split_size: int = 1024 * 1000 * 20, max_split_count: int = MAX_SPLIT_COUNT
74) -> int:
75"""Calculate the split size we should use to split the multipart upload of an object to a bucket. We are limited
76to 1000 max parts as the way we are using ListMultipartUploads. More info https://github.com/gridai/grid/pull/5267
77https://docs.aws.amazon.com/AmazonS3/latest/userguide/mpuoverview.html#mpu-process
78https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListMultipartUploads.html
79https://github.com/psf/requests/issues/2717#issuecomment-724725392 Python or requests has a limit of 2**31 bytes
80for a single file upload.
81
82Parameters
83----------
84minimum_split_size: int
85The minimum split size to use
86max_split_count: int
87The maximum split count
88total_size: int
89Total size of the file to split
90
91Returns
92-------
93int
94Split size
95
96"""
97max_size = max_split_count * (1 << 31) # max size per part limited by Requests or urllib as shown in ref above
98if total_size > max_size:
99raise click.ClickException(
100f"The size of the datastore to be uploaded is bigger than our {max_size / (1 << 40):.2f} TBytes limit"
101)
102
103split_size = minimum_split_size
104split_count = math.ceil(total_size / split_size)
105if split_count > max_split_count:
106# Adjust the split size based on max split count
107split_size = math.ceil(total_size / max_split_count)
108
109return split_size
110
111
112def _tar_path(source_path: str, target_file: str, compression: bool = False) -> _TarResults:
113"""Create tar from directory using `tar`
114
115Parameters
116----------
117source_path: str
118Source directory or file
119target_file
120Target tar file
121compression: bool, default False
122Enable compression, which is disabled by default.
123
124Returns
125-------
126TarResults
127Results that holds file counts and sizes
128
129"""
130if os.path.isdir(source_path):
131before_size, _ = _get_dir_size_and_count(source_path)
132else:
133before_size = os.path.getsize(source_path)
134
135try:
136_tar_path_subprocess(source_path, target_file, compression)
137except subprocess.CalledProcessError:
138_tar_path_python(source_path, target_file, compression)
139
140after_size = os.stat(target_file).st_size
141return _TarResults(before_size=before_size, after_size=after_size)
142
143
144def _tar_path_python(source_path: str, target_file: str, compression: bool = False) -> None:
145"""Create tar from directory using `python`
146
147Parameters
148----------
149source_path: str
150Source directory or file
151target_file
152Target tar file
153compression: bool, default False
154Enable compression, which is disabled by default.
155
156"""
157file_mode = "w:gz" if compression else "w:"
158
159with tarfile.open(target_file, file_mode) as tar:
160if os.path.isdir(source_path):
161tar.add(str(source_path), arcname=".")
162elif os.path.isfile(source_path):
163file_info = tarfile.TarInfo(os.path.basename(str(source_path)))
164with open(source_path) as fo:
165tar.addfile(file_info, fo)
166
167
168def _tar_path_subprocess(source_path: str, target_file: str, compression: bool = False) -> None:
169"""Create tar from directory using `tar`
170
171Parameters
172----------
173source_path: str
174Source directory or file
175target_file
176Target tar file
177compression: bool, default False
178Enable compression, which is disabled by default.
179
180"""
181# Only add compression when users explicitly request it.
182# We do this because it takes too long to compress
183# large datastores.
184tar_flags = "-cvf"
185if compression:
186tar_flags = "-zcvf"
187if os.path.isdir(source_path):
188command = f"tar -C {source_path} {tar_flags} {target_file} ./"
189else:
190abs_path = os.path.abspath(source_path)
191parent_dir = os.path.dirname(abs_path)
192base_name = os.path.basename(abs_path)
193command = f"tar -C {parent_dir} {tar_flags} {target_file} {base_name}"
194
195subprocess.check_call(
196command,
197stdout=subprocess.DEVNULL,
198stderr=subprocess.DEVNULL,
199shell=True,
200env={"GZIP": "-9", "COPYFILE_DISABLE": "1"},
201)
202