CSS-LM

Форк
0
/
convert_pytorch_checkpoint_to_tf2.py 
374 строки · 14.1 Кб
1
# coding=utf-8
2
# Copyright 2018 The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Convert pytorch checkpoints to TensorFlow """
16

17

18
import argparse
19
import logging
20
import os
21

22
from transformers import (
23
    ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
24
    BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
25
    CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
26
    CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
27
    DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
28
    ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
29
    FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
30
    GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
31
    OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
32
    ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
33
    T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
34
    TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
35
    WEIGHTS_NAME,
36
    XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
37
    XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
38
    XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
39
    AlbertConfig,
40
    BertConfig,
41
    CamembertConfig,
42
    CTRLConfig,
43
    DistilBertConfig,
44
    ElectraConfig,
45
    FlaubertConfig,
46
    GPT2Config,
47
    OpenAIGPTConfig,
48
    RobertaConfig,
49
    T5Config,
50
    TFAlbertForPreTraining,
51
    TFBertForPreTraining,
52
    TFBertForQuestionAnswering,
53
    TFBertForSequenceClassification,
54
    TFCamembertForMaskedLM,
55
    TFCTRLLMHeadModel,
56
    TFDistilBertForMaskedLM,
57
    TFDistilBertForQuestionAnswering,
58
    TFElectraForPreTraining,
59
    TFFlaubertWithLMHeadModel,
60
    TFGPT2LMHeadModel,
61
    TFOpenAIGPTLMHeadModel,
62
    TFRobertaForMaskedLM,
63
    TFRobertaForSequenceClassification,
64
    TFT5ForConditionalGeneration,
65
    TFTransfoXLLMHeadModel,
66
    TFXLMRobertaForMaskedLM,
67
    TFXLMWithLMHeadModel,
68
    TFXLNetLMHeadModel,
69
    TransfoXLConfig,
70
    XLMConfig,
71
    XLMRobertaConfig,
72
    XLNetConfig,
73
    cached_path,
74
    is_torch_available,
75
    load_pytorch_checkpoint_in_tf2_model,
76
)
77
from transformers.file_utils import hf_bucket_url
78

79

80
if is_torch_available():
81
    import torch
82
    import numpy as np
83
    from transformers import (
84
        BertForPreTraining,
85
        BertForQuestionAnswering,
86
        BertForSequenceClassification,
87
        GPT2LMHeadModel,
88
        XLNetLMHeadModel,
89
        XLMWithLMHeadModel,
90
        XLMRobertaForMaskedLM,
91
        TransfoXLLMHeadModel,
92
        OpenAIGPTLMHeadModel,
93
        RobertaForMaskedLM,
94
        RobertaForSequenceClassification,
95
        CamembertForMaskedLM,
96
        FlaubertWithLMHeadModel,
97
        DistilBertForMaskedLM,
98
        DistilBertForQuestionAnswering,
99
        CTRLLMHeadModel,
100
        AlbertForPreTraining,
101
        T5ForConditionalGeneration,
102
        ElectraForPreTraining,
103
    )
104

105

106
logging.basicConfig(level=logging.INFO)
107

108
MODEL_CLASSES = {
109
    "bert": (BertConfig, TFBertForPreTraining, BertForPreTraining, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,),
110
    "bert-large-uncased-whole-word-masking-finetuned-squad": (
111
        BertConfig,
112
        TFBertForQuestionAnswering,
113
        BertForQuestionAnswering,
114
        BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
115
    ),
116
    "bert-large-cased-whole-word-masking-finetuned-squad": (
117
        BertConfig,
118
        TFBertForQuestionAnswering,
119
        BertForQuestionAnswering,
120
        BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
121
    ),
122
    "bert-base-cased-finetuned-mrpc": (
123
        BertConfig,
124
        TFBertForSequenceClassification,
125
        BertForSequenceClassification,
126
        BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
127
    ),
128
    "gpt2": (GPT2Config, TFGPT2LMHeadModel, GPT2LMHeadModel, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,),
129
    "xlnet": (XLNetConfig, TFXLNetLMHeadModel, XLNetLMHeadModel, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,),
130
    "xlm": (XLMConfig, TFXLMWithLMHeadModel, XLMWithLMHeadModel, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,),
131
    "xlm-roberta": (
132
        XLMRobertaConfig,
133
        TFXLMRobertaForMaskedLM,
134
        XLMRobertaForMaskedLM,
135
        XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
136
    ),
137
    "transfo-xl": (
138
        TransfoXLConfig,
139
        TFTransfoXLLMHeadModel,
140
        TransfoXLLMHeadModel,
141
        TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
142
    ),
143
    "openai-gpt": (
144
        OpenAIGPTConfig,
145
        TFOpenAIGPTLMHeadModel,
146
        OpenAIGPTLMHeadModel,
147
        OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
148
    ),
149
    "roberta": (RobertaConfig, TFRobertaForMaskedLM, RobertaForMaskedLM, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,),
150
    "roberta-large-mnli": (
151
        RobertaConfig,
152
        TFRobertaForSequenceClassification,
153
        RobertaForSequenceClassification,
154
        ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
155
    ),
156
    "camembert": (
157
        CamembertConfig,
158
        TFCamembertForMaskedLM,
159
        CamembertForMaskedLM,
160
        CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
161
    ),
162
    "flaubert": (
163
        FlaubertConfig,
164
        TFFlaubertWithLMHeadModel,
165
        FlaubertWithLMHeadModel,
166
        FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
167
    ),
168
    "distilbert": (
169
        DistilBertConfig,
170
        TFDistilBertForMaskedLM,
171
        DistilBertForMaskedLM,
172
        DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
173
    ),
174
    "distilbert-base-distilled-squad": (
175
        DistilBertConfig,
176
        TFDistilBertForQuestionAnswering,
177
        DistilBertForQuestionAnswering,
178
        DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
179
    ),
180
    "ctrl": (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,),
181
    "albert": (AlbertConfig, TFAlbertForPreTraining, AlbertForPreTraining, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,),
182
    "t5": (T5Config, TFT5ForConditionalGeneration, T5ForConditionalGeneration, T5_PRETRAINED_CONFIG_ARCHIVE_MAP,),
183
    "electra": (ElectraConfig, TFElectraForPreTraining, ElectraForPreTraining, ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,),
184
}
185

186

187
def convert_pt_checkpoint_to_tf(
188
    model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True
189
):
190
    if model_type not in MODEL_CLASSES:
191
        raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))
192

193
    config_class, model_class, pt_model_class, aws_config_map = MODEL_CLASSES[model_type]
194

195
    # Initialise TF model
196
    if config_file in aws_config_map:
197
        config_file = cached_path(aws_config_map[config_file], force_download=not use_cached_models)
198
    config = config_class.from_json_file(config_file)
199
    config.output_hidden_states = True
200
    config.output_attentions = True
201
    print("Building TensorFlow model from configuration: {}".format(str(config)))
202
    tf_model = model_class(config)
203

204
    # Load weights from tf checkpoint
205
    if pytorch_checkpoint_path in aws_config_map.keys():
206
        pytorch_checkpoint_url = hf_bucket_url(pytorch_checkpoint_path, filename=WEIGHTS_NAME)
207
        pytorch_checkpoint_path = cached_path(pytorch_checkpoint_url, force_download=not use_cached_models)
208
    # Load PyTorch checkpoint in tf2 model:
209
    tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
210

211
    if compare_with_pt_model:
212
        tfo = tf_model(tf_model.dummy_inputs, training=False)  # build the network
213

214
        state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu")
215
        pt_model = pt_model_class.from_pretrained(
216
            pretrained_model_name_or_path=None, config=config, state_dict=state_dict
217
        )
218

219
        with torch.no_grad():
220
            pto = pt_model(**pt_model.dummy_inputs)
221

222
        np_pt = pto[0].numpy()
223
        np_tf = tfo[0].numpy()
224
        diff = np.amax(np.abs(np_pt - np_tf))
225
        print("Max absolute difference between models outputs {}".format(diff))
226
        assert diff <= 2e-2, "Error, model absolute difference is >2e-2: {}".format(diff)
227

228
    # Save pytorch-model
229
    print("Save TensorFlow model to {}".format(tf_dump_path))
230
    tf_model.save_weights(tf_dump_path, save_format="h5")
231

232

233
def convert_all_pt_checkpoints_to_tf(
234
    args_model_type,
235
    tf_dump_path,
236
    model_shortcut_names_or_path=None,
237
    config_shortcut_names_or_path=None,
238
    compare_with_pt_model=False,
239
    use_cached_models=False,
240
    remove_cached_files=False,
241
    only_convert_finetuned_models=False,
242
):
243

244
    if args_model_type is None:
245
        model_types = list(MODEL_CLASSES.keys())
246
    else:
247
        model_types = [args_model_type]
248

249
    for j, model_type in enumerate(model_types, start=1):
250
        print("=" * 100)
251
        print(" Converting model type {}/{}: {}".format(j, len(model_types), model_type))
252
        print("=" * 100)
253
        if model_type not in MODEL_CLASSES:
254
            raise ValueError(
255
                "Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys()))
256
            )
257

258
        config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
259

260
        if model_shortcut_names_or_path is None:
261
            model_shortcut_names_or_path = list(aws_model_maps.keys())
262
        if config_shortcut_names_or_path is None:
263
            config_shortcut_names_or_path = model_shortcut_names_or_path
264

265
        for i, (model_shortcut_name, config_shortcut_name) in enumerate(
266
            zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1
267
        ):
268
            print("-" * 100)
269
            if "-squad" in model_shortcut_name or "-mrpc" in model_shortcut_name or "-mnli" in model_shortcut_name:
270
                if not only_convert_finetuned_models:
271
                    print("    Skipping finetuned checkpoint {}".format(model_shortcut_name))
272
                    continue
273
                model_type = model_shortcut_name
274
            elif only_convert_finetuned_models:
275
                print("    Skipping not finetuned checkpoint {}".format(model_shortcut_name))
276
                continue
277
            print(
278
                "    Converting checkpoint {}/{}: {} - model_type {}".format(
279
                    i, len(aws_config_map), model_shortcut_name, model_type
280
                )
281
            )
282
            print("-" * 100)
283

284
            if config_shortcut_name in aws_config_map:
285
                config_file = cached_path(aws_config_map[config_shortcut_name], force_download=not use_cached_models)
286
            else:
287
                config_file = cached_path(config_shortcut_name, force_download=not use_cached_models)
288

289
            if model_shortcut_name in aws_model_maps:
290
                model_file = cached_path(aws_model_maps[model_shortcut_name], force_download=not use_cached_models)
291
            else:
292
                model_file = cached_path(model_shortcut_name, force_download=not use_cached_models)
293

294
            if os.path.isfile(model_shortcut_name):
295
                model_shortcut_name = "converted_model"
296

297
            convert_pt_checkpoint_to_tf(
298
                model_type=model_type,
299
                pytorch_checkpoint_path=model_file,
300
                config_file=config_file,
301
                tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + "-tf_model.h5"),
302
                compare_with_pt_model=compare_with_pt_model,
303
            )
304
            if remove_cached_files:
305
                os.remove(config_file)
306
                os.remove(model_file)
307

308

309
if __name__ == "__main__":
310
    parser = argparse.ArgumentParser()
311
    # Required parameters
312
    parser.add_argument(
313
        "--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file."
314
    )
315
    parser.add_argument(
316
        "--model_type",
317
        default=None,
318
        type=str,
319
        help="Model type selected in the list of {}. If not given, will download and convert all the models from AWS.".format(
320
            list(MODEL_CLASSES.keys())
321
        ),
322
    )
323
    parser.add_argument(
324
        "--pytorch_checkpoint_path",
325
        default=None,
326
        type=str,
327
        help="Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
328
        "If not given, will download and convert all the checkpoints from AWS.",
329
    )
330
    parser.add_argument(
331
        "--config_file",
332
        default=None,
333
        type=str,
334
        help="The config json file corresponding to the pre-trained model. \n"
335
        "This specifies the model architecture. If not given and "
336
        "--pytorch_checkpoint_path is not given or is a shortcut name"
337
        "use the configuration associated to the shortcut name on the AWS",
338
    )
339
    parser.add_argument(
340
        "--compare_with_pt_model", action="store_true", help="Compare Tensorflow and PyTorch model predictions."
341
    )
342
    parser.add_argument(
343
        "--use_cached_models",
344
        action="store_true",
345
        help="Use cached models if possible instead of updating to latest checkpoint versions.",
346
    )
347
    parser.add_argument(
348
        "--remove_cached_files",
349
        action="store_true",
350
        help="Remove pytorch models after conversion (save memory when converting in batches).",
351
    )
352
    parser.add_argument("--only_convert_finetuned_models", action="store_true", help="Only convert finetuned models.")
353
    args = parser.parse_args()
354

355
    # if args.pytorch_checkpoint_path is not None:
356
    #     convert_pt_checkpoint_to_tf(args.model_type.lower(),
357
    #                                 args.pytorch_checkpoint_path,
358
    #                                 args.config_file if args.config_file is not None else args.pytorch_checkpoint_path,
359
    #                                 args.tf_dump_path,
360
    #                                 compare_with_pt_model=args.compare_with_pt_model,
361
    #                                 use_cached_models=args.use_cached_models)
362
    # else:
363
    convert_all_pt_checkpoints_to_tf(
364
        args.model_type.lower() if args.model_type is not None else None,
365
        args.tf_dump_path,
366
        model_shortcut_names_or_path=[args.pytorch_checkpoint_path]
367
        if args.pytorch_checkpoint_path is not None
368
        else None,
369
        config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
370
        compare_with_pt_model=args.compare_with_pt_model,
371
        use_cached_models=args.use_cached_models,
372
        remove_cached_files=args.remove_cached_files,
373
        only_convert_finetuned_models=args.only_convert_finetuned_models,
374
    )
375

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.