transformers

Форк
0
/
sort_auto_mappings.py 
124 строки · 4.4 Кб
1
# coding=utf-8
2
# Copyright 2022 The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""
16
Utility that sorts the names in the auto mappings defines in the auto modules in alphabetical order.
17

18
Use from the root of the repo with:
19

20
```bash
21
python utils/sort_auto_mappings.py
22
```
23

24
to auto-fix all the auto mappings (used in `make style`).
25

26
To only check if the mappings are properly sorted (as used in `make quality`), do:
27

28
```bash
29
python utils/sort_auto_mappings.py --check_only
30
```
31
"""
32
import argparse
33
import os
34
import re
35
from typing import Optional
36

37

38
# Path are set with the intent you should run this script from the root of the repo.
39
PATH_TO_AUTO_MODULE = "src/transformers/models/auto"
40

41

42
# re pattern that matches mapping introductions:
43
#    SUPER_MODEL_MAPPING_NAMES = OrderedDict or SUPER_MODEL_MAPPING = OrderedDict
44
_re_intro_mapping = re.compile(r"[A-Z_]+_MAPPING(\s+|_[A-Z_]+\s+)=\s+OrderedDict")
45
# re pattern that matches identifiers in mappings
46
_re_identifier = re.compile(r'\s*\(\s*"(\S[^"]+)"')
47

48

49
def sort_auto_mapping(fname: str, overwrite: bool = False) -> Optional[bool]:
50
    """
51
    Sort all auto mappings in a file.
52

53
    Args:
54
        fname (`str`): The name of the file where we want to sort auto-mappings.
55
        overwrite (`bool`, *optional*, defaults to `False`): Whether or not to fix and overwrite the file.
56

57
    Returns:
58
        `Optional[bool]`: Returns `None` if `overwrite=True`. Otherwise returns `True` if the file has an auto-mapping
59
        improperly sorted, `False` if the file is okay.
60
    """
61
    with open(fname, "r", encoding="utf-8") as f:
62
        content = f.read()
63

64
    lines = content.split("\n")
65
    new_lines = []
66
    line_idx = 0
67
    while line_idx < len(lines):
68
        if _re_intro_mapping.search(lines[line_idx]) is not None:
69
            # Start of a new mapping!
70
            indent = len(re.search(r"^(\s*)\S", lines[line_idx]).groups()[0]) + 8
71
            while not lines[line_idx].startswith(" " * indent + "("):
72
                new_lines.append(lines[line_idx])
73
                line_idx += 1
74

75
            blocks = []
76
            while lines[line_idx].strip() != "]":
77
                # Blocks either fit in one line or not
78
                if lines[line_idx].strip() == "(":
79
                    start_idx = line_idx
80
                    while not lines[line_idx].startswith(" " * indent + ")"):
81
                        line_idx += 1
82
                    blocks.append("\n".join(lines[start_idx : line_idx + 1]))
83
                else:
84
                    blocks.append(lines[line_idx])
85
                line_idx += 1
86

87
            # Sort blocks by their identifiers
88
            blocks = sorted(blocks, key=lambda x: _re_identifier.search(x).groups()[0])
89
            new_lines += blocks
90
        else:
91
            new_lines.append(lines[line_idx])
92
            line_idx += 1
93

94
    if overwrite:
95
        with open(fname, "w", encoding="utf-8") as f:
96
            f.write("\n".join(new_lines))
97
    else:
98
        return "\n".join(new_lines) != content
99

100

101
def sort_all_auto_mappings(overwrite: bool = False):
102
    """
103
    Sort all auto mappings in the library.
104

105
    Args:
106
        overwrite (`bool`, *optional*, defaults to `False`): Whether or not to fix and overwrite the file.
107
    """
108
    fnames = [os.path.join(PATH_TO_AUTO_MODULE, f) for f in os.listdir(PATH_TO_AUTO_MODULE) if f.endswith(".py")]
109
    diffs = [sort_auto_mapping(fname, overwrite=overwrite) for fname in fnames]
110

111
    if not overwrite and any(diffs):
112
        failures = [f for f, d in zip(fnames, diffs) if d]
113
        raise ValueError(
114
            f"The following files have auto mappings that need sorting: {', '.join(failures)}. Run `make style` to fix"
115
            " this."
116
        )
117

118

119
if __name__ == "__main__":
120
    parser = argparse.ArgumentParser()
121
    parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.")
122
    args = parser.parse_args()
123

124
    sort_all_auto_mappings(not args.check_only)
125

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.