onnxruntime

Форк
0
/
fix_long_lines.py 
133 строки · 4.1 Кб
1
# Copyright (c) Microsoft Corporation. All rights reserved.
2
# Licensed under the MIT License.
3

4
import argparse
5
import logging
6
import os
7
import pathlib
8
import shutil
9
import tempfile
10

11
from util import logger, run
12

13
_log = logger.get_logger("fix_long_lines", logging.INFO)
14

15

16
# look for long lines in the file, and if found run clang-format on those lines
17
def _process_files(filenames, clang_exe, tmpdir):
18
    for path in filenames:
19
        _log.debug(f"Checking {path}")
20
        bad_lines = []
21

22
        with open(path, encoding="UTF8") as f:
23
            for i, line in enumerate(f):
24
                line_num = i + 1  # clang-format line numbers start at 1
25
                if len(line) > 120:
26
                    bad_lines.append(line_num)
27

28
        if bad_lines:
29
            _log.info(f"Updating {path}")
30
            filename = os.path.basename(path)
31
            target = os.path.join(tmpdir, filename)
32
            shutil.copy(path, target)
33

34
            # run clang-format to update just the long lines in the file
35
            cmd = [
36
                clang_exe,
37
                "-i",
38
            ]
39
            for line in bad_lines:
40
                cmd.append(f"--lines={line}:{line}")
41

42
            cmd.append(target)
43

44
            run(*cmd, cwd=tmpdir, check=True, shell=True)
45

46
            # copy updated file back to original location
47
            shutil.copy(target, path)
48

49

50
# file extensions we process
51
_EXTENSIONS = [".cc", ".h"]
52

53

54
def _get_branch_diffs(ort_root, branch):
55
    command = ["git", "diff", branch, "--name-only"]
56
    result = run(*command, capture_stdout=True, check=True)
57

58
    # stdout is bytes. one filename per line. decode, split, and filter to the extensions we are looking at
59
    for f in result.stdout.decode("utf-8").splitlines():
60
        if os.path.splitext(f.lower())[-1] in _EXTENSIONS:
61
            yield os.path.join(ort_root, f)
62

63

64
def _get_file_list(path):
65
    for root, _, files in os.walk(path):
66
        for file in files:
67
            if os.path.splitext(file.lower())[-1] in _EXTENSIONS:
68
                yield os.path.join(root, file)
69

70

71
def main():
72
    argparser = argparse.ArgumentParser(
73
        "Script to fix long lines in the source using clang-format. "
74
        "Only lines that exceed the 120 character maximum are altered in order to minimize the impact. "
75
        "Checks .cc and .h files",
76
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
77
    )
78

79
    argparser.add_argument(
80
        "--branch",
81
        type=str,
82
        default="origin/main",
83
        help="Limit changes to files that differ from this branch. Use origin/main when preparing a PR.",
84
    )
85

86
    argparser.add_argument(
87
        "--all_files",
88
        action="store_true",
89
        help="Process all files under /include/onnxruntime and /onnxruntime/core. Ignores --branch value.",
90
    )
91

92
    argparser.add_argument(
93
        "--clang-format",
94
        type=pathlib.Path,
95
        required=False,
96
        default="clang-format",
97
        help="Path to clang-format executable",
98
    )
99

100
    argparser.add_argument("--debug", action="store_true", help="Set log level to DEBUG.")
101

102
    args = argparser.parse_args()
103

104
    if args.debug:
105
        _log.setLevel(logging.DEBUG)
106

107
    script_dir = os.path.dirname(os.path.realpath(__file__))
108
    ort_root = os.path.abspath(os.path.join(script_dir, "..", ".."))
109

110
    with tempfile.TemporaryDirectory() as tmpdir:
111
        # create config in tmpdir
112
        with open(os.path.join(tmpdir, ".clang-format"), "w") as f:
113
            f.write(
114
                """
115
            BasedOnStyle: Google
116
            ColumnLimit: 120
117
            DerivePointerAlignment: false
118
            """
119
            )
120

121
        clang_format = str(args.clang_format)
122

123
        if args.all_files:
124
            include_path = os.path.join(ort_root, "include", "onnxruntime")
125
            src_path = os.path.join(ort_root, "onnxruntime", "core")
126
            _process_files(_get_file_list(include_path), clang_format, tmpdir)
127
            _process_files(_get_file_list(src_path), clang_format, tmpdir)
128
        else:
129
            _process_files(_get_branch_diffs(ort_root, args.branch), clang_format, tmpdir)
130

131

132
if __name__ == "__main__":
133
    main()
134

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.