CSS-LM
374 строки · 14.1 Кб
1# coding=utf-8
2# Copyright 2018 The HuggingFace Inc. team.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Convert pytorch checkpoints to TensorFlow """
16
17
18import argparse19import logging20import os21
22from transformers import (23ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,24BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,25CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,26CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,27DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,28ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,29FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,30GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,31OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,32ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,33T5_PRETRAINED_CONFIG_ARCHIVE_MAP,34TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,35WEIGHTS_NAME,36XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,37XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,38XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,39AlbertConfig,40BertConfig,41CamembertConfig,42CTRLConfig,43DistilBertConfig,44ElectraConfig,45FlaubertConfig,46GPT2Config,47OpenAIGPTConfig,48RobertaConfig,49T5Config,50TFAlbertForPreTraining,51TFBertForPreTraining,52TFBertForQuestionAnswering,53TFBertForSequenceClassification,54TFCamembertForMaskedLM,55TFCTRLLMHeadModel,56TFDistilBertForMaskedLM,57TFDistilBertForQuestionAnswering,58TFElectraForPreTraining,59TFFlaubertWithLMHeadModel,60TFGPT2LMHeadModel,61TFOpenAIGPTLMHeadModel,62TFRobertaForMaskedLM,63TFRobertaForSequenceClassification,64TFT5ForConditionalGeneration,65TFTransfoXLLMHeadModel,66TFXLMRobertaForMaskedLM,67TFXLMWithLMHeadModel,68TFXLNetLMHeadModel,69TransfoXLConfig,70XLMConfig,71XLMRobertaConfig,72XLNetConfig,73cached_path,74is_torch_available,75load_pytorch_checkpoint_in_tf2_model,76)
77from transformers.file_utils import hf_bucket_url78
79
80if is_torch_available():81import torch82import numpy as np83from transformers import (84BertForPreTraining,85BertForQuestionAnswering,86BertForSequenceClassification,87GPT2LMHeadModel,88XLNetLMHeadModel,89XLMWithLMHeadModel,90XLMRobertaForMaskedLM,91TransfoXLLMHeadModel,92OpenAIGPTLMHeadModel,93RobertaForMaskedLM,94RobertaForSequenceClassification,95CamembertForMaskedLM,96FlaubertWithLMHeadModel,97DistilBertForMaskedLM,98DistilBertForQuestionAnswering,99CTRLLMHeadModel,100AlbertForPreTraining,101T5ForConditionalGeneration,102ElectraForPreTraining,103)104
105
106logging.basicConfig(level=logging.INFO)107
108MODEL_CLASSES = {109"bert": (BertConfig, TFBertForPreTraining, BertForPreTraining, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,),110"bert-large-uncased-whole-word-masking-finetuned-squad": (111BertConfig,112TFBertForQuestionAnswering,113BertForQuestionAnswering,114BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,115),116"bert-large-cased-whole-word-masking-finetuned-squad": (117BertConfig,118TFBertForQuestionAnswering,119BertForQuestionAnswering,120BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,121),122"bert-base-cased-finetuned-mrpc": (123BertConfig,124TFBertForSequenceClassification,125BertForSequenceClassification,126BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,127),128"gpt2": (GPT2Config, TFGPT2LMHeadModel, GPT2LMHeadModel, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,),129"xlnet": (XLNetConfig, TFXLNetLMHeadModel, XLNetLMHeadModel, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,),130"xlm": (XLMConfig, TFXLMWithLMHeadModel, XLMWithLMHeadModel, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,),131"xlm-roberta": (132XLMRobertaConfig,133TFXLMRobertaForMaskedLM,134XLMRobertaForMaskedLM,135XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,136),137"transfo-xl": (138TransfoXLConfig,139TFTransfoXLLMHeadModel,140TransfoXLLMHeadModel,141TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,142),143"openai-gpt": (144OpenAIGPTConfig,145TFOpenAIGPTLMHeadModel,146OpenAIGPTLMHeadModel,147OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,148),149"roberta": (RobertaConfig, TFRobertaForMaskedLM, RobertaForMaskedLM, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,),150"roberta-large-mnli": (151RobertaConfig,152TFRobertaForSequenceClassification,153RobertaForSequenceClassification,154ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,155),156"camembert": (157CamembertConfig,158TFCamembertForMaskedLM,159CamembertForMaskedLM,160CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,161),162"flaubert": (163FlaubertConfig,164TFFlaubertWithLMHeadModel,165FlaubertWithLMHeadModel,166FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,167),168"distilbert": (169DistilBertConfig,170TFDistilBertForMaskedLM,171DistilBertForMaskedLM,172DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,173),174"distilbert-base-distilled-squad": (175DistilBertConfig,176TFDistilBertForQuestionAnswering,177DistilBertForQuestionAnswering,178DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,179),180"ctrl": (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,),181"albert": (AlbertConfig, TFAlbertForPreTraining, AlbertForPreTraining, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,),182"t5": (T5Config, TFT5ForConditionalGeneration, T5ForConditionalGeneration, T5_PRETRAINED_CONFIG_ARCHIVE_MAP,),183"electra": (ElectraConfig, TFElectraForPreTraining, ElectraForPreTraining, ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,),184}
185
186
187def convert_pt_checkpoint_to_tf(188model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True189):190if model_type not in MODEL_CLASSES:191raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))192
193config_class, model_class, pt_model_class, aws_config_map = MODEL_CLASSES[model_type]194
195# Initialise TF model196if config_file in aws_config_map:197config_file = cached_path(aws_config_map[config_file], force_download=not use_cached_models)198config = config_class.from_json_file(config_file)199config.output_hidden_states = True200config.output_attentions = True201print("Building TensorFlow model from configuration: {}".format(str(config)))202tf_model = model_class(config)203
204# Load weights from tf checkpoint205if pytorch_checkpoint_path in aws_config_map.keys():206pytorch_checkpoint_url = hf_bucket_url(pytorch_checkpoint_path, filename=WEIGHTS_NAME)207pytorch_checkpoint_path = cached_path(pytorch_checkpoint_url, force_download=not use_cached_models)208# Load PyTorch checkpoint in tf2 model:209tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)210
211if compare_with_pt_model:212tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network213
214state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu")215pt_model = pt_model_class.from_pretrained(216pretrained_model_name_or_path=None, config=config, state_dict=state_dict217)218
219with torch.no_grad():220pto = pt_model(**pt_model.dummy_inputs)221
222np_pt = pto[0].numpy()223np_tf = tfo[0].numpy()224diff = np.amax(np.abs(np_pt - np_tf))225print("Max absolute difference between models outputs {}".format(diff))226assert diff <= 2e-2, "Error, model absolute difference is >2e-2: {}".format(diff)227
228# Save pytorch-model229print("Save TensorFlow model to {}".format(tf_dump_path))230tf_model.save_weights(tf_dump_path, save_format="h5")231
232
233def convert_all_pt_checkpoints_to_tf(234args_model_type,235tf_dump_path,236model_shortcut_names_or_path=None,237config_shortcut_names_or_path=None,238compare_with_pt_model=False,239use_cached_models=False,240remove_cached_files=False,241only_convert_finetuned_models=False,242):243
244if args_model_type is None:245model_types = list(MODEL_CLASSES.keys())246else:247model_types = [args_model_type]248
249for j, model_type in enumerate(model_types, start=1):250print("=" * 100)251print(" Converting model type {}/{}: {}".format(j, len(model_types), model_type))252print("=" * 100)253if model_type not in MODEL_CLASSES:254raise ValueError(255"Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys()))256)257
258config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]259
260if model_shortcut_names_or_path is None:261model_shortcut_names_or_path = list(aws_model_maps.keys())262if config_shortcut_names_or_path is None:263config_shortcut_names_or_path = model_shortcut_names_or_path264
265for i, (model_shortcut_name, config_shortcut_name) in enumerate(266zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1267):268print("-" * 100)269if "-squad" in model_shortcut_name or "-mrpc" in model_shortcut_name or "-mnli" in model_shortcut_name:270if not only_convert_finetuned_models:271print(" Skipping finetuned checkpoint {}".format(model_shortcut_name))272continue273model_type = model_shortcut_name274elif only_convert_finetuned_models:275print(" Skipping not finetuned checkpoint {}".format(model_shortcut_name))276continue277print(278" Converting checkpoint {}/{}: {} - model_type {}".format(279i, len(aws_config_map), model_shortcut_name, model_type280)281)282print("-" * 100)283
284if config_shortcut_name in aws_config_map:285config_file = cached_path(aws_config_map[config_shortcut_name], force_download=not use_cached_models)286else:287config_file = cached_path(config_shortcut_name, force_download=not use_cached_models)288
289if model_shortcut_name in aws_model_maps:290model_file = cached_path(aws_model_maps[model_shortcut_name], force_download=not use_cached_models)291else:292model_file = cached_path(model_shortcut_name, force_download=not use_cached_models)293
294if os.path.isfile(model_shortcut_name):295model_shortcut_name = "converted_model"296
297convert_pt_checkpoint_to_tf(298model_type=model_type,299pytorch_checkpoint_path=model_file,300config_file=config_file,301tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + "-tf_model.h5"),302compare_with_pt_model=compare_with_pt_model,303)304if remove_cached_files:305os.remove(config_file)306os.remove(model_file)307
308
309if __name__ == "__main__":310parser = argparse.ArgumentParser()311# Required parameters312parser.add_argument(313"--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file."314)315parser.add_argument(316"--model_type",317default=None,318type=str,319help="Model type selected in the list of {}. If not given, will download and convert all the models from AWS.".format(320list(MODEL_CLASSES.keys())321),322)323parser.add_argument(324"--pytorch_checkpoint_path",325default=None,326type=str,327help="Path to the PyTorch checkpoint path or shortcut name to download from AWS. "328"If not given, will download and convert all the checkpoints from AWS.",329)330parser.add_argument(331"--config_file",332default=None,333type=str,334help="The config json file corresponding to the pre-trained model. \n"335"This specifies the model architecture. If not given and "336"--pytorch_checkpoint_path is not given or is a shortcut name"337"use the configuration associated to the shortcut name on the AWS",338)339parser.add_argument(340"--compare_with_pt_model", action="store_true", help="Compare Tensorflow and PyTorch model predictions."341)342parser.add_argument(343"--use_cached_models",344action="store_true",345help="Use cached models if possible instead of updating to latest checkpoint versions.",346)347parser.add_argument(348"--remove_cached_files",349action="store_true",350help="Remove pytorch models after conversion (save memory when converting in batches).",351)352parser.add_argument("--only_convert_finetuned_models", action="store_true", help="Only convert finetuned models.")353args = parser.parse_args()354
355# if args.pytorch_checkpoint_path is not None:356# convert_pt_checkpoint_to_tf(args.model_type.lower(),357# args.pytorch_checkpoint_path,358# args.config_file if args.config_file is not None else args.pytorch_checkpoint_path,359# args.tf_dump_path,360# compare_with_pt_model=args.compare_with_pt_model,361# use_cached_models=args.use_cached_models)362# else:363convert_all_pt_checkpoints_to_tf(364args.model_type.lower() if args.model_type is not None else None,365args.tf_dump_path,366model_shortcut_names_or_path=[args.pytorch_checkpoint_path]367if args.pytorch_checkpoint_path is not None368else None,369config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,370compare_with_pt_model=args.compare_with_pt_model,371use_cached_models=args.use_cached_models,372remove_cached_files=args.remove_cached_files,373only_convert_finetuned_models=args.only_convert_finetuned_models,374)375