15
"""Convert BERT checkpoint."""
23
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
26
logging.basicConfig(level=logging.INFO)
29
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
31
config = BertConfig.from_json_file(bert_config_file)
32
print("Building PyTorch model from configuration: {}".format(str(config)))
33
model = BertForPreTraining(config)
36
load_tf_weights_in_bert(model, config, tf_checkpoint_path)
39
print("Save PyTorch model to {}".format(pytorch_dump_path))
40
torch.save(model.state_dict(), pytorch_dump_path)
43
if __name__ == "__main__":
44
parser = argparse.ArgumentParser()
47
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
54
help="The config json file corresponding to the pre-trained BERT model. \n"
55
"This specifies the model architecture.",
58
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
60
args = parser.parse_args()
61
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)