transformers
236 строк · 8.2 Кб
1# coding=utf-8
2# Copyright 2020 The HuggingFace Inc. team.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""
16This script is responsible for making sure the dummies in utils/dummies_xxx.py are up to date with the main init.
17
18Why dummies? This is to make sure that a user can always import all objects from `transformers`, even if they don't
19have the necessary extra libs installed. Those objects will then raise helpful error message whenever the user tries
20to access one of their methods.
21
22Usage (from the root of the repo):
23
24Check that the dummy files are up to date (used in `make repo-consistency`):
25
26```bash
27python utils/check_dummies.py
28```
29
30Update the dummy files if needed (used in `make fix-copies`):
31
32```bash
33python utils/check_dummies.py --fix_and_overwrite
34```
35"""
36import argparse37import os38import re39from typing import Dict, List, Optional40
41
42# All paths are set with the intent you should run this script from the root of the repo with the command
43# python utils/check_dummies.py
44PATH_TO_TRANSFORMERS = "src/transformers"45
46# Matches is_xxx_available()
47_re_backend = re.compile(r"is\_([a-z_]*)_available()")48# Matches from xxx import bla
49_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")50# Matches if not is_xxx_available()
51_re_test_backend = re.compile(r"^\s+if\s+not\s+\(?is\_[a-z_]*\_available\(\)")52
53
54# Template for the dummy objects.
55DUMMY_CONSTANT = """56{0} = None
57"""
58
59
60DUMMY_CLASS = """61class {0}(metaclass=DummyObject):
62_backends = {1}
63
64def __init__(self, *args, **kwargs):
65requires_backends(self, {1})
66"""
67
68
69DUMMY_FUNCTION = """70def {0}(*args, **kwargs):
71requires_backends({0}, {1})
72"""
73
74
75def find_backend(line: str) -> Optional[str]:76"""77Find one (or multiple) backend in a code line of the init.
78
79Args:
80line (`str`): A code line in an init file.
81
82Returns:
83Optional[`str`]: If one (or several) backend is found, returns it. In the case of multiple backends (the line
84contains `if is_xxx_available() and `is_yyy_available()`) returns all backends joined on `_and_` (so
85`xxx_and_yyy` for instance).
86"""
87if _re_test_backend.search(line) is None:88return None89backends = [b[0] for b in _re_backend.findall(line)]90backends.sort()91return "_and_".join(backends)92
93
94def read_init() -> Dict[str, List[str]]:95"""96Read the init and extract backend-specific objects.
97
98Returns:
99Dict[str, List[str]]: A dictionary mapping backend name to the list of object names requiring that backend.
100"""
101with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:102lines = f.readlines()103
104# Get to the point we do the actual imports for type checking105line_index = 0106while not lines[line_index].startswith("if TYPE_CHECKING"):107line_index += 1108
109backend_specific_objects = {}110# Go through the end of the file111while line_index < len(lines):112# If the line is an if is_backend_available, we grab all objects associated.113backend = find_backend(lines[line_index])114if backend is not None:115while not lines[line_index].startswith(" else:"):116line_index += 1117line_index += 1118
119objects = []120# Until we unindent, add backend objects to the list121while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):122line = lines[line_index]123single_line_import_search = _re_single_line_import.search(line)124if single_line_import_search is not None:125# Single-line imports126objects.extend(single_line_import_search.groups()[0].split(", "))127elif line.startswith(" " * 12):128# Multiple-line imports (with 3 indent level)129objects.append(line[12:-2])130line_index += 1131
132backend_specific_objects[backend] = objects133else:134line_index += 1135
136return backend_specific_objects137
138
139def create_dummy_object(name: str, backend_name: str) -> str:140"""141Create the code for a dummy object.
142
143Args:
144name (`str`): The name of the object.
145backend_name (`str`): The name of the backend required for that object.
146
147Returns:
148`str`: The code of the dummy object.
149"""
150if name.isupper():151return DUMMY_CONSTANT.format(name)152elif name.islower():153return DUMMY_FUNCTION.format(name, backend_name)154else:155return DUMMY_CLASS.format(name, backend_name)156
157
158def create_dummy_files(backend_specific_objects: Optional[Dict[str, List[str]]] = None) -> Dict[str, str]:159"""160Create the content of the dummy files.
161
162Args:
163backend_specific_objects (`Dict[str, List[str]]`, *optional*):
164The mapping backend name to list of backend-specific objects. If not passed, will be obtained by calling
165`read_init()`.
166
167Returns:
168`Dict[str, str]`: A dictionary mapping backend name to code of the corresponding backend file.
169"""
170if backend_specific_objects is None:171backend_specific_objects = read_init()172
173dummy_files = {}174
175for backend, objects in backend_specific_objects.items():176backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]"177dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"178dummy_file += "from ..utils import DummyObject, requires_backends\n\n"179dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])180dummy_files[backend] = dummy_file181
182return dummy_files183
184
185def check_dummies(overwrite: bool = False):186"""187Check if the dummy files are up to date and maybe `overwrite` with the right content.
188
189Args:
190overwrite (`bool`, *optional*, default to `False`):
191Whether or not to overwrite the content of the dummy files. Will raise an error if they are not up to date
192when `overwrite=False`.
193"""
194dummy_files = create_dummy_files()195# For special correspondence backend name to shortcut as used in utils/dummy_xxx_objects.py196short_names = {"torch": "pt"}197
198# Locate actual dummy modules and read their content.199path = os.path.join(PATH_TO_TRANSFORMERS, "utils")200dummy_file_paths = {201backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py")202for backend in dummy_files.keys()203}204
205actual_dummies = {}206for backend, file_path in dummy_file_paths.items():207if os.path.isfile(file_path):208with open(file_path, "r", encoding="utf-8", newline="\n") as f:209actual_dummies[backend] = f.read()210else:211actual_dummies[backend] = ""212
213# Compare actual with what they should be.214for backend in dummy_files.keys():215if dummy_files[backend] != actual_dummies[backend]:216if overwrite:217print(218f"Updating transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "219"__init__ has new objects."220)221with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f:222f.write(dummy_files[backend])223else:224raise ValueError(225"The main __init__ has objects that are not present in "226f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` "227"to fix this."228)229
230
231if __name__ == "__main__":232parser = argparse.ArgumentParser()233parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")234args = parser.parse_args()235
236check_dummies(args.fix_and_overwrite)237