onnxruntime
133 строки · 4.1 Кб
1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License.
3
4import argparse
5import logging
6import os
7import pathlib
8import shutil
9import tempfile
10
11from util import logger, run
12
13_log = logger.get_logger("fix_long_lines", logging.INFO)
14
15
16# look for long lines in the file, and if found run clang-format on those lines
17def _process_files(filenames, clang_exe, tmpdir):
18for path in filenames:
19_log.debug(f"Checking {path}")
20bad_lines = []
21
22with open(path, encoding="UTF8") as f:
23for i, line in enumerate(f):
24line_num = i + 1 # clang-format line numbers start at 1
25if len(line) > 120:
26bad_lines.append(line_num)
27
28if bad_lines:
29_log.info(f"Updating {path}")
30filename = os.path.basename(path)
31target = os.path.join(tmpdir, filename)
32shutil.copy(path, target)
33
34# run clang-format to update just the long lines in the file
35cmd = [
36clang_exe,
37"-i",
38]
39for line in bad_lines:
40cmd.append(f"--lines={line}:{line}")
41
42cmd.append(target)
43
44run(*cmd, cwd=tmpdir, check=True, shell=True)
45
46# copy updated file back to original location
47shutil.copy(target, path)
48
49
50# file extensions we process
51_EXTENSIONS = [".cc", ".h"]
52
53
54def _get_branch_diffs(ort_root, branch):
55command = ["git", "diff", branch, "--name-only"]
56result = run(*command, capture_stdout=True, check=True)
57
58# stdout is bytes. one filename per line. decode, split, and filter to the extensions we are looking at
59for f in result.stdout.decode("utf-8").splitlines():
60if os.path.splitext(f.lower())[-1] in _EXTENSIONS:
61yield os.path.join(ort_root, f)
62
63
64def _get_file_list(path):
65for root, _, files in os.walk(path):
66for file in files:
67if os.path.splitext(file.lower())[-1] in _EXTENSIONS:
68yield os.path.join(root, file)
69
70
71def main():
72argparser = argparse.ArgumentParser(
73"Script to fix long lines in the source using clang-format. "
74"Only lines that exceed the 120 character maximum are altered in order to minimize the impact. "
75"Checks .cc and .h files",
76formatter_class=argparse.ArgumentDefaultsHelpFormatter,
77)
78
79argparser.add_argument(
80"--branch",
81type=str,
82default="origin/main",
83help="Limit changes to files that differ from this branch. Use origin/main when preparing a PR.",
84)
85
86argparser.add_argument(
87"--all_files",
88action="store_true",
89help="Process all files under /include/onnxruntime and /onnxruntime/core. Ignores --branch value.",
90)
91
92argparser.add_argument(
93"--clang-format",
94type=pathlib.Path,
95required=False,
96default="clang-format",
97help="Path to clang-format executable",
98)
99
100argparser.add_argument("--debug", action="store_true", help="Set log level to DEBUG.")
101
102args = argparser.parse_args()
103
104if args.debug:
105_log.setLevel(logging.DEBUG)
106
107script_dir = os.path.dirname(os.path.realpath(__file__))
108ort_root = os.path.abspath(os.path.join(script_dir, "..", ".."))
109
110with tempfile.TemporaryDirectory() as tmpdir:
111# create config in tmpdir
112with open(os.path.join(tmpdir, ".clang-format"), "w") as f:
113f.write(
114"""
115BasedOnStyle: Google
116ColumnLimit: 120
117DerivePointerAlignment: false
118"""
119)
120
121clang_format = str(args.clang_format)
122
123if args.all_files:
124include_path = os.path.join(ort_root, "include", "onnxruntime")
125src_path = os.path.join(ort_root, "onnxruntime", "core")
126_process_files(_get_file_list(include_path), clang_format, tmpdir)
127_process_files(_get_file_list(src_path), clang_format, tmpdir)
128else:
129_process_files(_get_branch_diffs(ort_root, args.branch), clang_format, tmpdir)
130
131
132if __name__ == "__main__":
133main()
134