pytorch

Форк
0
/
black_linter.py 
227 строк · 6.1 Кб
1
import argparse
2
import concurrent.futures
3
import json
4
import logging
5
import os
6
import subprocess
7
import sys
8
import time
9
from enum import Enum
10
from typing import Any, BinaryIO, List, NamedTuple, Optional
11

12

13
IS_WINDOWS: bool = os.name == "nt"
14

15

16
def eprint(*args: Any, **kwargs: Any) -> None:
17
    print(*args, file=sys.stderr, flush=True, **kwargs)
18

19

20
class LintSeverity(str, Enum):
21
    ERROR = "error"
22
    WARNING = "warning"
23
    ADVICE = "advice"
24
    DISABLED = "disabled"
25

26

27
class LintMessage(NamedTuple):
28
    path: Optional[str]
29
    line: Optional[int]
30
    char: Optional[int]
31
    code: str
32
    severity: LintSeverity
33
    name: str
34
    original: Optional[str]
35
    replacement: Optional[str]
36
    description: Optional[str]
37

38

39
def as_posix(name: str) -> str:
40
    return name.replace("\\", "/") if IS_WINDOWS else name
41

42

43
def _run_command(
44
    args: List[str],
45
    *,
46
    stdin: BinaryIO,
47
    timeout: int,
48
) -> "subprocess.CompletedProcess[bytes]":
49
    logging.debug("$ %s", " ".join(args))
50
    start_time = time.monotonic()
51
    try:
52
        return subprocess.run(
53
            args,
54
            stdin=stdin,
55
            capture_output=True,
56
            shell=IS_WINDOWS,  # So batch scripts are found.
57
            timeout=timeout,
58
            check=True,
59
        )
60
    finally:
61
        end_time = time.monotonic()
62
        logging.debug("took %dms", (end_time - start_time) * 1000)
63

64

65
def run_command(
66
    args: List[str],
67
    *,
68
    stdin: BinaryIO,
69
    retries: int,
70
    timeout: int,
71
) -> "subprocess.CompletedProcess[bytes]":
72
    remaining_retries = retries
73
    while True:
74
        try:
75
            return _run_command(args, stdin=stdin, timeout=timeout)
76
        except subprocess.TimeoutExpired as err:
77
            if remaining_retries == 0:
78
                raise err
79
            remaining_retries -= 1
80
            logging.warning(
81
                "(%s/%s) Retrying because command failed with: %r",
82
                retries - remaining_retries,
83
                retries,
84
                err,
85
            )
86
            time.sleep(1)
87

88

89
def check_file(
90
    filename: str,
91
    retries: int,
92
    timeout: int,
93
) -> List[LintMessage]:
94
    try:
95
        with open(filename, "rb") as f:
96
            original = f.read()
97
        with open(filename, "rb") as f:
98
            proc = run_command(
99
                [sys.executable, "-mblack", "--stdin-filename", filename, "-"],
100
                stdin=f,
101
                retries=retries,
102
                timeout=timeout,
103
            )
104
    except subprocess.TimeoutExpired:
105
        return [
106
            LintMessage(
107
                path=filename,
108
                line=None,
109
                char=None,
110
                code="BLACK",
111
                severity=LintSeverity.ERROR,
112
                name="timeout",
113
                original=None,
114
                replacement=None,
115
                description=(
116
                    "black timed out while trying to process a file. "
117
                    "Please report an issue in pytorch/pytorch with the "
118
                    "label 'module: lint'"
119
                ),
120
            )
121
        ]
122
    except (OSError, subprocess.CalledProcessError) as err:
123
        return [
124
            LintMessage(
125
                path=filename,
126
                line=None,
127
                char=None,
128
                code="BLACK",
129
                severity=LintSeverity.ADVICE,
130
                name="command-failed",
131
                original=None,
132
                replacement=None,
133
                description=(
134
                    f"Failed due to {err.__class__.__name__}:\n{err}"
135
                    if not isinstance(err, subprocess.CalledProcessError)
136
                    else (
137
                        "COMMAND (exit code {returncode})\n"
138
                        "{command}\n\n"
139
                        "STDERR\n{stderr}\n\n"
140
                        "STDOUT\n{stdout}"
141
                    ).format(
142
                        returncode=err.returncode,
143
                        command=" ".join(as_posix(x) for x in err.cmd),
144
                        stderr=err.stderr.decode("utf-8").strip() or "(empty)",
145
                        stdout=err.stdout.decode("utf-8").strip() or "(empty)",
146
                    )
147
                ),
148
            )
149
        ]
150

151
    replacement = proc.stdout
152
    if original == replacement:
153
        return []
154

155
    return [
156
        LintMessage(
157
            path=filename,
158
            line=None,
159
            char=None,
160
            code="BLACK",
161
            severity=LintSeverity.WARNING,
162
            name="format",
163
            original=original.decode("utf-8"),
164
            replacement=replacement.decode("utf-8"),
165
            description="Run `lintrunner -a` to apply this patch.",
166
        )
167
    ]
168

169

170
def main() -> None:
171
    parser = argparse.ArgumentParser(
172
        description="Format files with black.",
173
        fromfile_prefix_chars="@",
174
    )
175
    parser.add_argument(
176
        "--retries",
177
        default=3,
178
        type=int,
179
        help="times to retry timed out black",
180
    )
181
    parser.add_argument(
182
        "--timeout",
183
        default=90,
184
        type=int,
185
        help="seconds to wait for black",
186
    )
187
    parser.add_argument(
188
        "--verbose",
189
        action="store_true",
190
        help="verbose logging",
191
    )
192
    parser.add_argument(
193
        "filenames",
194
        nargs="+",
195
        help="paths to lint",
196
    )
197
    args = parser.parse_args()
198

199
    logging.basicConfig(
200
        format="<%(threadName)s:%(levelname)s> %(message)s",
201
        level=logging.NOTSET
202
        if args.verbose
203
        else logging.DEBUG
204
        if len(args.filenames) < 1000
205
        else logging.INFO,
206
        stream=sys.stderr,
207
    )
208

209
    with concurrent.futures.ThreadPoolExecutor(
210
        max_workers=os.cpu_count(),
211
        thread_name_prefix="Thread",
212
    ) as executor:
213
        futures = {
214
            executor.submit(check_file, x, args.retries, args.timeout): x
215
            for x in args.filenames
216
        }
217
        for future in concurrent.futures.as_completed(futures):
218
            try:
219
                for lint_message in future.result():
220
                    print(json.dumps(lint_message._asdict()), flush=True)
221
            except Exception:
222
                logging.critical('Failed at "%s".', futures[future])
223
                raise
224

225

226
if __name__ == "__main__":
227
    main()
228

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.