CSS-LM

Форк
0
/
convert_bert_original_tf_checkpoint_to_pytorch.py 
61 строка · 2.1 Кб
1
# coding=utf-8
2
# Copyright 2018 The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Convert BERT checkpoint."""
16

17

18
import argparse
19
import logging
20

21
import torch
22

23
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
24

25

26
logging.basicConfig(level=logging.INFO)
27

28

29
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
30
    # Initialise PyTorch model
31
    config = BertConfig.from_json_file(bert_config_file)
32
    print("Building PyTorch model from configuration: {}".format(str(config)))
33
    model = BertForPreTraining(config)
34

35
    # Load weights from tf checkpoint
36
    load_tf_weights_in_bert(model, config, tf_checkpoint_path)
37

38
    # Save pytorch-model
39
    print("Save PyTorch model to {}".format(pytorch_dump_path))
40
    torch.save(model.state_dict(), pytorch_dump_path)
41

42

43
if __name__ == "__main__":
44
    parser = argparse.ArgumentParser()
45
    # Required parameters
46
    parser.add_argument(
47
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
48
    )
49
    parser.add_argument(
50
        "--bert_config_file",
51
        default=None,
52
        type=str,
53
        required=True,
54
        help="The config json file corresponding to the pre-trained BERT model. \n"
55
        "This specifies the model architecture.",
56
    )
57
    parser.add_argument(
58
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
59
    )
60
    args = parser.parse_args()
61
    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
62

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.