transformers

Форк
0
/
check_dummies.py 
236 строк · 8.2 Кб
1
# coding=utf-8
2
# Copyright 2020 The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""
16
This script is responsible for making sure the dummies in utils/dummies_xxx.py are up to date with the main init.
17

18
Why dummies? This is to make sure that a user can always import all objects from `transformers`, even if they don't
19
have the necessary extra libs installed. Those objects will then raise helpful error message whenever the user tries
20
to access one of their methods.
21

22
Usage (from the root of the repo):
23

24
Check that the dummy files are up to date (used in `make repo-consistency`):
25

26
```bash
27
python utils/check_dummies.py
28
```
29

30
Update the dummy files if needed (used in `make fix-copies`):
31

32
```bash
33
python utils/check_dummies.py --fix_and_overwrite
34
```
35
"""
36
import argparse
37
import os
38
import re
39
from typing import Dict, List, Optional
40

41

42
# All paths are set with the intent you should run this script from the root of the repo with the command
43
# python utils/check_dummies.py
44
PATH_TO_TRANSFORMERS = "src/transformers"
45

46
# Matches is_xxx_available()
47
_re_backend = re.compile(r"is\_([a-z_]*)_available()")
48
# Matches from xxx import bla
49
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
50
# Matches if not is_xxx_available()
51
_re_test_backend = re.compile(r"^\s+if\s+not\s+\(?is\_[a-z_]*\_available\(\)")
52

53

54
# Template for the dummy objects.
55
DUMMY_CONSTANT = """
56
{0} = None
57
"""
58

59

60
DUMMY_CLASS = """
61
class {0}(metaclass=DummyObject):
62
    _backends = {1}
63

64
    def __init__(self, *args, **kwargs):
65
        requires_backends(self, {1})
66
"""
67

68

69
DUMMY_FUNCTION = """
70
def {0}(*args, **kwargs):
71
    requires_backends({0}, {1})
72
"""
73

74

75
def find_backend(line: str) -> Optional[str]:
76
    """
77
    Find one (or multiple) backend in a code line of the init.
78

79
    Args:
80
        line (`str`): A code line in an init file.
81

82
    Returns:
83
        Optional[`str`]: If one (or several) backend is found, returns it. In the case of multiple backends (the line
84
        contains `if is_xxx_available() and `is_yyy_available()`) returns all backends joined on `_and_` (so
85
        `xxx_and_yyy` for instance).
86
    """
87
    if _re_test_backend.search(line) is None:
88
        return None
89
    backends = [b[0] for b in _re_backend.findall(line)]
90
    backends.sort()
91
    return "_and_".join(backends)
92

93

94
def read_init() -> Dict[str, List[str]]:
95
    """
96
    Read the init and extract backend-specific objects.
97

98
    Returns:
99
        Dict[str, List[str]]: A dictionary mapping backend name to the list of object names requiring that backend.
100
    """
101
    with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
102
        lines = f.readlines()
103

104
    # Get to the point we do the actual imports for type checking
105
    line_index = 0
106
    while not lines[line_index].startswith("if TYPE_CHECKING"):
107
        line_index += 1
108

109
    backend_specific_objects = {}
110
    # Go through the end of the file
111
    while line_index < len(lines):
112
        # If the line is an if is_backend_available, we grab all objects associated.
113
        backend = find_backend(lines[line_index])
114
        if backend is not None:
115
            while not lines[line_index].startswith("    else:"):
116
                line_index += 1
117
            line_index += 1
118

119
            objects = []
120
            # Until we unindent, add backend objects to the list
121
            while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
122
                line = lines[line_index]
123
                single_line_import_search = _re_single_line_import.search(line)
124
                if single_line_import_search is not None:
125
                    # Single-line imports
126
                    objects.extend(single_line_import_search.groups()[0].split(", "))
127
                elif line.startswith(" " * 12):
128
                    # Multiple-line imports (with 3 indent level)
129
                    objects.append(line[12:-2])
130
                line_index += 1
131

132
            backend_specific_objects[backend] = objects
133
        else:
134
            line_index += 1
135

136
    return backend_specific_objects
137

138

139
def create_dummy_object(name: str, backend_name: str) -> str:
140
    """
141
    Create the code for a dummy object.
142

143
    Args:
144
        name (`str`): The name of the object.
145
        backend_name (`str`): The name of the backend required for that object.
146

147
    Returns:
148
        `str`: The code of the dummy object.
149
    """
150
    if name.isupper():
151
        return DUMMY_CONSTANT.format(name)
152
    elif name.islower():
153
        return DUMMY_FUNCTION.format(name, backend_name)
154
    else:
155
        return DUMMY_CLASS.format(name, backend_name)
156

157

158
def create_dummy_files(backend_specific_objects: Optional[Dict[str, List[str]]] = None) -> Dict[str, str]:
159
    """
160
    Create the content of the dummy files.
161

162
    Args:
163
        backend_specific_objects (`Dict[str, List[str]]`, *optional*):
164
            The mapping backend name to list of backend-specific objects. If not passed, will be obtained by calling
165
            `read_init()`.
166

167
    Returns:
168
        `Dict[str, str]`: A dictionary mapping backend name to code of the corresponding backend file.
169
    """
170
    if backend_specific_objects is None:
171
        backend_specific_objects = read_init()
172

173
    dummy_files = {}
174

175
    for backend, objects in backend_specific_objects.items():
176
        backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]"
177
        dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
178
        dummy_file += "from ..utils import DummyObject, requires_backends\n\n"
179
        dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
180
        dummy_files[backend] = dummy_file
181

182
    return dummy_files
183

184

185
def check_dummies(overwrite: bool = False):
186
    """
187
    Check if the dummy files are up to date and maybe `overwrite` with the right content.
188

189
    Args:
190
        overwrite (`bool`, *optional*, default to `False`):
191
            Whether or not to overwrite the content of the dummy files. Will raise an error if they are not up to date
192
            when `overwrite=False`.
193
    """
194
    dummy_files = create_dummy_files()
195
    # For special correspondence backend name to shortcut as used in utils/dummy_xxx_objects.py
196
    short_names = {"torch": "pt"}
197

198
    # Locate actual dummy modules and read their content.
199
    path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
200
    dummy_file_paths = {
201
        backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py")
202
        for backend in dummy_files.keys()
203
    }
204

205
    actual_dummies = {}
206
    for backend, file_path in dummy_file_paths.items():
207
        if os.path.isfile(file_path):
208
            with open(file_path, "r", encoding="utf-8", newline="\n") as f:
209
                actual_dummies[backend] = f.read()
210
        else:
211
            actual_dummies[backend] = ""
212

213
    # Compare actual with what they should be.
214
    for backend in dummy_files.keys():
215
        if dummy_files[backend] != actual_dummies[backend]:
216
            if overwrite:
217
                print(
218
                    f"Updating transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "
219
                    "__init__ has new objects."
220
                )
221
                with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f:
222
                    f.write(dummy_files[backend])
223
            else:
224
                raise ValueError(
225
                    "The main __init__ has objects that are not present in "
226
                    f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` "
227
                    "to fix this."
228
                )
229

230

231
if __name__ == "__main__":
232
    parser = argparse.ArgumentParser()
233
    parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
234
    args = parser.parse_args()
235

236
    check_dummies(args.fix_and_overwrite)
237

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.