transformers

Форк
0
/
update_metadata.py 
340 строк · 14.4 Кб
1
# coding=utf-8
2
# Copyright 2021 The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""
16
Utility that updates the metadata of the Transformers library in the repository `huggingface/transformers-metadata`.
17

18
Usage for an update (as used by the GitHub action `update_metadata`):
19

20
```bash
21
python utils/update_metadata.py --token <token> --commit_sha <commit_sha>
22
```
23

24
Usage to check all pipelines are properly defined in the constant `PIPELINE_TAGS_AND_AUTO_MODELS` of this script, so
25
that new pipelines are properly added as metadata (as used in `make repo-consistency`):
26

27
```bash
28
python utils/update_metadata.py --check-only
29
```
30
"""
31
import argparse
32
import collections
33
import os
34
import re
35
import tempfile
36
from typing import Dict, List, Tuple
37

38
import pandas as pd
39
from datasets import Dataset
40
from huggingface_hub import hf_hub_download, upload_folder
41

42
from transformers.utils import direct_transformers_import
43

44

45
# All paths are set with the intent you should run this script from the root of the repo with the command
46
# python utils/update_metadata.py
47
TRANSFORMERS_PATH = "src/transformers"
48

49

50
# This is to make sure the transformers module imported is the one in the repo.
51
transformers_module = direct_transformers_import(TRANSFORMERS_PATH)
52

53

54
# Regexes that match TF/Flax/PT model names.
55
_re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
56
_re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
57
# Will match any TF or Flax model too so need to be in an else branch afterthe two previous regexes.
58
_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
59

60

61
# Fill this with tuples (pipeline_tag, model_mapping, auto_model)
62
PIPELINE_TAGS_AND_AUTO_MODELS = [
63
    ("pretraining", "MODEL_FOR_PRETRAINING_MAPPING_NAMES", "AutoModelForPreTraining"),
64
    ("feature-extraction", "MODEL_MAPPING_NAMES", "AutoModel"),
65
    ("image-feature-extraction", "MODEL_FOR_IMAGE_MAPPING_NAMES", "AutoModel"),
66
    ("audio-classification", "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForAudioClassification"),
67
    ("text-generation", "MODEL_FOR_CAUSAL_LM_MAPPING_NAMES", "AutoModelForCausalLM"),
68
    ("automatic-speech-recognition", "MODEL_FOR_CTC_MAPPING_NAMES", "AutoModelForCTC"),
69
    ("image-classification", "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES", "AutoModelForImageClassification"),
70
    ("image-segmentation", "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES", "AutoModelForImageSegmentation"),
71
    ("image-to-image", "MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES", "AutoModelForImageToImage"),
72
    ("fill-mask", "MODEL_FOR_MASKED_LM_MAPPING_NAMES", "AutoModelForMaskedLM"),
73
    ("object-detection", "MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES", "AutoModelForObjectDetection"),
74
    (
75
        "zero-shot-object-detection",
76
        "MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES",
77
        "AutoModelForZeroShotObjectDetection",
78
    ),
79
    ("question-answering", "MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES", "AutoModelForQuestionAnswering"),
80
    ("text2text-generation", "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES", "AutoModelForSeq2SeqLM"),
81
    ("text-classification", "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES", "AutoModelForSequenceClassification"),
82
    ("automatic-speech-recognition", "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES", "AutoModelForSpeechSeq2Seq"),
83
    (
84
        "table-question-answering",
85
        "MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES",
86
        "AutoModelForTableQuestionAnswering",
87
    ),
88
    ("token-classification", "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES", "AutoModelForTokenClassification"),
89
    ("multiple-choice", "MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES", "AutoModelForMultipleChoice"),
90
    (
91
        "next-sentence-prediction",
92
        "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES",
93
        "AutoModelForNextSentencePrediction",
94
    ),
95
    (
96
        "audio-frame-classification",
97
        "MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES",
98
        "AutoModelForAudioFrameClassification",
99
    ),
100
    ("audio-xvector", "MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES", "AutoModelForAudioXVector"),
101
    (
102
        "document-question-answering",
103
        "MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES",
104
        "AutoModelForDocumentQuestionAnswering",
105
    ),
106
    (
107
        "visual-question-answering",
108
        "MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES",
109
        "AutoModelForVisualQuestionAnswering",
110
    ),
111
    ("image-to-text", "MODEL_FOR_FOR_VISION_2_SEQ_MAPPING_NAMES", "AutoModelForVision2Seq"),
112
    (
113
        "zero-shot-image-classification",
114
        "MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES",
115
        "AutoModelForZeroShotImageClassification",
116
    ),
117
    ("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"),
118
    ("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"),
119
    ("mask-generation", "MODEL_FOR_MASK_GENERATION_MAPPING_NAMES", "AutoModelForMaskGeneration"),
120
    ("text-to-audio", "MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES", "AutoModelForTextToSpectrogram"),
121
    ("text-to-audio", "MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES", "AutoModelForTextToWaveform"),
122
]
123

124

125
def camel_case_split(identifier: str) -> List[str]:
126
    """
127
    Split a camel-cased name into words.
128

129
    Args:
130
        identifier (`str`): The camel-cased name to parse.
131

132
    Returns:
133
        `List[str]`: The list of words in the identifier (as seprated by capital letters).
134

135
    Example:
136

137
    ```py
138
    >>> camel_case_split("CamelCasedClass")
139
    ["Camel", "Cased", "Class"]
140
    ```
141
    """
142
    # Regex thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
143
    matches = re.finditer(".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", identifier)
144
    return [m.group(0) for m in matches]
145

146

147
def get_frameworks_table() -> pd.DataFrame:
148
    """
149
    Generates a dataframe containing the supported auto classes for each model type, using the content of the auto
150
    modules.
151
    """
152
    # Dictionary model names to config.
153
    config_maping_names = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
154
    model_prefix_to_model_type = {
155
        config.replace("Config", ""): model_type for model_type, config in config_maping_names.items()
156
    }
157

158
    # Dictionaries flagging if each model prefix has a backend in PT/TF/Flax.
159
    pt_models = collections.defaultdict(bool)
160
    tf_models = collections.defaultdict(bool)
161
    flax_models = collections.defaultdict(bool)
162

163
    # Let's lookup through all transformers object (once) and find if models are supported by a given backend.
164
    for attr_name in dir(transformers_module):
165
        lookup_dict = None
166
        if _re_tf_models.match(attr_name) is not None:
167
            lookup_dict = tf_models
168
            attr_name = _re_tf_models.match(attr_name).groups()[0]
169
        elif _re_flax_models.match(attr_name) is not None:
170
            lookup_dict = flax_models
171
            attr_name = _re_flax_models.match(attr_name).groups()[0]
172
        elif _re_pt_models.match(attr_name) is not None:
173
            lookup_dict = pt_models
174
            attr_name = _re_pt_models.match(attr_name).groups()[0]
175

176
        if lookup_dict is not None:
177
            while len(attr_name) > 0:
178
                if attr_name in model_prefix_to_model_type:
179
                    lookup_dict[model_prefix_to_model_type[attr_name]] = True
180
                    break
181
                # Try again after removing the last word in the name
182
                attr_name = "".join(camel_case_split(attr_name)[:-1])
183

184
    all_models = set(list(pt_models.keys()) + list(tf_models.keys()) + list(flax_models.keys()))
185
    all_models = list(all_models)
186
    all_models.sort()
187

188
    data = {"model_type": all_models}
189
    data["pytorch"] = [pt_models[t] for t in all_models]
190
    data["tensorflow"] = [tf_models[t] for t in all_models]
191
    data["flax"] = [flax_models[t] for t in all_models]
192

193
    # Now let's find the right processing class for each model. In order we check if there is a Processor, then a
194
    # Tokenizer, then a FeatureExtractor, then an ImageProcessor
195
    processors = {}
196
    for t in all_models:
197
        if t in transformers_module.models.auto.processing_auto.PROCESSOR_MAPPING_NAMES:
198
            processors[t] = "AutoProcessor"
199
        elif t in transformers_module.models.auto.tokenization_auto.TOKENIZER_MAPPING_NAMES:
200
            processors[t] = "AutoTokenizer"
201
        elif t in transformers_module.models.auto.image_processing_auto.IMAGE_PROCESSOR_MAPPING_NAMES:
202
            processors[t] = "AutoImageProcessor"
203
        elif t in transformers_module.models.auto.feature_extraction_auto.FEATURE_EXTRACTOR_MAPPING_NAMES:
204
            processors[t] = "AutoFeatureExtractor"
205
        else:
206
            # Default to AutoTokenizer if a model has nothing, for backward compatibility.
207
            processors[t] = "AutoTokenizer"
208

209
    data["processor"] = [processors[t] for t in all_models]
210

211
    return pd.DataFrame(data)
212

213

214
def update_pipeline_and_auto_class_table(table: Dict[str, Tuple[str, str]]) -> Dict[str, Tuple[str, str]]:
215
    """
216
    Update the table maping models to pipelines and auto classes without removing old keys if they don't exist anymore.
217

218
    Args:
219
        table (`Dict[str, Tuple[str, str]]`):
220
            The existing table mapping model names to a tuple containing the pipeline tag and the auto-class name with
221
            which they should be used.
222

223
    Returns:
224
        `Dict[str, Tuple[str, str]]`: The updated table in the same format.
225
    """
226
    auto_modules = [
227
        transformers_module.models.auto.modeling_auto,
228
        transformers_module.models.auto.modeling_tf_auto,
229
        transformers_module.models.auto.modeling_flax_auto,
230
    ]
231
    for pipeline_tag, model_mapping, auto_class in PIPELINE_TAGS_AND_AUTO_MODELS:
232
        model_mappings = [model_mapping, f"TF_{model_mapping}", f"FLAX_{model_mapping}"]
233
        auto_classes = [auto_class, f"TF_{auto_class}", f"Flax_{auto_class}"]
234
        # Loop through all three frameworks
235
        for module, cls, mapping in zip(auto_modules, auto_classes, model_mappings):
236
            # The type of pipeline may not exist in this framework
237
            if not hasattr(module, mapping):
238
                continue
239
            # First extract all model_names
240
            model_names = []
241
            for name in getattr(module, mapping).values():
242
                if isinstance(name, str):
243
                    model_names.append(name)
244
                else:
245
                    model_names.extend(list(name))
246

247
            # Add pipeline tag and auto model class for those models
248
            table.update({model_name: (pipeline_tag, cls) for model_name in model_names})
249

250
    return table
251

252

253
def update_metadata(token: str, commit_sha: str):
254
    """
255
    Update the metadata for the Transformers repo in `huggingface/transformers-metadata`.
256

257
    Args:
258
        token (`str`): A valid token giving write access to `huggingface/transformers-metadata`.
259
        commit_sha (`str`): The commit SHA on Transformers corresponding to this update.
260
    """
261
    frameworks_table = get_frameworks_table()
262
    frameworks_dataset = Dataset.from_pandas(frameworks_table)
263

264
    resolved_tags_file = hf_hub_download(
265
        "huggingface/transformers-metadata", "pipeline_tags.json", repo_type="dataset", token=token
266
    )
267
    tags_dataset = Dataset.from_json(resolved_tags_file)
268
    table = {
269
        tags_dataset[i]["model_class"]: (tags_dataset[i]["pipeline_tag"], tags_dataset[i]["auto_class"])
270
        for i in range(len(tags_dataset))
271
    }
272
    table = update_pipeline_and_auto_class_table(table)
273

274
    # Sort the model classes to avoid some nondeterministic updates to create false update commits.
275
    model_classes = sorted(table.keys())
276
    tags_table = pd.DataFrame(
277
        {
278
            "model_class": model_classes,
279
            "pipeline_tag": [table[m][0] for m in model_classes],
280
            "auto_class": [table[m][1] for m in model_classes],
281
        }
282
    )
283
    tags_dataset = Dataset.from_pandas(tags_table)
284

285
    with tempfile.TemporaryDirectory() as tmp_dir:
286
        frameworks_dataset.to_json(os.path.join(tmp_dir, "frameworks.json"))
287
        tags_dataset.to_json(os.path.join(tmp_dir, "pipeline_tags.json"))
288

289
        if commit_sha is not None:
290
            commit_message = (
291
                f"Update with commit {commit_sha}\n\nSee: "
292
                f"https://github.com/huggingface/transformers/commit/{commit_sha}"
293
            )
294
        else:
295
            commit_message = "Update"
296

297
        upload_folder(
298
            repo_id="huggingface/transformers-metadata",
299
            folder_path=tmp_dir,
300
            repo_type="dataset",
301
            token=token,
302
            commit_message=commit_message,
303
        )
304

305

306
def check_pipeline_tags():
307
    """
308
    Check all pipeline tags are properly defined in the `PIPELINE_TAGS_AND_AUTO_MODELS` constant of this script.
309
    """
310
    in_table = {tag: cls for tag, _, cls in PIPELINE_TAGS_AND_AUTO_MODELS}
311
    pipeline_tasks = transformers_module.pipelines.SUPPORTED_TASKS
312
    missing = []
313
    for key in pipeline_tasks:
314
        if key not in in_table:
315
            model = pipeline_tasks[key]["pt"]
316
            if isinstance(model, (list, tuple)):
317
                model = model[0]
318
            model = model.__name__
319
            if model not in in_table.values():
320
                missing.append(key)
321

322
    if len(missing) > 0:
323
        msg = ", ".join(missing)
324
        raise ValueError(
325
            "The following pipeline tags are not present in the `PIPELINE_TAGS_AND_AUTO_MODELS` constant inside "
326
            f"`utils/update_metadata.py`: {msg}. Please add them!"
327
        )
328

329

330
if __name__ == "__main__":
331
    parser = argparse.ArgumentParser()
332
    parser.add_argument("--token", type=str, help="The token to use to push to the transformers-metadata dataset.")
333
    parser.add_argument("--commit_sha", type=str, help="The sha of the commit going with this update.")
334
    parser.add_argument("--check-only", action="store_true", help="Activate to just check all pipelines are present.")
335
    args = parser.parse_args()
336

337
    if args.check_only:
338
        check_pipeline_tags()
339
    else:
340
        update_metadata(args.token, args.commit_sha)
341

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.