transformers
124 строки · 4.4 Кб
1# coding=utf-8
2# Copyright 2022 The HuggingFace Inc. team.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""
16Utility that sorts the names in the auto mappings defines in the auto modules in alphabetical order.
17
18Use from the root of the repo with:
19
20```bash
21python utils/sort_auto_mappings.py
22```
23
24to auto-fix all the auto mappings (used in `make style`).
25
26To only check if the mappings are properly sorted (as used in `make quality`), do:
27
28```bash
29python utils/sort_auto_mappings.py --check_only
30```
31"""
32import argparse33import os34import re35from typing import Optional36
37
38# Path are set with the intent you should run this script from the root of the repo.
39PATH_TO_AUTO_MODULE = "src/transformers/models/auto"40
41
42# re pattern that matches mapping introductions:
43# SUPER_MODEL_MAPPING_NAMES = OrderedDict or SUPER_MODEL_MAPPING = OrderedDict
44_re_intro_mapping = re.compile(r"[A-Z_]+_MAPPING(\s+|_[A-Z_]+\s+)=\s+OrderedDict")45# re pattern that matches identifiers in mappings
46_re_identifier = re.compile(r'\s*\(\s*"(\S[^"]+)"')47
48
49def sort_auto_mapping(fname: str, overwrite: bool = False) -> Optional[bool]:50"""51Sort all auto mappings in a file.
52
53Args:
54fname (`str`): The name of the file where we want to sort auto-mappings.
55overwrite (`bool`, *optional*, defaults to `False`): Whether or not to fix and overwrite the file.
56
57Returns:
58`Optional[bool]`: Returns `None` if `overwrite=True`. Otherwise returns `True` if the file has an auto-mapping
59improperly sorted, `False` if the file is okay.
60"""
61with open(fname, "r", encoding="utf-8") as f:62content = f.read()63
64lines = content.split("\n")65new_lines = []66line_idx = 067while line_idx < len(lines):68if _re_intro_mapping.search(lines[line_idx]) is not None:69# Start of a new mapping!70indent = len(re.search(r"^(\s*)\S", lines[line_idx]).groups()[0]) + 871while not lines[line_idx].startswith(" " * indent + "("):72new_lines.append(lines[line_idx])73line_idx += 174
75blocks = []76while lines[line_idx].strip() != "]":77# Blocks either fit in one line or not78if lines[line_idx].strip() == "(":79start_idx = line_idx80while not lines[line_idx].startswith(" " * indent + ")"):81line_idx += 182blocks.append("\n".join(lines[start_idx : line_idx + 1]))83else:84blocks.append(lines[line_idx])85line_idx += 186
87# Sort blocks by their identifiers88blocks = sorted(blocks, key=lambda x: _re_identifier.search(x).groups()[0])89new_lines += blocks90else:91new_lines.append(lines[line_idx])92line_idx += 193
94if overwrite:95with open(fname, "w", encoding="utf-8") as f:96f.write("\n".join(new_lines))97else:98return "\n".join(new_lines) != content99
100
101def sort_all_auto_mappings(overwrite: bool = False):102"""103Sort all auto mappings in the library.
104
105Args:
106overwrite (`bool`, *optional*, defaults to `False`): Whether or not to fix and overwrite the file.
107"""
108fnames = [os.path.join(PATH_TO_AUTO_MODULE, f) for f in os.listdir(PATH_TO_AUTO_MODULE) if f.endswith(".py")]109diffs = [sort_auto_mapping(fname, overwrite=overwrite) for fname in fnames]110
111if not overwrite and any(diffs):112failures = [f for f, d in zip(fnames, diffs) if d]113raise ValueError(114f"The following files have auto mappings that need sorting: {', '.join(failures)}. Run `make style` to fix"115" this."116)117
118
119if __name__ == "__main__":120parser = argparse.ArgumentParser()121parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.")122args = parser.parse_args()123
124sort_all_auto_mappings(not args.check_only)125