transformers

Форк
0
/
custom_init_isort.py 
329 строк · 13.3 Кб
1
# coding=utf-8
2
# Copyright 2021 The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""
16
Utility that sorts the imports in the custom inits of Transformers. Transformers uses init files that delay the
17
import of an object to when it's actually needed. This is to avoid the main init importing all models, which would
18
make the line `import transformers` very slow when the user has all optional dependencies installed. The inits with
19
delayed imports have two halves: one definining a dictionary `_import_structure` which maps modules to the name of the
20
objects in each module, and one in `TYPE_CHECKING` which looks like a normal init for type-checkers. `isort` or `ruff`
21
properly sort the second half which looks like traditionl imports, the goal of this script is to sort the first half.
22

23
Use from the root of the repo with:
24

25
```bash
26
python utils/custom_init_isort.py
27
```
28

29
which will auto-sort the imports (used in `make style`).
30

31
For a check only (as used in `make quality`) run:
32

33
```bash
34
python utils/custom_init_isort.py --check_only
35
```
36
"""
37
import argparse
38
import os
39
import re
40
from typing import Any, Callable, List, Optional
41

42

43
# Path is defined with the intent you should run this script from the root of the repo.
44
PATH_TO_TRANSFORMERS = "src/transformers"
45

46
# Pattern that looks at the indentation in a line.
47
_re_indent = re.compile(r"^(\s*)\S")
48
# Pattern that matches `"key":" and puts `key` in group 0.
49
_re_direct_key = re.compile(r'^\s*"([^"]+)":')
50
# Pattern that matches `_import_structure["key"]` and puts `key` in group 0.
51
_re_indirect_key = re.compile(r'^\s*_import_structure\["([^"]+)"\]')
52
# Pattern that matches `"key",` and puts `key` in group 0.
53
_re_strip_line = re.compile(r'^\s*"([^"]+)",\s*$')
54
# Pattern that matches any `[stuff]` and puts `stuff` in group 0.
55
_re_bracket_content = re.compile(r"\[([^\]]+)\]")
56

57

58
def get_indent(line: str) -> str:
59
    """Returns the indent in  given line (as string)."""
60
    search = _re_indent.search(line)
61
    return "" if search is None else search.groups()[0]
62

63

64
def split_code_in_indented_blocks(
65
    code: str, indent_level: str = "", start_prompt: Optional[str] = None, end_prompt: Optional[str] = None
66
) -> List[str]:
67
    """
68
    Split some code into its indented blocks, starting at a given level.
69

70
    Args:
71
        code (`str`): The code to split.
72
        indent_level (`str`): The indent level (as string) to use for identifying the blocks to split.
73
        start_prompt (`str`, *optional*): If provided, only starts splitting at the line where this text is.
74
        end_prompt (`str`, *optional*): If provided, stops splitting at a line where this text is.
75

76
    Warning:
77
        The text before `start_prompt` or after `end_prompt` (if provided) is not ignored, just not split. The input `code`
78
        can thus be retrieved by joining the result.
79

80
    Returns:
81
        `List[str]`: The list of blocks.
82
    """
83
    # Let's split the code into lines and move to start_index.
84
    index = 0
85
    lines = code.split("\n")
86
    if start_prompt is not None:
87
        while not lines[index].startswith(start_prompt):
88
            index += 1
89
        blocks = ["\n".join(lines[:index])]
90
    else:
91
        blocks = []
92

93
    # This variable contains the block treated at a given time.
94
    current_block = [lines[index]]
95
    index += 1
96
    # We split into blocks until we get to the `end_prompt` (or the end of the file).
97
    while index < len(lines) and (end_prompt is None or not lines[index].startswith(end_prompt)):
98
        # We have a non-empty line with the proper indent -> start of a new block
99
        if len(lines[index]) > 0 and get_indent(lines[index]) == indent_level:
100
            # Store the current block in the result and rest. There are two cases: the line is part of the block (like
101
            # a closing parenthesis) or not.
102
            if len(current_block) > 0 and get_indent(current_block[-1]).startswith(indent_level + " "):
103
                # Line is part of the current block
104
                current_block.append(lines[index])
105
                blocks.append("\n".join(current_block))
106
                if index < len(lines) - 1:
107
                    current_block = [lines[index + 1]]
108
                    index += 1
109
                else:
110
                    current_block = []
111
            else:
112
                # Line is not part of the current block
113
                blocks.append("\n".join(current_block))
114
                current_block = [lines[index]]
115
        else:
116
            # Just add the line to the current block
117
            current_block.append(lines[index])
118
        index += 1
119

120
    # Adds current block if it's nonempty.
121
    if len(current_block) > 0:
122
        blocks.append("\n".join(current_block))
123

124
    # Add final block after end_prompt if provided.
125
    if end_prompt is not None and index < len(lines):
126
        blocks.append("\n".join(lines[index:]))
127

128
    return blocks
129

130

131
def ignore_underscore_and_lowercase(key: Callable[[Any], str]) -> Callable[[Any], str]:
132
    """
133
    Wraps a key function (as used in a sort) to lowercase and ignore underscores.
134
    """
135

136
    def _inner(x):
137
        return key(x).lower().replace("_", "")
138

139
    return _inner
140

141

142
def sort_objects(objects: List[Any], key: Optional[Callable[[Any], str]] = None) -> List[Any]:
143
    """
144
    Sort a list of objects following the rules of isort (all uppercased first, camel-cased second and lower-cased
145
    last).
146

147
    Args:
148
        objects (`List[Any]`):
149
            The list of objects to sort.
150
        key (`Callable[[Any], str]`, *optional*):
151
            A function taking an object as input and returning a string, used to sort them by alphabetical order.
152
            If not provided, will default to noop (so a `key` must be provided if the `objects` are not of type string).
153

154
    Returns:
155
        `List[Any]`: The sorted list with the same elements as in the inputs
156
    """
157

158
    # If no key is provided, we use a noop.
159
    def noop(x):
160
        return x
161

162
    if key is None:
163
        key = noop
164
    # Constants are all uppercase, they go first.
165
    constants = [obj for obj in objects if key(obj).isupper()]
166
    # Classes are not all uppercase but start with a capital, they go second.
167
    classes = [obj for obj in objects if key(obj)[0].isupper() and not key(obj).isupper()]
168
    # Functions begin with a lowercase, they go last.
169
    functions = [obj for obj in objects if not key(obj)[0].isupper()]
170

171
    # Then we sort each group.
172
    key1 = ignore_underscore_and_lowercase(key)
173
    return sorted(constants, key=key1) + sorted(classes, key=key1) + sorted(functions, key=key1)
174

175

176
def sort_objects_in_import(import_statement: str) -> str:
177
    """
178
    Sorts the imports in a single import statement.
179

180
    Args:
181
        import_statement (`str`): The import statement in which to sort the imports.
182

183
    Returns:
184
        `str`: The same as the input, but with objects properly sorted.
185
    """
186

187
    # This inner function sort imports between [ ].
188
    def _replace(match):
189
        imports = match.groups()[0]
190
        # If there is one import only, nothing to do.
191
        if "," not in imports:
192
            return f"[{imports}]"
193
        keys = [part.strip().replace('"', "") for part in imports.split(",")]
194
        # We will have a final empty element if the line finished with a comma.
195
        if len(keys[-1]) == 0:
196
            keys = keys[:-1]
197
        return "[" + ", ".join([f'"{k}"' for k in sort_objects(keys)]) + "]"
198

199
    lines = import_statement.split("\n")
200
    if len(lines) > 3:
201
        # Here we have to sort internal imports that are on several lines (one per name):
202
        # key: [
203
        #     "object1",
204
        #     "object2",
205
        #     ...
206
        # ]
207

208
        # We may have to ignore one or two lines on each side.
209
        idx = 2 if lines[1].strip() == "[" else 1
210
        keys_to_sort = [(i, _re_strip_line.search(line).groups()[0]) for i, line in enumerate(lines[idx:-idx])]
211
        sorted_indices = sort_objects(keys_to_sort, key=lambda x: x[1])
212
        sorted_lines = [lines[x[0] + idx] for x in sorted_indices]
213
        return "\n".join(lines[:idx] + sorted_lines + lines[-idx:])
214
    elif len(lines) == 3:
215
        # Here we have to sort internal imports that are on one separate line:
216
        # key: [
217
        #     "object1", "object2", ...
218
        # ]
219
        if _re_bracket_content.search(lines[1]) is not None:
220
            lines[1] = _re_bracket_content.sub(_replace, lines[1])
221
        else:
222
            keys = [part.strip().replace('"', "") for part in lines[1].split(",")]
223
            # We will have a final empty element if the line finished with a comma.
224
            if len(keys[-1]) == 0:
225
                keys = keys[:-1]
226
            lines[1] = get_indent(lines[1]) + ", ".join([f'"{k}"' for k in sort_objects(keys)])
227
        return "\n".join(lines)
228
    else:
229
        # Finally we have to deal with imports fitting on one line
230
        import_statement = _re_bracket_content.sub(_replace, import_statement)
231
        return import_statement
232

233

234
def sort_imports(file: str, check_only: bool = True):
235
    """
236
    Sort the imports defined in the `_import_structure` of a given init.
237

238
    Args:
239
        file (`str`): The path to the init to check/fix.
240
        check_only (`bool`, *optional*, defaults to `True`): Whether or not to just check (and not auto-fix) the init.
241
    """
242
    with open(file, encoding="utf-8") as f:
243
        code = f.read()
244

245
    # If the file is not a custom init, there is nothing to do.
246
    if "_import_structure" not in code:
247
        return
248

249
    # Blocks of indent level 0
250
    main_blocks = split_code_in_indented_blocks(
251
        code, start_prompt="_import_structure = {", end_prompt="if TYPE_CHECKING:"
252
    )
253

254
    # We ignore block 0 (everything untils start_prompt) and the last block (everything after end_prompt).
255
    for block_idx in range(1, len(main_blocks) - 1):
256
        # Check if the block contains some `_import_structure`s thingy to sort.
257
        block = main_blocks[block_idx]
258
        block_lines = block.split("\n")
259

260
        # Get to the start of the imports.
261
        line_idx = 0
262
        while line_idx < len(block_lines) and "_import_structure" not in block_lines[line_idx]:
263
            # Skip dummy import blocks
264
            if "import dummy" in block_lines[line_idx]:
265
                line_idx = len(block_lines)
266
            else:
267
                line_idx += 1
268
        if line_idx >= len(block_lines):
269
            continue
270

271
        # Ignore beginning and last line: they don't contain anything.
272
        internal_block_code = "\n".join(block_lines[line_idx:-1])
273
        indent = get_indent(block_lines[1])
274
        # Slit the internal block into blocks of indent level 1.
275
        internal_blocks = split_code_in_indented_blocks(internal_block_code, indent_level=indent)
276
        # We have two categories of import key: list or _import_structure[key].append/extend
277
        pattern = _re_direct_key if "_import_structure = {" in block_lines[0] else _re_indirect_key
278
        # Grab the keys, but there is a trap: some lines are empty or just comments.
279
        keys = [(pattern.search(b).groups()[0] if pattern.search(b) is not None else None) for b in internal_blocks]
280
        # We only sort the lines with a key.
281
        keys_to_sort = [(i, key) for i, key in enumerate(keys) if key is not None]
282
        sorted_indices = [x[0] for x in sorted(keys_to_sort, key=lambda x: x[1])]
283

284
        # We reorder the blocks by leaving empty lines/comments as they were and reorder the rest.
285
        count = 0
286
        reorderded_blocks = []
287
        for i in range(len(internal_blocks)):
288
            if keys[i] is None:
289
                reorderded_blocks.append(internal_blocks[i])
290
            else:
291
                block = sort_objects_in_import(internal_blocks[sorted_indices[count]])
292
                reorderded_blocks.append(block)
293
                count += 1
294

295
        # And we put our main block back together with its first and last line.
296
        main_blocks[block_idx] = "\n".join(block_lines[:line_idx] + reorderded_blocks + [block_lines[-1]])
297

298
    if code != "\n".join(main_blocks):
299
        if check_only:
300
            return True
301
        else:
302
            print(f"Overwriting {file}.")
303
            with open(file, "w", encoding="utf-8") as f:
304
                f.write("\n".join(main_blocks))
305

306

307
def sort_imports_in_all_inits(check_only=True):
308
    """
309
    Sort the imports defined in the `_import_structure` of all inits in the repo.
310

311
    Args:
312
        check_only (`bool`, *optional*, defaults to `True`): Whether or not to just check (and not auto-fix) the init.
313
    """
314
    failures = []
315
    for root, _, files in os.walk(PATH_TO_TRANSFORMERS):
316
        if "__init__.py" in files:
317
            result = sort_imports(os.path.join(root, "__init__.py"), check_only=check_only)
318
            if result:
319
                failures = [os.path.join(root, "__init__.py")]
320
    if len(failures) > 0:
321
        raise ValueError(f"Would overwrite {len(failures)} files, run `make style`.")
322

323

324
if __name__ == "__main__":
325
    parser = argparse.ArgumentParser()
326
    parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.")
327
    args = parser.parse_args()
328

329
    sort_imports_in_all_inits(check_only=args.check_only)
330

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.