pytorch-lightning

Форк
0
201 строка · 5.9 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import math
16
import os
17
import subprocess
18
import tarfile
19
from dataclasses import dataclass
20
from typing import Optional, Tuple
21

22
import click
23

24
MAX_SPLIT_COUNT = 999
25

26

27
def _get_dir_size_and_count(source_dir: str, prefix: Optional[str] = None) -> Tuple[int, int]:
28
    """Get size and file count of a directory.
29

30
    Parameters
31
    ----------
32
    source_dir: str
33
        Directory path
34

35
    Returns
36
    -------
37
    Tuple[int, int]
38
        Size in megabytes and file count
39

40
    """
41
    size = 0
42
    count = 0
43
    for root, _, files in os.walk(source_dir, topdown=True):
44
        for f in files:
45
            if prefix and not f.startswith(prefix):
46
                continue
47

48
            full_path = os.path.join(root, f)
49
            size += os.path.getsize(full_path)
50
            count += 1
51

52
    return (size, count)
53

54

55
@dataclass
56
class _TarResults:
57
    """This class holds the results of running tar_path.
58

59
    Attributes
60
    ----------
61
    before_size: int
62
        The total size of the original directory files in bytes
63
    after_size: int
64
        The total size of the compressed and tarred split files in bytes
65

66
    """
67

68
    before_size: int
69
    after_size: int
70

71

72
def _get_split_size(
73
    total_size: int, minimum_split_size: int = 1024 * 1000 * 20, max_split_count: int = MAX_SPLIT_COUNT
74
) -> int:
75
    """Calculate the split size we should use to split the multipart upload of an object to a bucket.  We are limited
76
    to 1000 max parts as the way we are using ListMultipartUploads. More info https://github.com/gridai/grid/pull/5267
77
    https://docs.aws.amazon.com/AmazonS3/latest/userguide/mpuoverview.html#mpu-process
78
    https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListMultipartUploads.html
79
    https://github.com/psf/requests/issues/2717#issuecomment-724725392 Python or requests has a limit of 2**31 bytes
80
    for a single file upload.
81

82
    Parameters
83
    ----------
84
    minimum_split_size: int
85
        The minimum split size to use
86
    max_split_count: int
87
        The maximum split count
88
    total_size: int
89
        Total size of the file to split
90

91
    Returns
92
    -------
93
    int
94
        Split size
95

96
    """
97
    max_size = max_split_count * (1 << 31)  # max size per part limited by Requests or urllib as shown in ref above
98
    if total_size > max_size:
99
        raise click.ClickException(
100
            f"The size of the datastore to be uploaded is bigger than our {max_size / (1 << 40):.2f} TBytes limit"
101
        )
102

103
    split_size = minimum_split_size
104
    split_count = math.ceil(total_size / split_size)
105
    if split_count > max_split_count:
106
        # Adjust the split size based on max split count
107
        split_size = math.ceil(total_size / max_split_count)
108

109
    return split_size
110

111

112
def _tar_path(source_path: str, target_file: str, compression: bool = False) -> _TarResults:
113
    """Create tar from directory using `tar`
114

115
    Parameters
116
    ----------
117
    source_path: str
118
        Source directory or file
119
    target_file
120
        Target tar file
121
    compression: bool, default False
122
        Enable compression, which is disabled by default.
123

124
    Returns
125
    -------
126
    TarResults
127
        Results that holds file counts and sizes
128

129
    """
130
    if os.path.isdir(source_path):
131
        before_size, _ = _get_dir_size_and_count(source_path)
132
    else:
133
        before_size = os.path.getsize(source_path)
134

135
    try:
136
        _tar_path_subprocess(source_path, target_file, compression)
137
    except subprocess.CalledProcessError:
138
        _tar_path_python(source_path, target_file, compression)
139

140
    after_size = os.stat(target_file).st_size
141
    return _TarResults(before_size=before_size, after_size=after_size)
142

143

144
def _tar_path_python(source_path: str, target_file: str, compression: bool = False) -> None:
145
    """Create tar from directory using `python`
146

147
    Parameters
148
    ----------
149
    source_path: str
150
        Source directory or file
151
    target_file
152
        Target tar file
153
    compression: bool, default False
154
        Enable compression, which is disabled by default.
155

156
    """
157
    file_mode = "w:gz" if compression else "w:"
158

159
    with tarfile.open(target_file, file_mode) as tar:
160
        if os.path.isdir(source_path):
161
            tar.add(str(source_path), arcname=".")
162
        elif os.path.isfile(source_path):
163
            file_info = tarfile.TarInfo(os.path.basename(str(source_path)))
164
            with open(source_path) as fo:
165
                tar.addfile(file_info, fo)
166

167

168
def _tar_path_subprocess(source_path: str, target_file: str, compression: bool = False) -> None:
169
    """Create tar from directory using `tar`
170

171
    Parameters
172
    ----------
173
    source_path: str
174
        Source directory or file
175
    target_file
176
        Target tar file
177
    compression: bool, default False
178
        Enable compression, which is disabled by default.
179

180
    """
181
    # Only add compression when users explicitly request it.
182
    # We do this because it takes too long to compress
183
    # large datastores.
184
    tar_flags = "-cvf"
185
    if compression:
186
        tar_flags = "-zcvf"
187
    if os.path.isdir(source_path):
188
        command = f"tar -C {source_path} {tar_flags} {target_file} ./"
189
    else:
190
        abs_path = os.path.abspath(source_path)
191
        parent_dir = os.path.dirname(abs_path)
192
        base_name = os.path.basename(abs_path)
193
        command = f"tar -C {parent_dir} {tar_flags} {target_file} {base_name}"
194

195
    subprocess.check_call(
196
        command,
197
        stdout=subprocess.DEVNULL,
198
        stderr=subprocess.DEVNULL,
199
        shell=True,
200
        env={"GZIP": "-9", "COPYFILE_DISABLE": "1"},
201
    )
202

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.