CSS-LM
1745 строк · 100.7 Кб
1# coding=utf-8
2# Copyright 2018 The HuggingFace Inc. team.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Auto Model class. """
16
17
18import logging19import warnings20from collections import OrderedDict21
22from .configuration_auto import (23AlbertConfig,24AutoConfig,25BartConfig,26BertConfig,27CamembertConfig,28CTRLConfig,29DistilBertConfig,30ElectraConfig,31EncoderDecoderConfig,32FlaubertConfig,33GPT2Config,34LongformerConfig,35MobileBertConfig,36OpenAIGPTConfig,37ReformerConfig,38RetriBertConfig,39RobertaConfig,40T5Config,41TransfoXLConfig,42XLMConfig,43XLMRobertaConfig,44XLNetConfig,45)
46from .configuration_marian import MarianConfig47from .configuration_utils import PretrainedConfig48from .modeling_albert import (49AlbertForMaskedLM,50AlbertForMultipleChoice,51AlbertForPreTraining,52AlbertForQuestionAnswering,53AlbertForSequenceClassification,54AlbertForTokenClassification,55AlbertModel,56)
57from .modeling_bart import (58BartForConditionalGeneration,59BartForQuestionAnswering,60BartForSequenceClassification,61BartModel,62)
63from .modeling_bert import (64BertForMaskedLM,65BertForMultipleChoice,66BertForPreTraining,67BertForQuestionAnswering,68BertForSequenceClassification,69BertForTokenClassification,70BertLMHeadModel,71BertModel,72)
73from .modeling_camembert import (74CamembertForMaskedLM,75CamembertForMultipleChoice,76CamembertForQuestionAnswering,77CamembertForSequenceClassification,78CamembertForTokenClassification,79CamembertModel,80)
81from .modeling_ctrl import CTRLLMHeadModel, CTRLModel82from .modeling_distilbert import (83DistilBertForMaskedLM,84DistilBertForMultipleChoice,85DistilBertForQuestionAnswering,86DistilBertForSequenceClassification,87DistilBertForTokenClassification,88DistilBertModel,89)
90from .modeling_electra import (91ElectraForMaskedLM,92ElectraForMultipleChoice,93ElectraForPreTraining,94ElectraForQuestionAnswering,95ElectraForSequenceClassification,96ElectraForTokenClassification,97ElectraModel,98)
99from .modeling_encoder_decoder import EncoderDecoderModel100from .modeling_flaubert import (101FlaubertForMultipleChoice,102FlaubertForQuestionAnsweringSimple,103FlaubertForSequenceClassification,104FlaubertForTokenClassification,105FlaubertModel,106FlaubertWithLMHeadModel,107)
108from .modeling_gpt2 import GPT2LMHeadModel, GPT2Model109from .modeling_longformer import (110LongformerForMaskedLM,111LongformerForMultipleChoice,112LongformerForQuestionAnswering,113LongformerForSequenceClassification,114LongformerForTokenClassification,115LongformerModel,116)
117from .modeling_marian import MarianMTModel118from .modeling_mobilebert import (119MobileBertForMaskedLM,120MobileBertForMultipleChoice,121MobileBertForPreTraining,122MobileBertForQuestionAnswering,123MobileBertForSequenceClassification,124MobileBertForTokenClassification,125MobileBertModel,126)
127from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel128from .modeling_reformer import (129ReformerForMaskedLM,130ReformerForQuestionAnswering,131ReformerModel,132ReformerModelWithLMHead,133)
134from .modeling_retribert import RetriBertModel135from .modeling_roberta import (136RobertaForMaskedLM,137RobertaForMultipleChoice,138RobertaForQuestionAnswering,139RobertaForSequenceClassification,140RobertaForTokenClassification,141RobertaModel,142)
143from .modeling_t5 import T5ForConditionalGeneration, T5Model144from .modeling_transfo_xl import TransfoXLLMHeadModel, TransfoXLModel145from .modeling_xlm import (146XLMForMultipleChoice,147XLMForQuestionAnsweringSimple,148XLMForSequenceClassification,149XLMForTokenClassification,150XLMModel,151XLMWithLMHeadModel,152)
153from .modeling_xlm_roberta import (154XLMRobertaForMaskedLM,155XLMRobertaForMultipleChoice,156XLMRobertaForQuestionAnswering,157XLMRobertaForSequenceClassification,158XLMRobertaForTokenClassification,159XLMRobertaModel,160)
161from .modeling_xlnet import (162XLNetForMultipleChoice,163XLNetForQuestionAnsweringSimple,164XLNetForSequenceClassification,165XLNetForTokenClassification,166XLNetLMHeadModel,167XLNetModel,168)
169
170
171logger = logging.getLogger(__name__)172
173
174MODEL_MAPPING = OrderedDict(175[176(RetriBertConfig, RetriBertModel),177(T5Config, T5Model),178(DistilBertConfig, DistilBertModel),179(AlbertConfig, AlbertModel),180(CamembertConfig, CamembertModel),181(XLMRobertaConfig, XLMRobertaModel),182(BartConfig, BartModel),183(LongformerConfig, LongformerModel),184(RobertaConfig, RobertaModel),185(BertConfig, BertModel),186(OpenAIGPTConfig, OpenAIGPTModel),187(GPT2Config, GPT2Model),188(MobileBertConfig, MobileBertModel),189(TransfoXLConfig, TransfoXLModel),190(XLNetConfig, XLNetModel),191(FlaubertConfig, FlaubertModel),192(XLMConfig, XLMModel),193(CTRLConfig, CTRLModel),194(ElectraConfig, ElectraModel),195(ReformerConfig, ReformerModel),196]197)
198
199MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(200[201(RetriBertConfig, RetriBertModel),202(T5Config, T5ForConditionalGeneration),203(DistilBertConfig, DistilBertForMaskedLM),204(AlbertConfig, AlbertForPreTraining),205(CamembertConfig, CamembertForMaskedLM),206(XLMRobertaConfig, XLMRobertaForMaskedLM),207(BartConfig, BartForConditionalGeneration),208(LongformerConfig, LongformerForMaskedLM),209(RobertaConfig, RobertaForMaskedLM),210(BertConfig, BertForPreTraining),211(OpenAIGPTConfig, OpenAIGPTLMHeadModel),212(GPT2Config, GPT2LMHeadModel),213(MobileBertConfig, MobileBertForPreTraining),214(TransfoXLConfig, TransfoXLLMHeadModel),215(XLNetConfig, XLNetLMHeadModel),216(FlaubertConfig, FlaubertWithLMHeadModel),217(XLMConfig, XLMWithLMHeadModel),218(CTRLConfig, CTRLLMHeadModel),219(ElectraConfig, ElectraForPreTraining),220]221)
222
223MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(224[225(T5Config, T5ForConditionalGeneration),226(DistilBertConfig, DistilBertForMaskedLM),227(AlbertConfig, AlbertForMaskedLM),228(CamembertConfig, CamembertForMaskedLM),229(XLMRobertaConfig, XLMRobertaForMaskedLM),230(MarianConfig, MarianMTModel),231(BartConfig, BartForConditionalGeneration),232(LongformerConfig, LongformerForMaskedLM),233(RobertaConfig, RobertaForMaskedLM),234(BertConfig, BertForMaskedLM),235(OpenAIGPTConfig, OpenAIGPTLMHeadModel),236(GPT2Config, GPT2LMHeadModel),237(MobileBertConfig, MobileBertForMaskedLM),238(TransfoXLConfig, TransfoXLLMHeadModel),239(XLNetConfig, XLNetLMHeadModel),240(FlaubertConfig, FlaubertWithLMHeadModel),241(XLMConfig, XLMWithLMHeadModel),242(CTRLConfig, CTRLLMHeadModel),243(ElectraConfig, ElectraForMaskedLM),244(EncoderDecoderConfig, EncoderDecoderModel),245(ReformerConfig, ReformerModelWithLMHead),246]247)
248
249MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(250[251(BertConfig, BertLMHeadModel),252(OpenAIGPTConfig, OpenAIGPTLMHeadModel),253(GPT2Config, GPT2LMHeadModel),254(TransfoXLConfig, TransfoXLLMHeadModel),255(XLNetConfig, XLNetLMHeadModel),256(257XLMConfig,258XLMWithLMHeadModel,259), # XLM can be MLM and CLM => model should be split similar to BERT; leave here for now260(CTRLConfig, CTRLLMHeadModel),261(ReformerConfig, ReformerModelWithLMHead),262]263)
264
265MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(266[267(DistilBertConfig, DistilBertForMaskedLM),268(AlbertConfig, AlbertForMaskedLM),269(BartConfig, BartForConditionalGeneration),270(CamembertConfig, CamembertForMaskedLM),271(XLMRobertaConfig, XLMRobertaForMaskedLM),272(LongformerConfig, LongformerForMaskedLM),273(RobertaConfig, RobertaForMaskedLM),274(BertConfig, BertForMaskedLM),275(MobileBertConfig, MobileBertForMaskedLM),276(FlaubertConfig, FlaubertWithLMHeadModel),277(XLMConfig, XLMWithLMHeadModel),278(ElectraConfig, ElectraForMaskedLM),279(ReformerConfig, ReformerForMaskedLM),280]281)
282
283MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(284[285(T5Config, T5ForConditionalGeneration),286(MarianConfig, MarianMTModel),287(BartConfig, BartForConditionalGeneration),288(EncoderDecoderConfig, EncoderDecoderModel),289]290)
291
292MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(293[294(DistilBertConfig, DistilBertForSequenceClassification),295(AlbertConfig, AlbertForSequenceClassification),296(CamembertConfig, CamembertForSequenceClassification),297(XLMRobertaConfig, XLMRobertaForSequenceClassification),298(BartConfig, BartForSequenceClassification),299(LongformerConfig, LongformerForSequenceClassification),300(RobertaConfig, RobertaForSequenceClassification),301(BertConfig, BertForSequenceClassification),302(XLNetConfig, XLNetForSequenceClassification),303(MobileBertConfig, MobileBertForSequenceClassification),304(FlaubertConfig, FlaubertForSequenceClassification),305(XLMConfig, XLMForSequenceClassification),306(ElectraConfig, ElectraForSequenceClassification),307]308)
309
310MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(311[312(DistilBertConfig, DistilBertForQuestionAnswering),313(AlbertConfig, AlbertForQuestionAnswering),314(CamembertConfig, CamembertForQuestionAnswering),315(BartConfig, BartForQuestionAnswering),316(LongformerConfig, LongformerForQuestionAnswering),317(XLMRobertaConfig, XLMRobertaForQuestionAnswering),318(RobertaConfig, RobertaForQuestionAnswering),319(BertConfig, BertForQuestionAnswering),320(XLNetConfig, XLNetForQuestionAnsweringSimple),321(FlaubertConfig, FlaubertForQuestionAnsweringSimple),322(MobileBertConfig, MobileBertForQuestionAnswering),323(XLMConfig, XLMForQuestionAnsweringSimple),324(ElectraConfig, ElectraForQuestionAnswering),325(ReformerConfig, ReformerForQuestionAnswering),326]327)
328
329MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(330[331(DistilBertConfig, DistilBertForTokenClassification),332(CamembertConfig, CamembertForTokenClassification),333(FlaubertConfig, FlaubertForTokenClassification),334(XLMConfig, XLMForTokenClassification),335(XLMRobertaConfig, XLMRobertaForTokenClassification),336(LongformerConfig, LongformerForTokenClassification),337(RobertaConfig, RobertaForTokenClassification),338(BertConfig, BertForTokenClassification),339(MobileBertConfig, MobileBertForTokenClassification),340(XLNetConfig, XLNetForTokenClassification),341(AlbertConfig, AlbertForTokenClassification),342(ElectraConfig, ElectraForTokenClassification),343(FlaubertConfig, FlaubertForTokenClassification),344]345)
346
347MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(348[349(CamembertConfig, CamembertForMultipleChoice),350(ElectraConfig, ElectraForMultipleChoice),351(XLMRobertaConfig, XLMRobertaForMultipleChoice),352(LongformerConfig, LongformerForMultipleChoice),353(RobertaConfig, RobertaForMultipleChoice),354(BertConfig, BertForMultipleChoice),355(DistilBertConfig, DistilBertForMultipleChoice),356(MobileBertConfig, MobileBertForMultipleChoice),357(XLNetConfig, XLNetForMultipleChoice),358(AlbertConfig, AlbertForMultipleChoice),359(XLMConfig, XLMForMultipleChoice),360(FlaubertConfig, FlaubertForMultipleChoice),361]362)
363
364
365class AutoModel:366r"""367:class:`~transformers.AutoModel` is a generic model class
368that will be instantiated as one of the base model classes of the library
369when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
370or the `AutoModel.from_config(config)` class methods.
371
372This class cannot be instantiated using `__init__()` (throws an error).
373"""
374
375def __init__(self):376raise EnvironmentError(377"AutoModel is designed to be instantiated "378"using the `AutoModel.from_pretrained(pretrained_model_name_or_path)` or "379"`AutoModel.from_config(config)` methods."380)381
382@classmethod383def from_config(cls, config):384r""" Instantiates one of the base model classes of the library385from a configuration.
386
387Note:
388Loading a model from its configuration file does **not** load the model weights.
389It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
390the model weights
391
392Args:
393config (:class:`~transformers.PretrainedConfig`):
394The model class to instantiate is selected based on the configuration class:
395
396- isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertModel` (DistilBERT model)
397- isInstance of `longformer` configuration class: :class:`~transformers.LongformerModel` (Longformer model)
398- isInstance of `roberta` configuration class: :class:`~transformers.RobertaModel` (RoBERTa model)
399- isInstance of `bert` configuration class: :class:`~transformers.BertModel` (Bert model)
400- isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTModel` (OpenAI GPT model)
401- isInstance of `gpt2` configuration class: :class:`~transformers.GPT2Model` (OpenAI GPT-2 model)
402- isInstance of `ctrl` configuration class: :class:`~transformers.CTRLModel` (Salesforce CTRL model)
403- isInstance of `transfo-xl` configuration class: :class:`~transformers.TransfoXLModel` (Transformer-XL model)
404- isInstance of `xlnet` configuration class: :class:`~transformers.XLNetModel` (XLNet model)
405- isInstance of `xlm` configuration class: :class:`~transformers.XLMModel` (XLM model)
406- isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertModel` (Flaubert model)
407- isInstance of `electra` configuration class: :class:`~transformers.ElectraModel` (Electra model)
408
409Examples::
410
411>>> config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
412>>> model = AutoModel.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
413"""
414for config_class, model_class in MODEL_MAPPING.items():415if isinstance(config, config_class):416return model_class(config)417raise ValueError(418"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"419"Model type should be one of {}.".format(420config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_MAPPING.keys())421)422)423
424@classmethod425def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):426r""" Instantiates one of the base model classes of the library427from a pre-trained model configuration.
428
429The `from_pretrained()` method takes care of returning the correct model class instance
430based on the `model_type` property of the config object, or when it's missing,
431falling back to using pattern matching on the `pretrained_model_name_or_path` string:
432
433- `t5`: :class:`~transformers.T5Model` (T5 model)
434- `distilbert`: :class:`~transformers.DistilBertModel` (DistilBERT model)
435- `albert`: :class:`~transformers.AlbertModel` (ALBERT model)
436- `camembert`: :class:`~transformers.CamembertModel` (CamemBERT model)
437- `xlm-roberta`: :class:`~transformers.XLMRobertaModel` (XLM-RoBERTa model)
438- `longformer` :class:`~transformers.LongformerModel` (Longformer model)
439- `roberta`: :class:`~transformers.RobertaModel` (RoBERTa model)
440- `bert`: :class:`~transformers.BertModel` (Bert model)
441- `openai-gpt`: :class:`~transformers.OpenAIGPTModel` (OpenAI GPT model)
442- `gpt2`: :class:`~transformers.GPT2Model` (OpenAI GPT-2 model)
443- `transfo-xl`: :class:`~transformers.TransfoXLModel` (Transformer-XL model)
444- `xlnet`: :class:`~transformers.XLNetModel` (XLNet model)
445- `xlm`: :class:`~transformers.XLMModel` (XLM model)
446- `ctrl`: :class:`~transformers.CTRLModel` (Salesforce CTRL model)
447- `flaubert`: :class:`~transformers.FlaubertModel` (Flaubert model)
448- `electra`: :class:`~transformers.ElectraModel` (Electra model)
449
450The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
451To train the model, you should first set it back in training mode with `model.train()`
452
453Args:
454pretrained_model_name_or_path: either:
455
456- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
457- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
458- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
459- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
460
461model_args: (`optional`) Sequence of positional arguments:
462All remaning positional arguments will be passed to the underlying model's ``__init__`` method
463
464config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
465Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
466
467- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
468- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
469- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
470
471state_dict: (`optional`) dict:
472an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
473This option can be used if you want to create a model from a pretrained configuration but load your own weights.
474In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
475
476cache_dir: (`optional`) string:
477Path to a directory in which a downloaded pre-trained model
478configuration should be cached if the standard cache should not be used.
479
480force_download: (`optional`) boolean, default False:
481Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
482
483resume_download: (`optional`) boolean, default False:
484Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
485
486proxies: (`optional`) dict, default None:
487A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
488The proxies are used on each request.
489
490output_loading_info: (`optional`) boolean:
491Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
492
493kwargs: (`optional`) Remaining dictionary of keyword arguments:
494These arguments will be passed to the configuration and the model.
495
496Examples::
497
498model = AutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
499assert model.config.output_attentions == True
500# Loading from a TF checkpoint file instead of a PyTorch model (slower)
501config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
502model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
503
504"""
505config = kwargs.pop("config", None)506if not isinstance(config, PretrainedConfig):507config, kwargs = AutoConfig.from_pretrained(508pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs509)510
511for config_class, model_class in MODEL_MAPPING.items():512if isinstance(config, config_class):513return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)514raise ValueError(515"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"516"Model type should be one of {}.".format(517config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_MAPPING.keys())518)519)520
521
522class AutoModelForPreTraining:523r"""524:class:`~transformers.AutoModelForPreTraining` is a generic model class
525that will be instantiated as one of the model classes of the library -with the architecture used for pretraining this model– when created with the `AutoModelForPreTraining.from_pretrained(pretrained_model_name_or_path)`
526class method.
527
528This class cannot be instantiated using `__init__()` (throws an error).
529"""
530
531def __init__(self):532raise EnvironmentError(533"AutoModelForPreTraining is designed to be instantiated "534"using the `AutoModelForPreTraining.from_pretrained(pretrained_model_name_or_path)` or "535"`AutoModelForPreTraining.from_config(config)` methods."536)537
538@classmethod539def from_config(cls, config):540r""" Instantiates one of the base model classes of the library541from a configuration.
542
543Note:
544Loading a model from its configuration file does **not** load the model weights.
545It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
546the model weights
547
548Args:
549config (:class:`~transformers.PretrainedConfig`):
550The model class to instantiate is selected based on the configuration class:
551
552- isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
553- isInstance of `longformer` configuration class: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
554- isInstance of `roberta` configuration class: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
555- isInstance of `bert` configuration class: :class:`~transformers.BertForPreTraining` (Bert model)
556- isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
557- isInstance of `gpt2` configuration class: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
558- isInstance of `ctrl` configuration class: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
559- isInstance of `transfo-xl` configuration class: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
560- isInstance of `xlnet` configuration class: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
561- isInstance of `xlm` configuration class: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
562- isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
563- isInstance of `electra` configuration class: :class:`~transformers.ElectraForPreTraining` (Electra model)
564
565Examples::
566
567>>> config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
568>>> model = AutoModelForPreTraining.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
569"""
570for config_class, model_class in MODEL_FOR_PRETRAINING_MAPPING.items():571if isinstance(config, config_class):572return model_class(config)573raise ValueError(574"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"575"Model type should be one of {}.".format(576config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_PRETRAINING_MAPPING.keys())577)578)579
580@classmethod581def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):582r""" Instantiates one of the model classes of the library -with the architecture used for pretraining this model– from a pre-trained model configuration.583
584The `from_pretrained()` method takes care of returning the correct model class instance
585based on the `model_type` property of the config object, or when it's missing,
586falling back to using pattern matching on the `pretrained_model_name_or_path` string:
587
588- `t5`: :class:`~transformers.T5ModelWithLMHead` (T5 model)
589- `distilbert`: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
590- `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model)
591- `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model)
592- `xlm-roberta`: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-RoBERTa model)
593- `longformer`: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
594- `roberta`: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
595- `bert`: :class:`~transformers.BertForPreTraining` (Bert model)
596- `openai-gpt`: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
597- `gpt2`: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
598- `transfo-xl`: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
599- `xlnet`: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
600- `xlm`: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
601- `ctrl`: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
602- `flaubert`: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
603- `electra`: :class:`~transformers.ElectraForPreTraining` (Electra model)
604
605The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
606To train the model, you should first set it back in training mode with `model.train()`
607
608Args:
609pretrained_model_name_or_path:
610Either:
611
612- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
613- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
614- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
615- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
616model_args: (`optional`) Sequence of positional arguments:
617All remaning positional arguments will be passed to the underlying model's ``__init__`` method
618config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
619Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
620
621- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
622- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
623- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
624
625state_dict: (`optional`) dict:
626an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
627This option can be used if you want to create a model from a pretrained configuration but load your own weights.
628In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
629cache_dir: (`optional`) string:
630Path to a directory in which a downloaded pre-trained model
631configuration should be cached if the standard cache should not be used.
632force_download: (`optional`) boolean, default False:
633Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
634resume_download: (`optional`) boolean, default False:
635Do not delete incompletely received file. Attempt to resume the download if such a file exists.
636proxies: (`optional`) dict, default None:
637A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
638The proxies are used on each request.
639output_loading_info: (`optional`) boolean:
640Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
641kwargs: (`optional`) Remaining dictionary of keyword arguments:
642These arguments will be passed to the configuration and the model.
643
644Examples::
645
646model = AutoModelForPreTraining.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
647model = AutoModelForPreTraining.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
648assert model.config.output_attention == True
649# Loading from a TF checkpoint file instead of a PyTorch model (slower)
650config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
651model = AutoModelForPreTraining.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
652
653"""
654config = kwargs.pop("config", None)655if not isinstance(config, PretrainedConfig):656config, kwargs = AutoConfig.from_pretrained(657pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs658)659
660for config_class, model_class in MODEL_FOR_PRETRAINING_MAPPING.items():661if isinstance(config, config_class):662return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)663raise ValueError(664"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"665"Model type should be one of {}.".format(666config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_PRETRAINING_MAPPING.keys())667)668)669
670
671class AutoModelWithLMHead:672r"""673:class:`~transformers.AutoModelWithLMHead` is a generic model class
674that will be instantiated as one of the language modeling model classes of the library
675when created with the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)`
676class method.
677
678This class cannot be instantiated using `__init__()` (throws an error).
679"""
680
681def __init__(self):682raise EnvironmentError(683"AutoModelWithLMHead is designed to be instantiated "684"using the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` or "685"`AutoModelWithLMHead.from_config(config)` methods."686)687
688@classmethod689def from_config(cls, config):690r""" Instantiates one of the base model classes of the library691from a configuration.
692
693Note:
694Loading a model from its configuration file does **not** load the model weights.
695It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
696the model weights
697
698Args:
699config (:class:`~transformers.PretrainedConfig`):
700The model class to instantiate is selected based on the configuration class:
701
702- isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
703- isInstance of `longformer` configuration class: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
704- isInstance of `roberta` configuration class: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
705- isInstance of `bert` configuration class: :class:`~transformers.BertForMaskedLM` (Bert model)
706- isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
707- isInstance of `gpt2` configuration class: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
708- isInstance of `ctrl` configuration class: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
709- isInstance of `transfo-xl` configuration class: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
710- isInstance of `xlnet` configuration class: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
711- isInstance of `xlm` configuration class: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
712- isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
713- isInstance of `electra` configuration class: :class:`~transformers.ElectraForMaskedLM` (Electra model)
714
715Examples::
716
717config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
718model = AutoModelWithLMHead.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
719"""
720warnings.warn(721"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.",722FutureWarning,723)724for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():725if isinstance(config, config_class):726return model_class(config)727raise ValueError(728"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"729"Model type should be one of {}.".format(730config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_WITH_LM_HEAD_MAPPING.keys())731)732)733
734@classmethod735def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):736r""" Instantiates one of the language modeling model classes of the library737from a pre-trained model configuration.
738
739The `from_pretrained()` method takes care of returning the correct model class instance
740based on the `model_type` property of the config object, or when it's missing,
741falling back to using pattern matching on the `pretrained_model_name_or_path` string:
742
743- `t5`: :class:`~transformers.T5ForConditionalGeneration` (T5 model)
744- `distilbert`: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
745- `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model)
746- `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model)
747- `xlm-roberta`: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-RoBERTa model)
748- `longformer`: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
749- `roberta`: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
750- `bert`: :class:`~transformers.BertForMaskedLM` (Bert model)
751- `openai-gpt`: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
752- `gpt2`: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
753- `transfo-xl`: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
754- `xlnet`: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
755- `xlm`: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
756- `ctrl`: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
757- `flaubert`: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
758- `electra`: :class:`~transformers.ElectraForMaskedLM` (Electra model)
759
760The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
761To train the model, you should first set it back in training mode with `model.train()`
762
763Args:
764pretrained_model_name_or_path:
765Either:
766
767- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
768- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
769- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
770- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
771model_args: (`optional`) Sequence of positional arguments:
772All remaning positional arguments will be passed to the underlying model's ``__init__`` method
773config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
774Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
775
776- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
777- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
778- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
779
780state_dict: (`optional`) dict:
781an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
782This option can be used if you want to create a model from a pretrained configuration but load your own weights.
783In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
784cache_dir: (`optional`) string:
785Path to a directory in which a downloaded pre-trained model
786configuration should be cached if the standard cache should not be used.
787force_download: (`optional`) boolean, default False:
788Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
789resume_download: (`optional`) boolean, default False:
790Do not delete incompletely received file. Attempt to resume the download if such a file exists.
791proxies: (`optional`) dict, default None:
792A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
793The proxies are used on each request.
794output_loading_info: (`optional`) boolean:
795Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
796kwargs: (`optional`) Remaining dictionary of keyword arguments:
797These arguments will be passed to the configuration and the model.
798
799Examples::
800
801model = AutoModelWithLMHead.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
802model = AutoModelWithLMHead.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
803assert model.config.output_attention == True
804# Loading from a TF checkpoint file instead of a PyTorch model (slower)
805config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
806model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
807
808"""
809warnings.warn(810"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.",811FutureWarning,812)813config = kwargs.pop("config", None)814if not isinstance(config, PretrainedConfig):815config, kwargs = AutoConfig.from_pretrained(816pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs817)818
819for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():820if isinstance(config, config_class):821return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)822raise ValueError(823"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"824"Model type should be one of {}.".format(825config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_WITH_LM_HEAD_MAPPING.keys())826)827)828
829
830class AutoModelForCausalLM:831r"""832:class:`~transformers.AutoModelForCausalLM` is a generic model class
833that will be instantiated as one of the language modeling model classes of the library
834when created with the `AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)`
835class method.
836
837This class cannot be instantiated using `__init__()` (throws an error).
838"""
839
840def __init__(self):841raise EnvironmentError(842"AutoModelForCausalLM is designed to be instantiated "843"using the `AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)` or "844"`AutoModelForCausalLM.from_config(config)` methods."845)846
847@classmethod848def from_config(cls, config):849r""" Instantiates one of the base model classes of the library850from a configuration.
851
852Note:
853Loading a model from its configuration file does **not** load the model weights.
854It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
855the model weights
856
857Args:
858config (:class:`~transformers.PretrainedConfig`):
859The model class to instantiate is selected based on the configuration class:
860
861- isInstance of `bert` configuration class: :class:`~transformers.BertLMHeadModel` (Bert model)
862- isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
863- isInstance of `gpt2` configuration class: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
864- isInstance of `ctrl` configuration class: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
865- isInstance of `transfo-xl` configuration class: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
866- isInstance of `xlnet` configuration class: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
867- isInstance of `reformer` configuration class: :class:`~transformers.ReformerModelWithLMHead` (Reformer model)
868
869Examples::
870
871config = GPT2Config.from_pretrained('gpt2') # Download configuration from S3 and cache.
872model = AutoModelForCausalLM.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
873"""
874for config_class, model_class in MODEL_FOR_CAUSAL_LM_MAPPING.items():875if isinstance(config, config_class):876return model_class(config)877raise ValueError(878"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"879"Model type should be one of {}.".format(880config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_CAUSAL_LM_MAPPING.keys())881)882)883
884@classmethod885def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):886r""" Instantiates one of the language modeling model classes of the library887from a pre-trained model configuration.
888
889The `from_pretrained()` method takes care of returning the correct model class instance
890based on the `model_type` property of the config object, or when it's missing,
891falling back to using pattern matching on the `pretrained_model_name_or_path` string:
892
893- `bert`: :class:`~transformers.BertLMHeadModel` (Bert model)
894- `openai-gpt`: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
895- `gpt2`: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
896- `transfo-xl`: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
897- `xlnet`: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
898- `ctrl`: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
899- `reformer`: :class:`~transformers.ReformerModelWithLMHead` (Google Reformer model)
900
901The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
902To train the model, you should first set it back in training mode with `model.train()`
903
904Args:
905pretrained_model_name_or_path:
906Either:
907
908- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
909- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
910- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
911- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
912model_args: (`optional`) Sequence of positional arguments:
913All remaning positional arguments will be passed to the underlying model's ``__init__`` method
914config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
915Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
916
917- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
918- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
919- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
920
921state_dict: (`optional`) dict:
922an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
923This option can be used if you want to create a model from a pretrained configuration but load your own weights.
924In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
925cache_dir: (`optional`) string:
926Path to a directory in which a downloaded pre-trained model
927configuration should be cached if the standard cache should not be used.
928force_download: (`optional`) boolean, default False:
929Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
930resume_download: (`optional`) boolean, default False:
931Do not delete incompletely received file. Attempt to resume the download if such a file exists.
932proxies: (`optional`) dict, default None:
933A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
934The proxies are used on each request.
935output_loading_info: (`optional`) boolean:
936Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
937kwargs: (`optional`) Remaining dictionary of keyword arguments:
938These arguments will be passed to the configuration and the model.
939
940Examples::
941
942model = AutoModelForCausalLM.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
943model = AutoModelForCausalLM.from_pretrained('./test/gpt2_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
944assert model.config.output_attention == True
945# Loading from a TF checkpoint file instead of a PyTorch model (slower)
946config = AutoConfig.from_json_file('./tf_model/gpt2_tf_model_config.json')
947model = AutoModelForCausalLM.from_pretrained('./tf_model/gpt2_tf_checkpoint.ckpt.index', from_tf=True, config=config)
948
949"""
950config = kwargs.pop("config", None)951if not isinstance(config, PretrainedConfig):952config, kwargs = AutoConfig.from_pretrained(953pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs954)955
956for config_class, model_class in MODEL_FOR_CAUSAL_LM_MAPPING.items():957if isinstance(config, config_class):958return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)959raise ValueError(960"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"961"Model type should be one of {}.".format(962config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_CAUSAL_LM_MAPPING.keys())963)964)965
966
967class AutoModelForMaskedLM:968r"""969:class:`~transformers.AutoModelForMaskedLM` is a generic model class
970that will be instantiated as one of the language modeling model classes of the library
971when created with the `AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)`
972class method.
973
974This class cannot be instantiated using `__init__()` (throws an error).
975"""
976
977def __init__(self):978raise EnvironmentError(979"AutoModelForMaskedLM is designed to be instantiated "980"using the `AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)` or "981"`AutoModelForMaskedLM.from_config(config)` methods."982)983
984@classmethod985def from_config(cls, config):986r""" Instantiates one of the base model classes of the library987from a configuration.
988
989Note:
990Loading a model from its configuration file does **not** load the model weights.
991It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
992the model weights
993
994Args:
995config (:class:`~transformers.PretrainedConfig`):
996The model class to instantiate is selected based on the configuration class:
997- isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
998- isInstance of `longformer` configuration class: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
999- isInstance of `roberta` configuration class: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
1000- isInstance of `bert` configuration class: :class:`~transformers.BertForMaskedLM` (Bert model)
1001- isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
1002- isInstance of `xlm` configuration class: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
1003- isInstance of `xlm-roberta` configuration class: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-Roberta model)
1004- isInstance of `electra` configuration class: :class:`~transformers.ElectraForMaskedLM` (Electra model)
1005- isInstance of `camembert` configuration class: :class:`~transformers.CamembertForMaskedLM` (Camembert model)
1006- isInstance of `albert` configuration class: :class:`~transformers.AlbertForMaskedLM` (Albert model)
1007
1008
1009Examples::
1010
1011config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
1012model = AutoModelForMaskedLM.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
1013"""
1014for config_class, model_class in MODEL_FOR_MASKED_LM_MAPPING.items():1015if isinstance(config, config_class):1016return model_class(config)1017raise ValueError(1018"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"1019"Model type should be one of {}.".format(1020config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_MASKED_LM_MAPPING.keys())1021)1022)1023
1024@classmethod1025def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):1026r""" Instantiates one of the language modeling model classes of the library1027from a pre-trained model configuration.
1028
1029The `from_pretrained()` method takes care of returning the correct model class instance
1030based on the `model_type` property of the config object, or when it's missing,
1031falling back to using pattern matching on the `pretrained_model_name_or_path` string:
1032
1033- `distilbert`: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
1034- `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model)
1035- `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model)
1036- `xlm-roberta`: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-RoBERTa model)
1037- `longformer`: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
1038- `roberta`: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
1039- `xlm`: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
1040- `flaubert`: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
1041- `electra`: :class:`~transformers.ElectraForMaskedLM` (Electra model)
1042- `bert`: :class:`~transformers.BertLMHeadModel` (Bert model)
1043
1044The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
1045To train the model, you should first set it back in training mode with `model.train()`
1046
1047Args:
1048pretrained_model_name_or_path:
1049Either:
1050
1051- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
1052- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
1053- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
1054- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
1055model_args: (`optional`) Sequence of positional arguments:
1056All remaning positional arguments will be passed to the underlying model's ``__init__`` method
1057config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
1058Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
1059
1060- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
1061- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
1062- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
1063
1064state_dict: (`optional`) dict:
1065an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
1066This option can be used if you want to create a model from a pretrained configuration but load your own weights.
1067In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
1068cache_dir: (`optional`) string:
1069Path to a directory in which a downloaded pre-trained model
1070configuration should be cached if the standard cache should not be used.
1071force_download: (`optional`) boolean, default False:
1072Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
1073resume_download: (`optional`) boolean, default False:
1074Do not delete incompletely received file. Attempt to resume the download if such a file exists.
1075proxies: (`optional`) dict, default None:
1076A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
1077The proxies are used on each request.
1078output_loading_info: (`optional`) boolean:
1079Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
1080kwargs: (`optional`) Remaining dictionary of keyword arguments:
1081These arguments will be passed to the configuration and the model.
1082
1083Examples::
1084
1085model = AutoModelForMaskedLM.from_pretrained('bert') # Download model and configuration from S3 and cache.
1086model = AutoModelForMaskedLM.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
1087assert model.config.output_attention == True
1088# Loading from a TF checkpoint file instead of a PyTorch model (slower)
1089config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
1090model = AutoModelForMaskedLM.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
1091
1092"""
1093config = kwargs.pop("config", None)1094if not isinstance(config, PretrainedConfig):1095config, kwargs = AutoConfig.from_pretrained(1096pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs1097)1098
1099for config_class, model_class in MODEL_FOR_MASKED_LM_MAPPING.items():1100if isinstance(config, config_class):1101return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)1102raise ValueError(1103"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"1104"Model type should be one of {}.".format(1105config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_MASKED_LM_MAPPING.keys())1106)1107)1108
1109
1110class AutoModelForSeq2SeqLM:1111r"""1112:class:`~transformers.AutoModelForSeq2SeqLM` is a generic model class
1113that will be instantiated as one of the language modeling model classes of the library
1114when created with the `AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path)`
1115class method.
1116
1117This class cannot be instantiated using `__init__()` (throws an error).
1118"""
1119
1120def __init__(self):1121raise EnvironmentError(1122"AutoModelForSeq2SeqLM is designed to be instantiated "1123"using the `AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path)` or "1124"`AutoModelForSeq2SeqLM.from_config(config)` methods."1125)1126
1127@classmethod1128def from_config(cls, config):1129r""" Instantiates one of the base model classes of the library1130from a configuration.
1131
1132Note:
1133Loading a model from its configuration file does **not** load the model weights.
1134It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
1135the model weights
1136
1137Args:
1138config (:class:`~transformers.PretrainedConfig`):
1139The model class to instantiate is selected based on the configuration class:
1140
1141- isInstance of `t5` configuration class: :class:`~transformers.T5ForConditionalGeneration` (T5 model)
1142- isInstance of `bart` configuration class: :class:`~transformers.BartForConditionalGeneration` (Bart model)
1143- isInstance of `marian` configuration class: :class:`~transformers.MarianMTModel` (Marian model)
1144- isInstance of `encoder-decoder` configuration class: :class:`~transformers.EncoderDecoderModel` (Encoder Decoder model)
1145
1146Examples::
1147
1148config = T5Config.from_pretrained('t5')
1149model = AutoModelForSeq2SeqLM.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
1150"""
1151for config_class, model_class in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items():1152if isinstance(config, config_class):1153return model_class(config)1154raise ValueError(1155"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"1156"Model type should be one of {}.".format(1157config.__class__,1158cls.__name__,1159", ".join(c.__name__ for c in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys()),1160)1161)1162
1163@classmethod1164def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):1165r""" Instantiates one of the language modeling model classes of the library1166from a pre-trained model configuration.
1167
1168The `from_pretrained()` method takes care of returning the correct model class instance
1169based on the `model_type` property of the config object, or when it's missing,
1170falling back to using pattern matching on the `pretrained_model_name_or_path` string:
1171
1172- `t5`: :class:`~transformers.T5ForConditionalGeneration` (T5 model)
1173- `bart`: :class:`~transformers.BartForConditionalGeneration` (Bert model)
1174- `marian`: :class:`~transformers.MarianMTModel` (Marian model)
1175- `encoder-decoder`: :class:`~transformers.EncoderDecoderModel` (Encoder Decoder model)
1176
1177The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
1178To train the model, you should first set it back in training mode with `model.train()`
1179
1180Args:
1181pretrained_model_name_or_path:
1182Either:
1183
1184- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
1185- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
1186- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
1187- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
1188model_args: (`optional`) Sequence of positional arguments:
1189All remaning positional arguments will be passed to the underlying model's ``__init__`` method
1190config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
1191Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
1192
1193- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
1194- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
1195- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
1196
1197state_dict: (`optional`) dict:
1198an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
1199This option can be used if you want to create a model from a pretrained configuration but load your own weights.
1200In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
1201cache_dir: (`optional`) string:
1202Path to a directory in which a downloaded pre-trained model
1203configuration should be cached if the standard cache should not be used.
1204force_download: (`optional`) boolean, default False:
1205Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
1206resume_download: (`optional`) boolean, default False:
1207Do not delete incompletely received file. Attempt to resume the download if such a file exists.
1208proxies: (`optional`) dict, default None:
1209A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
1210The proxies are used on each request.
1211output_loading_info: (`optional`) boolean:
1212Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
1213kwargs: (`optional`) Remaining dictionary of keyword arguments:
1214These arguments will be passed to the configuration and the model.
1215
1216Examples::
1217
1218model = AutoModelForSeq2SeqLM.from_pretrained('t5-base') # Download model and configuration from S3 and cache.
1219model = AutoModelForSeq2SeqLM.from_pretrained('./test/t5_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
1220assert model.config.output_attention == True
1221# Loading from a TF checkpoint file instead of a PyTorch model (slower)
1222config = AutoConfig.from_json_file('./tf_model/t5_tf_model_config.json')
1223model = AutoModelForSeq2SeqLM.from_pretrained('./tf_model/t5_tf_checkpoint.ckpt.index', from_tf=True, config=config)
1224
1225"""
1226config = kwargs.pop("config", None)1227if not isinstance(config, PretrainedConfig):1228config, kwargs = AutoConfig.from_pretrained(1229pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs1230)1231
1232for config_class, model_class in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items():1233if isinstance(config, config_class):1234return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)1235raise ValueError(1236"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"1237"Model type should be one of {}.".format(1238config.__class__,1239cls.__name__,1240", ".join(c.__name__ for c in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys()),1241)1242)1243
1244
1245class AutoModelForSequenceClassification:1246r"""1247:class:`~transformers.AutoModelForSequenceClassification` is a generic model class
1248that will be instantiated as one of the sequence classification model classes of the library
1249when created with the `AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)`
1250class method.
1251
1252This class cannot be instantiated using `__init__()` (throws an error).
1253"""
1254
1255def __init__(self):1256raise EnvironmentError(1257"AutoModelForSequenceClassification is designed to be instantiated "1258"using the `AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)` or "1259"`AutoModelForSequenceClassification.from_config(config)` methods."1260)1261
1262@classmethod1263def from_config(cls, config):1264r""" Instantiates one of the base model classes of the library1265from a configuration.
1266
1267Note:
1268Loading a model from its configuration file does **not** load the model weights.
1269It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
1270the model weights
1271
1272Args:
1273config (:class:`~transformers.PretrainedConfig`):
1274The model class to instantiate is selected based on the configuration class:
1275
1276- isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForSequenceClassification` (DistilBERT model)
1277- isInstance of `albert` configuration class: :class:`~transformers.AlbertForSequenceClassification` (ALBERT model)
1278- isInstance of `camembert` configuration class: :class:`~transformers.CamembertForSequenceClassification` (CamemBERT model)
1279- isInstance of `xlm roberta` configuration class: :class:`~transformers.XLMRobertaForSequenceClassification` (XLM-RoBERTa model)
1280- isInstance of `roberta` configuration class: :class:`~transformers.RobertaForSequenceClassification` (RoBERTa model)
1281- isInstance of `bert` configuration class: :class:`~transformers.BertForSequenceClassification` (Bert model)
1282- isInstance of `xlnet` configuration class: :class:`~transformers.XLNetForSequenceClassification` (XLNet model)
1283- isInstance of `xlm` configuration class: :class:`~transformers.XLMForSequenceClassification` (XLM model)
1284- isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertForSequenceClassification` (Flaubert model)
1285
1286
1287Examples::
1288
1289config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
1290model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
1291"""
1292for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():1293if isinstance(config, config_class):1294return model_class(config)1295raise ValueError(1296"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"1297"Model type should be one of {}.".format(1298config.__class__,1299cls.__name__,1300", ".join(c.__name__ for c in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()),1301)1302)1303
1304@classmethod1305def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):1306r""" Instantiates one of the sequence classification model classes of the library1307from a pre-trained model configuration.
1308
1309The `from_pretrained()` method takes care of returning the correct model class instance
1310based on the `model_type` property of the config object, or when it's missing,
1311falling back to using pattern matching on the `pretrained_model_name_or_path` string:
1312
1313- `distilbert`: :class:`~transformers.DistilBertForSequenceClassification` (DistilBERT model)
1314- `albert`: :class:`~transformers.AlbertForSequenceClassification` (ALBERT model)
1315- `camembert`: :class:`~transformers.CamembertForSequenceClassification` (CamemBERT model)
1316- `xlm-roberta`: :class:`~transformers.XLMRobertaForSequenceClassification` (XLM-RoBERTa model)
1317- `roberta`: :class:`~transformers.RobertaForSequenceClassification` (RoBERTa model)
1318- `bert`: :class:`~transformers.BertForSequenceClassification` (Bert model)
1319- `xlnet`: :class:`~transformers.XLNetForSequenceClassification` (XLNet model)
1320- `flaubert`: :class:`~transformers.FlaubertForSequenceClassification` (Flaubert model)
1321
1322The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
1323To train the model, you should first set it back in training mode with `model.train()`
1324
1325Args:
1326pretrained_model_name_or_path: either:
1327
1328- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
1329- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
1330- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
1331- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
1332
1333model_args: (`optional`) Sequence of positional arguments:
1334All remaining positional arguments will be passed to the underlying model's ``__init__`` method
1335
1336config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
1337Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
1338
1339- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
1340- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
1341- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
1342
1343state_dict: (`optional`) dict:
1344an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
1345This option can be used if you want to create a model from a pretrained configuration but load your own weights.
1346In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
1347
1348cache_dir: (`optional`) string:
1349Path to a directory in which a downloaded pre-trained model
1350configuration should be cached if the standard cache should not be used.
1351
1352force_download: (`optional`) boolean, default False:
1353Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
1354
1355resume_download: (`optional`) boolean, default False:
1356Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
1357
1358proxies: (`optional`) dict, default None:
1359A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
1360The proxies are used on each request.
1361
1362output_loading_info: (`optional`) boolean:
1363Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
1364
1365kwargs: (`optional`) Remaining dictionary of keyword arguments:
1366These arguments will be passed to the configuration and the model.
1367
1368Examples::
1369
1370model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
1371model = AutoModelForSequenceClassification.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
1372assert model.config.output_attention == True
1373# Loading from a TF checkpoint file instead of a PyTorch model (slower)
1374config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
1375model = AutoModelForSequenceClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
1376
1377"""
1378config = kwargs.pop("config", None)1379if not isinstance(config, PretrainedConfig):1380config, kwargs = AutoConfig.from_pretrained(1381pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs1382)1383
1384for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():1385if isinstance(config, config_class):1386return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)1387raise ValueError(1388"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"1389"Model type should be one of {}.".format(1390config.__class__,1391cls.__name__,1392", ".join(c.__name__ for c in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()),1393)1394)1395
1396
1397class AutoModelForQuestionAnswering:1398r"""1399:class:`~transformers.AutoModelForQuestionAnswering` is a generic model class
1400that will be instantiated as one of the question answering model classes of the library
1401when created with the `AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)`
1402class method.
1403
1404This class cannot be instantiated using `__init__()` (throws an error).
1405"""
1406
1407def __init__(self):1408raise EnvironmentError(1409"AutoModelForQuestionAnswering is designed to be instantiated "1410"using the `AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)` or "1411"`AutoModelForQuestionAnswering.from_config(config)` methods."1412)1413
1414@classmethod1415def from_config(cls, config):1416r""" Instantiates one of the base model classes of the library1417from a configuration.
1418
1419Note:
1420Loading a model from its configuration file does **not** load the model weights.
1421It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
1422the model weights
1423
1424Args:
1425config (:class:`~transformers.PretrainedConfig`):
1426The model class to instantiate is selected based on the configuration class:
1427
1428- isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForQuestionAnswering` (DistilBERT model)
1429- isInstance of `albert` configuration class: :class:`~transformers.AlbertForQuestionAnswering` (ALBERT model)
1430- isInstance of `bert` configuration class: :class:`~transformers.BertModelForQuestionAnswering` (Bert model)
1431- isInstance of `xlnet` configuration class: :class:`~transformers.XLNetForQuestionAnswering` (XLNet model)
1432- isInstance of `xlm` configuration class: :class:`~transformers.XLMForQuestionAnswering` (XLM model)
1433- isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertForQuestionAnswering` (XLM model)
1434
1435Examples::
1436
1437config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
1438model = AutoModelForQuestionAnswering.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
1439"""
1440for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():1441if isinstance(config, config_class):1442return model_class(config)1443
1444raise ValueError(1445"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"1446"Model type should be one of {}.".format(1447config.__class__,1448cls.__name__,1449", ".join(c.__name__ for c in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()),1450)1451)1452
1453@classmethod1454def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):1455r""" Instantiates one of the question answering model classes of the library1456from a pre-trained model configuration.
1457
1458The `from_pretrained()` method takes care of returning the correct model class instance
1459based on the `model_type` property of the config object, or when it's missing,
1460falling back to using pattern matching on the `pretrained_model_name_or_path` string:
1461
1462- `distilbert`: :class:`~transformers.DistilBertForQuestionAnswering` (DistilBERT model)
1463- `albert`: :class:`~transformers.AlbertForQuestionAnswering` (ALBERT model)
1464- `bert`: :class:`~transformers.BertForQuestionAnswering` (Bert model)
1465- `xlnet`: :class:`~transformers.XLNetForQuestionAnswering` (XLNet model)
1466- `xlm`: :class:`~transformers.XLMForQuestionAnswering` (XLM model)
1467- `flaubert`: :class:`~transformers.FlaubertForQuestionAnswering` (XLM model)
1468
1469The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
1470To train the model, you should first set it back in training mode with `model.train()`
1471
1472Args:
1473pretrained_model_name_or_path: either:
1474
1475- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
1476- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
1477- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
1478- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
1479
1480model_args: (`optional`) Sequence of positional arguments:
1481All remaning positional arguments will be passed to the underlying model's ``__init__`` method
1482
1483config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
1484Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
1485
1486- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
1487- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
1488- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
1489
1490state_dict: (`optional`) dict:
1491an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
1492This option can be used if you want to create a model from a pretrained configuration but load your own weights.
1493In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
1494
1495cache_dir: (`optional`) string:
1496Path to a directory in which a downloaded pre-trained model
1497configuration should be cached if the standard cache should not be used.
1498
1499force_download: (`optional`) boolean, default False:
1500Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
1501
1502proxies: (`optional`) dict, default None:
1503A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
1504The proxies are used on each request.
1505
1506output_loading_info: (`optional`) boolean:
1507Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
1508
1509kwargs: (`optional`) Remaining dictionary of keyword arguments:
1510These arguments will be passed to the configuration and the model.
1511
1512Examples::
1513
1514model = AutoModelForQuestionAnswering.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
1515model = AutoModelForQuestionAnswering.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
1516assert model.config.output_attention == True
1517# Loading from a TF checkpoint file instead of a PyTorch model (slower)
1518config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
1519model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
1520
1521"""
1522config = kwargs.pop("config", None)1523if not isinstance(config, PretrainedConfig):1524config, kwargs = AutoConfig.from_pretrained(1525pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs1526)1527
1528for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():1529if isinstance(config, config_class):1530return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)1531
1532raise ValueError(1533"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"1534"Model type should be one of {}.".format(1535config.__class__,1536cls.__name__,1537", ".join(c.__name__ for c in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()),1538)1539)1540
1541
1542class AutoModelForTokenClassification:1543r"""1544:class:`~transformers.AutoModelForTokenClassification` is a generic model class
1545that will be instantiated as one of the token classification model classes of the library
1546when created with the `AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path)`
1547class method.
1548
1549This class cannot be instantiated using `__init__()` (throws an error).
1550"""
1551
1552def __init__(self):1553raise EnvironmentError(1554"AutoModelForTokenClassification is designed to be instantiated "1555"using the `AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path)` or "1556"`AutoModelForTokenClassification.from_config(config)` methods."1557)1558
1559@classmethod1560def from_config(cls, config):1561r""" Instantiates one of the base model classes of the library1562from a configuration.
1563
1564Note:
1565Loading a model from its configuration file does **not** load the model weights.
1566It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
1567the model weights
1568
1569Args:
1570config (:class:`~transformers.PretrainedConfig`):
1571The model class to instantiate is selected based on the configuration class:
1572
1573- isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertModelForTokenClassification` (DistilBERT model)
1574- isInstance of `xlm` configuration class: :class:`~transformers.XLMForTokenClassification` (XLM model)
1575- isInstance of `xlm roberta` configuration class: :class:`~transformers.XLMRobertaModelForTokenClassification` (XLMRoberta model)
1576- isInstance of `bert` configuration class: :class:`~transformers.BertModelForTokenClassification` (Bert model)
1577- isInstance of `albert` configuration class: :class:`~transformers.AlbertForTokenClassification` (AlBert model)
1578- isInstance of `xlnet` configuration class: :class:`~transformers.XLNetModelForTokenClassification` (XLNet model)
1579- isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertForTokenClassification` (Flaubert model)
1580- isInstance of `camembert` configuration class: :class:`~transformers.CamembertModelForTokenClassification` (Camembert model)
1581- isInstance of `roberta` configuration class: :class:`~transformers.RobertaModelForTokenClassification` (Roberta model)
1582- isInstance of `electra` configuration class: :class:`~transformers.ElectraForTokenClassification` (Electra model)
1583
1584Examples::
1585
1586config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
1587model = AutoModelForTokenClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
1588"""
1589for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():1590if isinstance(config, config_class):1591return model_class(config)1592
1593raise ValueError(1594"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"1595"Model type should be one of {}.".format(1596config.__class__,1597cls.__name__,1598", ".join(c.__name__ for c in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()),1599)1600)1601
1602@classmethod1603def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):1604r""" Instantiates one of the question answering model classes of the library1605from a pre-trained model configuration.
1606
1607The `from_pretrained()` method takes care of returning the correct model class instance
1608based on the `model_type` property of the config object, or when it's missing,
1609falling back to using pattern matching on the `pretrained_model_name_or_path` string:
1610
1611- `distilbert`: :class:`~transformers.DistilBertForTokenClassification` (DistilBERT model)
1612- `xlm`: :class:`~transformers.XLMForTokenClassification` (XLM model)
1613- `xlm-roberta`: :class:`~transformers.XLMRobertaForTokenClassification` (XLM-RoBERTa?Para model)
1614- `camembert`: :class:`~transformers.CamembertForTokenClassification` (Camembert model)
1615- `bert`: :class:`~transformers.BertForTokenClassification` (Bert model)
1616- `xlnet`: :class:`~transformers.XLNetForTokenClassification` (XLNet model)
1617- `flaubert`: :class:`~transformers.FlaubertForTokenClassification` (Flaubert model)
1618- `roberta`: :class:`~transformers.RobertaForTokenClassification` (Roberta model)
1619- `electra`: :class:`~transformers.ElectraForTokenClassification` (Electra model)
1620
1621The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
1622To train the model, you should first set it back in training mode with `model.train()`
1623
1624Args:
1625pretrained_model_name_or_path:
1626Either:
1627
1628- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
1629- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
1630- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
1631
1632model_args: (`optional`) Sequence of positional arguments:
1633All remaning positional arguments will be passed to the underlying model's ``__init__`` method
1634
1635config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
1636Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
1637
1638- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
1639- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
1640- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
1641
1642state_dict: (`optional`) dict:
1643an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
1644This option can be used if you want to create a model from a pretrained configuration but load your own weights.
1645In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
1646
1647cache_dir: (`optional`) string:
1648Path to a directory in which a downloaded pre-trained model
1649configuration should be cached if the standard cache should not be used.
1650
1651force_download: (`optional`) boolean, default False:
1652Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
1653
1654proxies: (`optional`) dict, default None:
1655A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
1656The proxies are used on each request.
1657
1658output_loading_info: (`optional`) boolean:
1659Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
1660
1661kwargs: (`optional`) Remaining dictionary of keyword arguments:
1662These arguments will be passed to the configuration and the model.
1663
1664Examples::
1665
1666model = AutoModelForTokenClassification.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
1667model = AutoModelForTokenClassification.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
1668assert model.config.output_attention == True
1669# Loading from a TF checkpoint file instead of a PyTorch model (slower)
1670config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
1671model = AutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
1672
1673"""
1674config = kwargs.pop("config", None)1675if not isinstance(config, PretrainedConfig):1676config, kwargs = AutoConfig.from_pretrained(1677pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs1678)1679
1680for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():1681if isinstance(config, config_class):1682return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)1683
1684raise ValueError(1685"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"1686"Model type should be one of {}.".format(1687config.__class__,1688cls.__name__,1689", ".join(c.__name__ for c in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()),1690)1691)1692
1693
1694class AutoModelForMultipleChoice:1695r"""1696:class:`~transformers.AutoModelForMultipleChoice` is a generic model class
1697that will be instantiated as one of the multiple choice model classes of the library
1698when created with the `AutoModelForMultipleChoice.from_pretrained(pretrained_model_name_or_path)`
1699class method.
1700
1701This class cannot be instantiated using `__init__()` (throws an error).
1702"""
1703
1704def __init__(self):1705raise EnvironmentError(1706"AutoModelForMultipleChoice is designed to be instantiated "1707"using the `AutoModelForMultipleChoice.from_pretrained(pretrained_model_name_or_path)` or "1708"`AutoModelForMultipleChoice.from_config(config)` methods."1709)1710
1711@classmethod1712def from_config(cls, config):1713for config_class, model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items():1714if isinstance(config, config_class):1715return model_class(config)1716
1717raise ValueError(1718"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"1719"Model type should be one of {}.".format(1720config.__class__,1721cls.__name__,1722", ".join(c.__name__ for c in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys()),1723)1724)1725
1726@classmethod1727def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):1728config = kwargs.pop("config", None)1729if not isinstance(config, PretrainedConfig):1730config, kwargs = AutoConfig.from_pretrained(1731pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs1732)1733
1734for config_class, model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items():1735if isinstance(config, config_class):1736return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)1737
1738raise ValueError(1739"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"1740"Model type should be one of {}.".format(1741config.__class__,1742cls.__name__,1743", ".join(c.__name__ for c in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys()),1744)1745)1746