CSS-LM

Форк
0
/
modeling_auto.py 
1745 строк · 100.7 Кб
1
# coding=utf-8
2
# Copyright 2018 The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Auto Model class. """
16

17

18
import logging
19
import warnings
20
from collections import OrderedDict
21

22
from .configuration_auto import (
23
    AlbertConfig,
24
    AutoConfig,
25
    BartConfig,
26
    BertConfig,
27
    CamembertConfig,
28
    CTRLConfig,
29
    DistilBertConfig,
30
    ElectraConfig,
31
    EncoderDecoderConfig,
32
    FlaubertConfig,
33
    GPT2Config,
34
    LongformerConfig,
35
    MobileBertConfig,
36
    OpenAIGPTConfig,
37
    ReformerConfig,
38
    RetriBertConfig,
39
    RobertaConfig,
40
    T5Config,
41
    TransfoXLConfig,
42
    XLMConfig,
43
    XLMRobertaConfig,
44
    XLNetConfig,
45
)
46
from .configuration_marian import MarianConfig
47
from .configuration_utils import PretrainedConfig
48
from .modeling_albert import (
49
    AlbertForMaskedLM,
50
    AlbertForMultipleChoice,
51
    AlbertForPreTraining,
52
    AlbertForQuestionAnswering,
53
    AlbertForSequenceClassification,
54
    AlbertForTokenClassification,
55
    AlbertModel,
56
)
57
from .modeling_bart import (
58
    BartForConditionalGeneration,
59
    BartForQuestionAnswering,
60
    BartForSequenceClassification,
61
    BartModel,
62
)
63
from .modeling_bert import (
64
    BertForMaskedLM,
65
    BertForMultipleChoice,
66
    BertForPreTraining,
67
    BertForQuestionAnswering,
68
    BertForSequenceClassification,
69
    BertForTokenClassification,
70
    BertLMHeadModel,
71
    BertModel,
72
)
73
from .modeling_camembert import (
74
    CamembertForMaskedLM,
75
    CamembertForMultipleChoice,
76
    CamembertForQuestionAnswering,
77
    CamembertForSequenceClassification,
78
    CamembertForTokenClassification,
79
    CamembertModel,
80
)
81
from .modeling_ctrl import CTRLLMHeadModel, CTRLModel
82
from .modeling_distilbert import (
83
    DistilBertForMaskedLM,
84
    DistilBertForMultipleChoice,
85
    DistilBertForQuestionAnswering,
86
    DistilBertForSequenceClassification,
87
    DistilBertForTokenClassification,
88
    DistilBertModel,
89
)
90
from .modeling_electra import (
91
    ElectraForMaskedLM,
92
    ElectraForMultipleChoice,
93
    ElectraForPreTraining,
94
    ElectraForQuestionAnswering,
95
    ElectraForSequenceClassification,
96
    ElectraForTokenClassification,
97
    ElectraModel,
98
)
99
from .modeling_encoder_decoder import EncoderDecoderModel
100
from .modeling_flaubert import (
101
    FlaubertForMultipleChoice,
102
    FlaubertForQuestionAnsweringSimple,
103
    FlaubertForSequenceClassification,
104
    FlaubertForTokenClassification,
105
    FlaubertModel,
106
    FlaubertWithLMHeadModel,
107
)
108
from .modeling_gpt2 import GPT2LMHeadModel, GPT2Model
109
from .modeling_longformer import (
110
    LongformerForMaskedLM,
111
    LongformerForMultipleChoice,
112
    LongformerForQuestionAnswering,
113
    LongformerForSequenceClassification,
114
    LongformerForTokenClassification,
115
    LongformerModel,
116
)
117
from .modeling_marian import MarianMTModel
118
from .modeling_mobilebert import (
119
    MobileBertForMaskedLM,
120
    MobileBertForMultipleChoice,
121
    MobileBertForPreTraining,
122
    MobileBertForQuestionAnswering,
123
    MobileBertForSequenceClassification,
124
    MobileBertForTokenClassification,
125
    MobileBertModel,
126
)
127
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
128
from .modeling_reformer import (
129
    ReformerForMaskedLM,
130
    ReformerForQuestionAnswering,
131
    ReformerModel,
132
    ReformerModelWithLMHead,
133
)
134
from .modeling_retribert import RetriBertModel
135
from .modeling_roberta import (
136
    RobertaForMaskedLM,
137
    RobertaForMultipleChoice,
138
    RobertaForQuestionAnswering,
139
    RobertaForSequenceClassification,
140
    RobertaForTokenClassification,
141
    RobertaModel,
142
)
143
from .modeling_t5 import T5ForConditionalGeneration, T5Model
144
from .modeling_transfo_xl import TransfoXLLMHeadModel, TransfoXLModel
145
from .modeling_xlm import (
146
    XLMForMultipleChoice,
147
    XLMForQuestionAnsweringSimple,
148
    XLMForSequenceClassification,
149
    XLMForTokenClassification,
150
    XLMModel,
151
    XLMWithLMHeadModel,
152
)
153
from .modeling_xlm_roberta import (
154
    XLMRobertaForMaskedLM,
155
    XLMRobertaForMultipleChoice,
156
    XLMRobertaForQuestionAnswering,
157
    XLMRobertaForSequenceClassification,
158
    XLMRobertaForTokenClassification,
159
    XLMRobertaModel,
160
)
161
from .modeling_xlnet import (
162
    XLNetForMultipleChoice,
163
    XLNetForQuestionAnsweringSimple,
164
    XLNetForSequenceClassification,
165
    XLNetForTokenClassification,
166
    XLNetLMHeadModel,
167
    XLNetModel,
168
)
169

170

171
logger = logging.getLogger(__name__)
172

173

174
MODEL_MAPPING = OrderedDict(
175
    [
176
        (RetriBertConfig, RetriBertModel),
177
        (T5Config, T5Model),
178
        (DistilBertConfig, DistilBertModel),
179
        (AlbertConfig, AlbertModel),
180
        (CamembertConfig, CamembertModel),
181
        (XLMRobertaConfig, XLMRobertaModel),
182
        (BartConfig, BartModel),
183
        (LongformerConfig, LongformerModel),
184
        (RobertaConfig, RobertaModel),
185
        (BertConfig, BertModel),
186
        (OpenAIGPTConfig, OpenAIGPTModel),
187
        (GPT2Config, GPT2Model),
188
        (MobileBertConfig, MobileBertModel),
189
        (TransfoXLConfig, TransfoXLModel),
190
        (XLNetConfig, XLNetModel),
191
        (FlaubertConfig, FlaubertModel),
192
        (XLMConfig, XLMModel),
193
        (CTRLConfig, CTRLModel),
194
        (ElectraConfig, ElectraModel),
195
        (ReformerConfig, ReformerModel),
196
    ]
197
)
198

199
MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
200
    [
201
        (RetriBertConfig, RetriBertModel),
202
        (T5Config, T5ForConditionalGeneration),
203
        (DistilBertConfig, DistilBertForMaskedLM),
204
        (AlbertConfig, AlbertForPreTraining),
205
        (CamembertConfig, CamembertForMaskedLM),
206
        (XLMRobertaConfig, XLMRobertaForMaskedLM),
207
        (BartConfig, BartForConditionalGeneration),
208
        (LongformerConfig, LongformerForMaskedLM),
209
        (RobertaConfig, RobertaForMaskedLM),
210
        (BertConfig, BertForPreTraining),
211
        (OpenAIGPTConfig, OpenAIGPTLMHeadModel),
212
        (GPT2Config, GPT2LMHeadModel),
213
        (MobileBertConfig, MobileBertForPreTraining),
214
        (TransfoXLConfig, TransfoXLLMHeadModel),
215
        (XLNetConfig, XLNetLMHeadModel),
216
        (FlaubertConfig, FlaubertWithLMHeadModel),
217
        (XLMConfig, XLMWithLMHeadModel),
218
        (CTRLConfig, CTRLLMHeadModel),
219
        (ElectraConfig, ElectraForPreTraining),
220
    ]
221
)
222

223
MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
224
    [
225
        (T5Config, T5ForConditionalGeneration),
226
        (DistilBertConfig, DistilBertForMaskedLM),
227
        (AlbertConfig, AlbertForMaskedLM),
228
        (CamembertConfig, CamembertForMaskedLM),
229
        (XLMRobertaConfig, XLMRobertaForMaskedLM),
230
        (MarianConfig, MarianMTModel),
231
        (BartConfig, BartForConditionalGeneration),
232
        (LongformerConfig, LongformerForMaskedLM),
233
        (RobertaConfig, RobertaForMaskedLM),
234
        (BertConfig, BertForMaskedLM),
235
        (OpenAIGPTConfig, OpenAIGPTLMHeadModel),
236
        (GPT2Config, GPT2LMHeadModel),
237
        (MobileBertConfig, MobileBertForMaskedLM),
238
        (TransfoXLConfig, TransfoXLLMHeadModel),
239
        (XLNetConfig, XLNetLMHeadModel),
240
        (FlaubertConfig, FlaubertWithLMHeadModel),
241
        (XLMConfig, XLMWithLMHeadModel),
242
        (CTRLConfig, CTRLLMHeadModel),
243
        (ElectraConfig, ElectraForMaskedLM),
244
        (EncoderDecoderConfig, EncoderDecoderModel),
245
        (ReformerConfig, ReformerModelWithLMHead),
246
    ]
247
)
248

249
MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
250
    [
251
        (BertConfig, BertLMHeadModel),
252
        (OpenAIGPTConfig, OpenAIGPTLMHeadModel),
253
        (GPT2Config, GPT2LMHeadModel),
254
        (TransfoXLConfig, TransfoXLLMHeadModel),
255
        (XLNetConfig, XLNetLMHeadModel),
256
        (
257
            XLMConfig,
258
            XLMWithLMHeadModel,
259
        ),  # XLM can be MLM and CLM => model should be split similar to BERT; leave here for now
260
        (CTRLConfig, CTRLLMHeadModel),
261
        (ReformerConfig, ReformerModelWithLMHead),
262
    ]
263
)
264

265
MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
266
    [
267
        (DistilBertConfig, DistilBertForMaskedLM),
268
        (AlbertConfig, AlbertForMaskedLM),
269
        (BartConfig, BartForConditionalGeneration),
270
        (CamembertConfig, CamembertForMaskedLM),
271
        (XLMRobertaConfig, XLMRobertaForMaskedLM),
272
        (LongformerConfig, LongformerForMaskedLM),
273
        (RobertaConfig, RobertaForMaskedLM),
274
        (BertConfig, BertForMaskedLM),
275
        (MobileBertConfig, MobileBertForMaskedLM),
276
        (FlaubertConfig, FlaubertWithLMHeadModel),
277
        (XLMConfig, XLMWithLMHeadModel),
278
        (ElectraConfig, ElectraForMaskedLM),
279
        (ReformerConfig, ReformerForMaskedLM),
280
    ]
281
)
282

283
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
284
    [
285
        (T5Config, T5ForConditionalGeneration),
286
        (MarianConfig, MarianMTModel),
287
        (BartConfig, BartForConditionalGeneration),
288
        (EncoderDecoderConfig, EncoderDecoderModel),
289
    ]
290
)
291

292
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
293
    [
294
        (DistilBertConfig, DistilBertForSequenceClassification),
295
        (AlbertConfig, AlbertForSequenceClassification),
296
        (CamembertConfig, CamembertForSequenceClassification),
297
        (XLMRobertaConfig, XLMRobertaForSequenceClassification),
298
        (BartConfig, BartForSequenceClassification),
299
        (LongformerConfig, LongformerForSequenceClassification),
300
        (RobertaConfig, RobertaForSequenceClassification),
301
        (BertConfig, BertForSequenceClassification),
302
        (XLNetConfig, XLNetForSequenceClassification),
303
        (MobileBertConfig, MobileBertForSequenceClassification),
304
        (FlaubertConfig, FlaubertForSequenceClassification),
305
        (XLMConfig, XLMForSequenceClassification),
306
        (ElectraConfig, ElectraForSequenceClassification),
307
    ]
308
)
309

310
MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
311
    [
312
        (DistilBertConfig, DistilBertForQuestionAnswering),
313
        (AlbertConfig, AlbertForQuestionAnswering),
314
        (CamembertConfig, CamembertForQuestionAnswering),
315
        (BartConfig, BartForQuestionAnswering),
316
        (LongformerConfig, LongformerForQuestionAnswering),
317
        (XLMRobertaConfig, XLMRobertaForQuestionAnswering),
318
        (RobertaConfig, RobertaForQuestionAnswering),
319
        (BertConfig, BertForQuestionAnswering),
320
        (XLNetConfig, XLNetForQuestionAnsweringSimple),
321
        (FlaubertConfig, FlaubertForQuestionAnsweringSimple),
322
        (MobileBertConfig, MobileBertForQuestionAnswering),
323
        (XLMConfig, XLMForQuestionAnsweringSimple),
324
        (ElectraConfig, ElectraForQuestionAnswering),
325
        (ReformerConfig, ReformerForQuestionAnswering),
326
    ]
327
)
328

329
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
330
    [
331
        (DistilBertConfig, DistilBertForTokenClassification),
332
        (CamembertConfig, CamembertForTokenClassification),
333
        (FlaubertConfig, FlaubertForTokenClassification),
334
        (XLMConfig, XLMForTokenClassification),
335
        (XLMRobertaConfig, XLMRobertaForTokenClassification),
336
        (LongformerConfig, LongformerForTokenClassification),
337
        (RobertaConfig, RobertaForTokenClassification),
338
        (BertConfig, BertForTokenClassification),
339
        (MobileBertConfig, MobileBertForTokenClassification),
340
        (XLNetConfig, XLNetForTokenClassification),
341
        (AlbertConfig, AlbertForTokenClassification),
342
        (ElectraConfig, ElectraForTokenClassification),
343
        (FlaubertConfig, FlaubertForTokenClassification),
344
    ]
345
)
346

347
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
348
    [
349
        (CamembertConfig, CamembertForMultipleChoice),
350
        (ElectraConfig, ElectraForMultipleChoice),
351
        (XLMRobertaConfig, XLMRobertaForMultipleChoice),
352
        (LongformerConfig, LongformerForMultipleChoice),
353
        (RobertaConfig, RobertaForMultipleChoice),
354
        (BertConfig, BertForMultipleChoice),
355
        (DistilBertConfig, DistilBertForMultipleChoice),
356
        (MobileBertConfig, MobileBertForMultipleChoice),
357
        (XLNetConfig, XLNetForMultipleChoice),
358
        (AlbertConfig, AlbertForMultipleChoice),
359
        (XLMConfig, XLMForMultipleChoice),
360
        (FlaubertConfig, FlaubertForMultipleChoice),
361
    ]
362
)
363

364

365
class AutoModel:
366
    r"""
367
        :class:`~transformers.AutoModel` is a generic model class
368
        that will be instantiated as one of the base model classes of the library
369
        when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
370
        or the `AutoModel.from_config(config)` class methods.
371

372
        This class cannot be instantiated using `__init__()` (throws an error).
373
    """
374

375
    def __init__(self):
376
        raise EnvironmentError(
377
            "AutoModel is designed to be instantiated "
378
            "using the `AutoModel.from_pretrained(pretrained_model_name_or_path)` or "
379
            "`AutoModel.from_config(config)` methods."
380
        )
381

382
    @classmethod
383
    def from_config(cls, config):
384
        r""" Instantiates one of the base model classes of the library
385
        from a configuration.
386

387
        Note:
388
            Loading a model from its configuration file does **not** load the model weights.
389
            It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
390
            the model weights
391

392
        Args:
393
            config (:class:`~transformers.PretrainedConfig`):
394
                The model class to instantiate is selected based on the configuration class:
395

396
                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertModel` (DistilBERT model)
397
                - isInstance of `longformer` configuration class: :class:`~transformers.LongformerModel` (Longformer model)
398
                - isInstance of `roberta` configuration class: :class:`~transformers.RobertaModel` (RoBERTa model)
399
                - isInstance of `bert` configuration class: :class:`~transformers.BertModel` (Bert model)
400
                - isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTModel` (OpenAI GPT model)
401
                - isInstance of `gpt2` configuration class: :class:`~transformers.GPT2Model` (OpenAI GPT-2 model)
402
                - isInstance of `ctrl` configuration class: :class:`~transformers.CTRLModel` (Salesforce CTRL  model)
403
                - isInstance of `transfo-xl` configuration class: :class:`~transformers.TransfoXLModel` (Transformer-XL model)
404
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetModel` (XLNet model)
405
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMModel` (XLM model)
406
                - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertModel` (Flaubert model)
407
                - isInstance of `electra` configuration class: :class:`~transformers.ElectraModel` (Electra model)
408

409
        Examples::
410

411
            >>> config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
412
            >>> model = AutoModel.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
413
        """
414
        for config_class, model_class in MODEL_MAPPING.items():
415
            if isinstance(config, config_class):
416
                return model_class(config)
417
        raise ValueError(
418
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
419
            "Model type should be one of {}.".format(
420
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_MAPPING.keys())
421
            )
422
        )
423

424
    @classmethod
425
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
426
        r""" Instantiates one of the base model classes of the library
427
        from a pre-trained model configuration.
428

429
        The `from_pretrained()` method takes care of returning the correct model class instance
430
        based on the `model_type` property of the config object, or when it's missing,
431
        falling back to using pattern matching on the `pretrained_model_name_or_path` string:
432

433
            - `t5`: :class:`~transformers.T5Model` (T5 model)
434
            - `distilbert`: :class:`~transformers.DistilBertModel` (DistilBERT model)
435
            - `albert`: :class:`~transformers.AlbertModel` (ALBERT model)
436
            - `camembert`: :class:`~transformers.CamembertModel` (CamemBERT model)
437
            - `xlm-roberta`: :class:`~transformers.XLMRobertaModel` (XLM-RoBERTa model)
438
            - `longformer` :class:`~transformers.LongformerModel` (Longformer model)
439
            - `roberta`: :class:`~transformers.RobertaModel` (RoBERTa model)
440
            - `bert`: :class:`~transformers.BertModel` (Bert model)
441
            - `openai-gpt`: :class:`~transformers.OpenAIGPTModel` (OpenAI GPT model)
442
            - `gpt2`: :class:`~transformers.GPT2Model` (OpenAI GPT-2 model)
443
            - `transfo-xl`: :class:`~transformers.TransfoXLModel` (Transformer-XL model)
444
            - `xlnet`: :class:`~transformers.XLNetModel` (XLNet model)
445
            - `xlm`: :class:`~transformers.XLMModel` (XLM model)
446
            - `ctrl`: :class:`~transformers.CTRLModel` (Salesforce CTRL  model)
447
            - `flaubert`: :class:`~transformers.FlaubertModel` (Flaubert  model)
448
            - `electra`: :class:`~transformers.ElectraModel` (Electra  model)
449

450
        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
451
        To train the model, you should first set it back in training mode with `model.train()`
452

453
        Args:
454
            pretrained_model_name_or_path: either:
455

456
                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
457
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
458
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
459
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
460

461
            model_args: (`optional`) Sequence of positional arguments:
462
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method
463

464
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
465
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
466

467
                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
468
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
469
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
470

471
            state_dict: (`optional`) dict:
472
                an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
473
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
474
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
475

476
            cache_dir: (`optional`) string:
477
                Path to a directory in which a downloaded pre-trained model
478
                configuration should be cached if the standard cache should not be used.
479

480
            force_download: (`optional`) boolean, default False:
481
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
482

483
            resume_download: (`optional`) boolean, default False:
484
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
485

486
            proxies: (`optional`) dict, default None:
487
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
488
                The proxies are used on each request.
489

490
            output_loading_info: (`optional`) boolean:
491
                Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
492

493
            kwargs: (`optional`) Remaining dictionary of keyword arguments:
494
                These arguments will be passed to the configuration and the model.
495

496
        Examples::
497

498
            model = AutoModel.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
499
            assert model.config.output_attentions == True
500
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
501
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
502
            model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
503

504
        """
505
        config = kwargs.pop("config", None)
506
        if not isinstance(config, PretrainedConfig):
507
            config, kwargs = AutoConfig.from_pretrained(
508
                pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
509
            )
510

511
        for config_class, model_class in MODEL_MAPPING.items():
512
            if isinstance(config, config_class):
513
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
514
        raise ValueError(
515
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
516
            "Model type should be one of {}.".format(
517
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_MAPPING.keys())
518
            )
519
        )
520

521

522
class AutoModelForPreTraining:
523
    r"""
524
        :class:`~transformers.AutoModelForPreTraining` is a generic model class
525
        that will be instantiated as one of the model classes of the library -with the architecture used for pretraining this model– when created with the `AutoModelForPreTraining.from_pretrained(pretrained_model_name_or_path)`
526
        class method.
527

528
        This class cannot be instantiated using `__init__()` (throws an error).
529
    """
530

531
    def __init__(self):
532
        raise EnvironmentError(
533
            "AutoModelForPreTraining is designed to be instantiated "
534
            "using the `AutoModelForPreTraining.from_pretrained(pretrained_model_name_or_path)` or "
535
            "`AutoModelForPreTraining.from_config(config)` methods."
536
        )
537

538
    @classmethod
539
    def from_config(cls, config):
540
        r""" Instantiates one of the base model classes of the library
541
        from a configuration.
542

543
        Note:
544
            Loading a model from its configuration file does **not** load the model weights.
545
            It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
546
            the model weights
547

548
        Args:
549
            config (:class:`~transformers.PretrainedConfig`):
550
                The model class to instantiate is selected based on the configuration class:
551

552
                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
553
                - isInstance of `longformer` configuration class: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
554
                - isInstance of `roberta` configuration class: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
555
                - isInstance of `bert` configuration class: :class:`~transformers.BertForPreTraining` (Bert model)
556
                - isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
557
                - isInstance of `gpt2` configuration class: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
558
                - isInstance of `ctrl` configuration class: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL  model)
559
                - isInstance of `transfo-xl` configuration class: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
560
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
561
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
562
                - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
563
                - isInstance of `electra` configuration class: :class:`~transformers.ElectraForPreTraining` (Electra model)
564

565
        Examples::
566

567
            >>> config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
568
            >>> model = AutoModelForPreTraining.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
569
        """
570
        for config_class, model_class in MODEL_FOR_PRETRAINING_MAPPING.items():
571
            if isinstance(config, config_class):
572
                return model_class(config)
573
        raise ValueError(
574
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
575
            "Model type should be one of {}.".format(
576
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_PRETRAINING_MAPPING.keys())
577
            )
578
        )
579

580
    @classmethod
581
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
582
        r""" Instantiates one of the model classes of the library -with the architecture used for pretraining this model– from a pre-trained model configuration.
583

584
        The `from_pretrained()` method takes care of returning the correct model class instance
585
        based on the `model_type` property of the config object, or when it's missing,
586
        falling back to using pattern matching on the `pretrained_model_name_or_path` string:
587

588
            - `t5`: :class:`~transformers.T5ModelWithLMHead` (T5 model)
589
            - `distilbert`: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
590
            - `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model)
591
            - `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model)
592
            - `xlm-roberta`: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-RoBERTa model)
593
            - `longformer`: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
594
            - `roberta`: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
595
            - `bert`: :class:`~transformers.BertForPreTraining` (Bert model)
596
            - `openai-gpt`: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
597
            - `gpt2`: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
598
            - `transfo-xl`: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
599
            - `xlnet`: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
600
            - `xlm`: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
601
            - `ctrl`: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
602
            - `flaubert`: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
603
            - `electra`: :class:`~transformers.ElectraForPreTraining` (Electra model)
604

605
        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
606
        To train the model, you should first set it back in training mode with `model.train()`
607

608
        Args:
609
            pretrained_model_name_or_path:
610
                Either:
611

612
                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
613
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
614
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
615
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
616
            model_args: (`optional`) Sequence of positional arguments:
617
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method
618
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
619
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
620

621
                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
622
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
623
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
624

625
            state_dict: (`optional`) dict:
626
                an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
627
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
628
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
629
            cache_dir: (`optional`) string:
630
                Path to a directory in which a downloaded pre-trained model
631
                configuration should be cached if the standard cache should not be used.
632
            force_download: (`optional`) boolean, default False:
633
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
634
            resume_download: (`optional`) boolean, default False:
635
                Do not delete incompletely received file. Attempt to resume the download if such a file exists.
636
            proxies: (`optional`) dict, default None:
637
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
638
                The proxies are used on each request.
639
            output_loading_info: (`optional`) boolean:
640
                Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
641
            kwargs: (`optional`) Remaining dictionary of keyword arguments:
642
                These arguments will be passed to the configuration and the model.
643

644
        Examples::
645

646
            model = AutoModelForPreTraining.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
647
            model = AutoModelForPreTraining.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
648
            assert model.config.output_attention == True
649
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
650
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
651
            model = AutoModelForPreTraining.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
652

653
        """
654
        config = kwargs.pop("config", None)
655
        if not isinstance(config, PretrainedConfig):
656
            config, kwargs = AutoConfig.from_pretrained(
657
                pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
658
            )
659

660
        for config_class, model_class in MODEL_FOR_PRETRAINING_MAPPING.items():
661
            if isinstance(config, config_class):
662
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
663
        raise ValueError(
664
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
665
            "Model type should be one of {}.".format(
666
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_PRETRAINING_MAPPING.keys())
667
            )
668
        )
669

670

671
class AutoModelWithLMHead:
672
    r"""
673
        :class:`~transformers.AutoModelWithLMHead` is a generic model class
674
        that will be instantiated as one of the language modeling model classes of the library
675
        when created with the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)`
676
        class method.
677

678
        This class cannot be instantiated using `__init__()` (throws an error).
679
    """
680

681
    def __init__(self):
682
        raise EnvironmentError(
683
            "AutoModelWithLMHead is designed to be instantiated "
684
            "using the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` or "
685
            "`AutoModelWithLMHead.from_config(config)` methods."
686
        )
687

688
    @classmethod
689
    def from_config(cls, config):
690
        r""" Instantiates one of the base model classes of the library
691
        from a configuration.
692

693
        Note:
694
            Loading a model from its configuration file does **not** load the model weights.
695
            It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
696
            the model weights
697

698
        Args:
699
            config (:class:`~transformers.PretrainedConfig`):
700
                The model class to instantiate is selected based on the configuration class:
701

702
                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
703
                - isInstance of `longformer` configuration class: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
704
                - isInstance of `roberta` configuration class: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
705
                - isInstance of `bert` configuration class: :class:`~transformers.BertForMaskedLM` (Bert model)
706
                - isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
707
                - isInstance of `gpt2` configuration class: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
708
                - isInstance of `ctrl` configuration class: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL  model)
709
                - isInstance of `transfo-xl` configuration class: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
710
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
711
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
712
                - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
713
                - isInstance of `electra` configuration class: :class:`~transformers.ElectraForMaskedLM` (Electra model)
714

715
        Examples::
716

717
            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
718
            model = AutoModelWithLMHead.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
719
        """
720
        warnings.warn(
721
            "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.",
722
            FutureWarning,
723
        )
724
        for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():
725
            if isinstance(config, config_class):
726
                return model_class(config)
727
        raise ValueError(
728
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
729
            "Model type should be one of {}.".format(
730
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_WITH_LM_HEAD_MAPPING.keys())
731
            )
732
        )
733

734
    @classmethod
735
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
736
        r""" Instantiates one of the language modeling model classes of the library
737
        from a pre-trained model configuration.
738

739
        The `from_pretrained()` method takes care of returning the correct model class instance
740
        based on the `model_type` property of the config object, or when it's missing,
741
        falling back to using pattern matching on the `pretrained_model_name_or_path` string:
742

743
            - `t5`: :class:`~transformers.T5ForConditionalGeneration` (T5 model)
744
            - `distilbert`: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
745
            - `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model)
746
            - `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model)
747
            - `xlm-roberta`: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-RoBERTa model)
748
            - `longformer`: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
749
            - `roberta`: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
750
            - `bert`: :class:`~transformers.BertForMaskedLM` (Bert model)
751
            - `openai-gpt`: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
752
            - `gpt2`: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
753
            - `transfo-xl`: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
754
            - `xlnet`: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
755
            - `xlm`: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
756
            - `ctrl`: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
757
            - `flaubert`: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
758
            - `electra`: :class:`~transformers.ElectraForMaskedLM` (Electra model)
759

760
        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
761
        To train the model, you should first set it back in training mode with `model.train()`
762

763
        Args:
764
            pretrained_model_name_or_path:
765
                Either:
766

767
                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
768
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
769
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
770
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
771
            model_args: (`optional`) Sequence of positional arguments:
772
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method
773
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
774
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
775

776
                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
777
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
778
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
779

780
            state_dict: (`optional`) dict:
781
                an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
782
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
783
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
784
            cache_dir: (`optional`) string:
785
                Path to a directory in which a downloaded pre-trained model
786
                configuration should be cached if the standard cache should not be used.
787
            force_download: (`optional`) boolean, default False:
788
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
789
            resume_download: (`optional`) boolean, default False:
790
                Do not delete incompletely received file. Attempt to resume the download if such a file exists.
791
            proxies: (`optional`) dict, default None:
792
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
793
                The proxies are used on each request.
794
            output_loading_info: (`optional`) boolean:
795
                Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
796
            kwargs: (`optional`) Remaining dictionary of keyword arguments:
797
                These arguments will be passed to the configuration and the model.
798

799
        Examples::
800

801
            model = AutoModelWithLMHead.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
802
            model = AutoModelWithLMHead.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
803
            assert model.config.output_attention == True
804
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
805
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
806
            model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
807

808
        """
809
        warnings.warn(
810
            "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.",
811
            FutureWarning,
812
        )
813
        config = kwargs.pop("config", None)
814
        if not isinstance(config, PretrainedConfig):
815
            config, kwargs = AutoConfig.from_pretrained(
816
                pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
817
            )
818

819
        for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():
820
            if isinstance(config, config_class):
821
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
822
        raise ValueError(
823
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
824
            "Model type should be one of {}.".format(
825
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_WITH_LM_HEAD_MAPPING.keys())
826
            )
827
        )
828

829

830
class AutoModelForCausalLM:
831
    r"""
832
        :class:`~transformers.AutoModelForCausalLM` is a generic model class
833
        that will be instantiated as one of the language modeling model classes of the library
834
        when created with the `AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)`
835
        class method.
836

837
        This class cannot be instantiated using `__init__()` (throws an error).
838
    """
839

840
    def __init__(self):
841
        raise EnvironmentError(
842
            "AutoModelForCausalLM is designed to be instantiated "
843
            "using the `AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)` or "
844
            "`AutoModelForCausalLM.from_config(config)` methods."
845
        )
846

847
    @classmethod
848
    def from_config(cls, config):
849
        r""" Instantiates one of the base model classes of the library
850
        from a configuration.
851

852
        Note:
853
            Loading a model from its configuration file does **not** load the model weights.
854
            It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
855
            the model weights
856

857
        Args:
858
            config (:class:`~transformers.PretrainedConfig`):
859
                The model class to instantiate is selected based on the configuration class:
860

861
                - isInstance of `bert` configuration class: :class:`~transformers.BertLMHeadModel` (Bert model)
862
                - isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
863
                - isInstance of `gpt2` configuration class: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
864
                - isInstance of `ctrl` configuration class: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL  model)
865
                - isInstance of `transfo-xl` configuration class: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
866
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
867
                - isInstance of `reformer` configuration class: :class:`~transformers.ReformerModelWithLMHead` (Reformer model)
868

869
        Examples::
870

871
            config = GPT2Config.from_pretrained('gpt2')    # Download configuration from S3 and cache.
872
            model = AutoModelForCausalLM.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
873
        """
874
        for config_class, model_class in MODEL_FOR_CAUSAL_LM_MAPPING.items():
875
            if isinstance(config, config_class):
876
                return model_class(config)
877
        raise ValueError(
878
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
879
            "Model type should be one of {}.".format(
880
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_CAUSAL_LM_MAPPING.keys())
881
            )
882
        )
883

884
    @classmethod
885
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
886
        r""" Instantiates one of the language modeling model classes of the library
887
        from a pre-trained model configuration.
888

889
        The `from_pretrained()` method takes care of returning the correct model class instance
890
        based on the `model_type` property of the config object, or when it's missing,
891
        falling back to using pattern matching on the `pretrained_model_name_or_path` string:
892

893
            - `bert`: :class:`~transformers.BertLMHeadModel` (Bert model)
894
            - `openai-gpt`: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
895
            - `gpt2`: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
896
            - `transfo-xl`: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
897
            - `xlnet`: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
898
            - `ctrl`: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
899
            - `reformer`: :class:`~transformers.ReformerModelWithLMHead` (Google Reformer model)
900

901
        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
902
        To train the model, you should first set it back in training mode with `model.train()`
903

904
        Args:
905
            pretrained_model_name_or_path:
906
                Either:
907

908
                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
909
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
910
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
911
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
912
            model_args: (`optional`) Sequence of positional arguments:
913
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method
914
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
915
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
916

917
                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
918
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
919
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
920

921
            state_dict: (`optional`) dict:
922
                an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
923
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
924
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
925
            cache_dir: (`optional`) string:
926
                Path to a directory in which a downloaded pre-trained model
927
                configuration should be cached if the standard cache should not be used.
928
            force_download: (`optional`) boolean, default False:
929
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
930
            resume_download: (`optional`) boolean, default False:
931
                Do not delete incompletely received file. Attempt to resume the download if such a file exists.
932
            proxies: (`optional`) dict, default None:
933
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
934
                The proxies are used on each request.
935
            output_loading_info: (`optional`) boolean:
936
                Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
937
            kwargs: (`optional`) Remaining dictionary of keyword arguments:
938
                These arguments will be passed to the configuration and the model.
939

940
        Examples::
941

942
            model = AutoModelForCausalLM.from_pretrained('gpt2')    # Download model and configuration from S3 and cache.
943
            model = AutoModelForCausalLM.from_pretrained('./test/gpt2_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
944
            assert model.config.output_attention == True
945
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
946
            config = AutoConfig.from_json_file('./tf_model/gpt2_tf_model_config.json')
947
            model =  AutoModelForCausalLM.from_pretrained('./tf_model/gpt2_tf_checkpoint.ckpt.index', from_tf=True, config=config)
948

949
        """
950
        config = kwargs.pop("config", None)
951
        if not isinstance(config, PretrainedConfig):
952
            config, kwargs = AutoConfig.from_pretrained(
953
                pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
954
            )
955

956
        for config_class, model_class in MODEL_FOR_CAUSAL_LM_MAPPING.items():
957
            if isinstance(config, config_class):
958
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
959
        raise ValueError(
960
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
961
            "Model type should be one of {}.".format(
962
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_CAUSAL_LM_MAPPING.keys())
963
            )
964
        )
965

966

967
class AutoModelForMaskedLM:
968
    r"""
969
        :class:`~transformers.AutoModelForMaskedLM` is a generic model class
970
        that will be instantiated as one of the language modeling model classes of the library
971
        when created with the `AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)`
972
        class method.
973

974
        This class cannot be instantiated using `__init__()` (throws an error).
975
    """
976

977
    def __init__(self):
978
        raise EnvironmentError(
979
            "AutoModelForMaskedLM is designed to be instantiated "
980
            "using the `AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)` or "
981
            "`AutoModelForMaskedLM.from_config(config)` methods."
982
        )
983

984
    @classmethod
985
    def from_config(cls, config):
986
        r""" Instantiates one of the base model classes of the library
987
        from a configuration.
988

989
        Note:
990
            Loading a model from its configuration file does **not** load the model weights.
991
            It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
992
            the model weights
993

994
        Args:
995
            config (:class:`~transformers.PretrainedConfig`):
996
                The model class to instantiate is selected based on the configuration class:
997
                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
998
                - isInstance of `longformer` configuration class: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
999
                - isInstance of `roberta` configuration class: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
1000
                - isInstance of `bert` configuration class: :class:`~transformers.BertForMaskedLM` (Bert model)
1001
                - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
1002
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
1003
                - isInstance of `xlm-roberta` configuration class: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-Roberta model)
1004
                - isInstance of `electra` configuration class: :class:`~transformers.ElectraForMaskedLM` (Electra model)
1005
                - isInstance of `camembert` configuration class: :class:`~transformers.CamembertForMaskedLM` (Camembert model)
1006
                - isInstance of `albert` configuration class: :class:`~transformers.AlbertForMaskedLM` (Albert model)
1007

1008

1009
        Examples::
1010

1011
            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
1012
            model = AutoModelForMaskedLM.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
1013
        """
1014
        for config_class, model_class in MODEL_FOR_MASKED_LM_MAPPING.items():
1015
            if isinstance(config, config_class):
1016
                return model_class(config)
1017
        raise ValueError(
1018
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
1019
            "Model type should be one of {}.".format(
1020
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_MASKED_LM_MAPPING.keys())
1021
            )
1022
        )
1023

1024
    @classmethod
1025
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1026
        r""" Instantiates one of the language modeling model classes of the library
1027
        from a pre-trained model configuration.
1028

1029
        The `from_pretrained()` method takes care of returning the correct model class instance
1030
        based on the `model_type` property of the config object, or when it's missing,
1031
        falling back to using pattern matching on the `pretrained_model_name_or_path` string:
1032

1033
            - `distilbert`: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
1034
            - `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model)
1035
            - `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model)
1036
            - `xlm-roberta`: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-RoBERTa model)
1037
            - `longformer`: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
1038
            - `roberta`: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
1039
            - `xlm`: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
1040
            - `flaubert`: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
1041
            - `electra`: :class:`~transformers.ElectraForMaskedLM` (Electra model)
1042
            - `bert`: :class:`~transformers.BertLMHeadModel` (Bert model)
1043

1044
        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
1045
        To train the model, you should first set it back in training mode with `model.train()`
1046

1047
        Args:
1048
            pretrained_model_name_or_path:
1049
                Either:
1050

1051
                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
1052
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
1053
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
1054
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
1055
            model_args: (`optional`) Sequence of positional arguments:
1056
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method
1057
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
1058
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
1059

1060
                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
1061
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
1062
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
1063

1064
            state_dict: (`optional`) dict:
1065
                an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
1066
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
1067
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
1068
            cache_dir: (`optional`) string:
1069
                Path to a directory in which a downloaded pre-trained model
1070
                configuration should be cached if the standard cache should not be used.
1071
            force_download: (`optional`) boolean, default False:
1072
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
1073
            resume_download: (`optional`) boolean, default False:
1074
                Do not delete incompletely received file. Attempt to resume the download if such a file exists.
1075
            proxies: (`optional`) dict, default None:
1076
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
1077
                The proxies are used on each request.
1078
            output_loading_info: (`optional`) boolean:
1079
                Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
1080
            kwargs: (`optional`) Remaining dictionary of keyword arguments:
1081
                These arguments will be passed to the configuration and the model.
1082

1083
        Examples::
1084

1085
            model = AutoModelForMaskedLM.from_pretrained('bert')    # Download model and configuration from S3 and cache.
1086
            model = AutoModelForMaskedLM.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
1087
            assert model.config.output_attention == True
1088
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
1089
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
1090
            model =  AutoModelForMaskedLM.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
1091

1092
        """
1093
        config = kwargs.pop("config", None)
1094
        if not isinstance(config, PretrainedConfig):
1095
            config, kwargs = AutoConfig.from_pretrained(
1096
                pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
1097
            )
1098

1099
        for config_class, model_class in MODEL_FOR_MASKED_LM_MAPPING.items():
1100
            if isinstance(config, config_class):
1101
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
1102
        raise ValueError(
1103
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
1104
            "Model type should be one of {}.".format(
1105
                config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_FOR_MASKED_LM_MAPPING.keys())
1106
            )
1107
        )
1108

1109

1110
class AutoModelForSeq2SeqLM:
1111
    r"""
1112
        :class:`~transformers.AutoModelForSeq2SeqLM` is a generic model class
1113
        that will be instantiated as one of the language modeling model classes of the library
1114
        when created with the `AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path)`
1115
        class method.
1116

1117
        This class cannot be instantiated using `__init__()` (throws an error).
1118
    """
1119

1120
    def __init__(self):
1121
        raise EnvironmentError(
1122
            "AutoModelForSeq2SeqLM is designed to be instantiated "
1123
            "using the `AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path)` or "
1124
            "`AutoModelForSeq2SeqLM.from_config(config)` methods."
1125
        )
1126

1127
    @classmethod
1128
    def from_config(cls, config):
1129
        r""" Instantiates one of the base model classes of the library
1130
        from a configuration.
1131

1132
        Note:
1133
            Loading a model from its configuration file does **not** load the model weights.
1134
            It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
1135
            the model weights
1136

1137
        Args:
1138
            config (:class:`~transformers.PretrainedConfig`):
1139
                The model class to instantiate is selected based on the configuration class:
1140

1141
                - isInstance of `t5` configuration class: :class:`~transformers.T5ForConditionalGeneration` (T5 model)
1142
                - isInstance of `bart` configuration class: :class:`~transformers.BartForConditionalGeneration` (Bart model)
1143
                - isInstance of `marian` configuration class: :class:`~transformers.MarianMTModel` (Marian model)
1144
                - isInstance of `encoder-decoder` configuration class: :class:`~transformers.EncoderDecoderModel` (Encoder Decoder model)
1145

1146
        Examples::
1147

1148
            config = T5Config.from_pretrained('t5')
1149
            model = AutoModelForSeq2SeqLM.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
1150
        """
1151
        for config_class, model_class in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items():
1152
            if isinstance(config, config_class):
1153
                return model_class(config)
1154
        raise ValueError(
1155
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
1156
            "Model type should be one of {}.".format(
1157
                config.__class__,
1158
                cls.__name__,
1159
                ", ".join(c.__name__ for c in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys()),
1160
            )
1161
        )
1162

1163
    @classmethod
1164
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1165
        r""" Instantiates one of the language modeling model classes of the library
1166
        from a pre-trained model configuration.
1167

1168
        The `from_pretrained()` method takes care of returning the correct model class instance
1169
        based on the `model_type` property of the config object, or when it's missing,
1170
        falling back to using pattern matching on the `pretrained_model_name_or_path` string:
1171

1172
            - `t5`: :class:`~transformers.T5ForConditionalGeneration` (T5 model)
1173
            - `bart`: :class:`~transformers.BartForConditionalGeneration` (Bert model)
1174
            - `marian`: :class:`~transformers.MarianMTModel` (Marian model)
1175
            - `encoder-decoder`: :class:`~transformers.EncoderDecoderModel` (Encoder Decoder model)
1176

1177
        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
1178
        To train the model, you should first set it back in training mode with `model.train()`
1179

1180
        Args:
1181
            pretrained_model_name_or_path:
1182
                Either:
1183

1184
                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
1185
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
1186
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
1187
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
1188
            model_args: (`optional`) Sequence of positional arguments:
1189
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method
1190
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
1191
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
1192

1193
                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
1194
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
1195
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
1196

1197
            state_dict: (`optional`) dict:
1198
                an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
1199
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
1200
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
1201
            cache_dir: (`optional`) string:
1202
                Path to a directory in which a downloaded pre-trained model
1203
                configuration should be cached if the standard cache should not be used.
1204
            force_download: (`optional`) boolean, default False:
1205
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
1206
            resume_download: (`optional`) boolean, default False:
1207
                Do not delete incompletely received file. Attempt to resume the download if such a file exists.
1208
            proxies: (`optional`) dict, default None:
1209
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
1210
                The proxies are used on each request.
1211
            output_loading_info: (`optional`) boolean:
1212
                Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
1213
            kwargs: (`optional`) Remaining dictionary of keyword arguments:
1214
                These arguments will be passed to the configuration and the model.
1215

1216
        Examples::
1217

1218
            model = AutoModelForSeq2SeqLM.from_pretrained('t5-base')    # Download model and configuration from S3 and cache.
1219
            model = AutoModelForSeq2SeqLM.from_pretrained('./test/t5_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
1220
            assert model.config.output_attention == True
1221
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
1222
            config = AutoConfig.from_json_file('./tf_model/t5_tf_model_config.json')
1223
            model =  AutoModelForSeq2SeqLM.from_pretrained('./tf_model/t5_tf_checkpoint.ckpt.index', from_tf=True, config=config)
1224

1225
        """
1226
        config = kwargs.pop("config", None)
1227
        if not isinstance(config, PretrainedConfig):
1228
            config, kwargs = AutoConfig.from_pretrained(
1229
                pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
1230
            )
1231

1232
        for config_class, model_class in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items():
1233
            if isinstance(config, config_class):
1234
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
1235
        raise ValueError(
1236
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
1237
            "Model type should be one of {}.".format(
1238
                config.__class__,
1239
                cls.__name__,
1240
                ", ".join(c.__name__ for c in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys()),
1241
            )
1242
        )
1243

1244

1245
class AutoModelForSequenceClassification:
1246
    r"""
1247
        :class:`~transformers.AutoModelForSequenceClassification` is a generic model class
1248
        that will be instantiated as one of the sequence classification model classes of the library
1249
        when created with the `AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)`
1250
        class method.
1251

1252
        This class cannot be instantiated using `__init__()` (throws an error).
1253
    """
1254

1255
    def __init__(self):
1256
        raise EnvironmentError(
1257
            "AutoModelForSequenceClassification is designed to be instantiated "
1258
            "using the `AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)` or "
1259
            "`AutoModelForSequenceClassification.from_config(config)` methods."
1260
        )
1261

1262
    @classmethod
1263
    def from_config(cls, config):
1264
        r""" Instantiates one of the base model classes of the library
1265
        from a configuration.
1266

1267
        Note:
1268
            Loading a model from its configuration file does **not** load the model weights.
1269
            It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
1270
            the model weights
1271

1272
        Args:
1273
            config (:class:`~transformers.PretrainedConfig`):
1274
                The model class to instantiate is selected based on the configuration class:
1275

1276
                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForSequenceClassification` (DistilBERT model)
1277
                - isInstance of `albert` configuration class: :class:`~transformers.AlbertForSequenceClassification` (ALBERT model)
1278
                - isInstance of `camembert` configuration class: :class:`~transformers.CamembertForSequenceClassification` (CamemBERT model)
1279
                - isInstance of `xlm roberta` configuration class: :class:`~transformers.XLMRobertaForSequenceClassification` (XLM-RoBERTa model)
1280
                - isInstance of `roberta` configuration class: :class:`~transformers.RobertaForSequenceClassification` (RoBERTa model)
1281
                - isInstance of `bert` configuration class: :class:`~transformers.BertForSequenceClassification` (Bert model)
1282
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetForSequenceClassification` (XLNet model)
1283
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMForSequenceClassification` (XLM model)
1284
                - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertForSequenceClassification` (Flaubert model)
1285

1286

1287
        Examples::
1288

1289
            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
1290
            model = AutoModelForSequenceClassification.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
1291
        """
1292
        for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
1293
            if isinstance(config, config_class):
1294
                return model_class(config)
1295
        raise ValueError(
1296
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
1297
            "Model type should be one of {}.".format(
1298
                config.__class__,
1299
                cls.__name__,
1300
                ", ".join(c.__name__ for c in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()),
1301
            )
1302
        )
1303

1304
    @classmethod
1305
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1306
        r""" Instantiates one of the sequence classification model classes of the library
1307
        from a pre-trained model configuration.
1308

1309
        The `from_pretrained()` method takes care of returning the correct model class instance
1310
        based on the `model_type` property of the config object, or when it's missing,
1311
        falling back to using pattern matching on the `pretrained_model_name_or_path` string:
1312

1313
            - `distilbert`: :class:`~transformers.DistilBertForSequenceClassification` (DistilBERT model)
1314
            - `albert`: :class:`~transformers.AlbertForSequenceClassification` (ALBERT model)
1315
            - `camembert`: :class:`~transformers.CamembertForSequenceClassification` (CamemBERT model)
1316
            - `xlm-roberta`: :class:`~transformers.XLMRobertaForSequenceClassification` (XLM-RoBERTa model)
1317
            - `roberta`: :class:`~transformers.RobertaForSequenceClassification` (RoBERTa model)
1318
            - `bert`: :class:`~transformers.BertForSequenceClassification` (Bert model)
1319
            - `xlnet`: :class:`~transformers.XLNetForSequenceClassification` (XLNet model)
1320
            - `flaubert`: :class:`~transformers.FlaubertForSequenceClassification` (Flaubert model)
1321

1322
        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
1323
        To train the model, you should first set it back in training mode with `model.train()`
1324

1325
        Args:
1326
            pretrained_model_name_or_path: either:
1327

1328
                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
1329
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
1330
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
1331
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
1332

1333
            model_args: (`optional`) Sequence of positional arguments:
1334
                All remaining positional arguments will be passed to the underlying model's ``__init__`` method
1335

1336
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
1337
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
1338

1339
                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
1340
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
1341
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
1342

1343
            state_dict: (`optional`) dict:
1344
                an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
1345
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
1346
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
1347

1348
            cache_dir: (`optional`) string:
1349
                Path to a directory in which a downloaded pre-trained model
1350
                configuration should be cached if the standard cache should not be used.
1351

1352
            force_download: (`optional`) boolean, default False:
1353
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
1354

1355
            resume_download: (`optional`) boolean, default False:
1356
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
1357

1358
            proxies: (`optional`) dict, default None:
1359
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
1360
                The proxies are used on each request.
1361

1362
            output_loading_info: (`optional`) boolean:
1363
                Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
1364

1365
            kwargs: (`optional`) Remaining dictionary of keyword arguments:
1366
                These arguments will be passed to the configuration and the model.
1367

1368
        Examples::
1369

1370
            model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
1371
            model = AutoModelForSequenceClassification.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
1372
            assert model.config.output_attention == True
1373
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
1374
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
1375
            model = AutoModelForSequenceClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
1376

1377
        """
1378
        config = kwargs.pop("config", None)
1379
        if not isinstance(config, PretrainedConfig):
1380
            config, kwargs = AutoConfig.from_pretrained(
1381
                pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
1382
            )
1383

1384
        for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
1385
            if isinstance(config, config_class):
1386
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
1387
        raise ValueError(
1388
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
1389
            "Model type should be one of {}.".format(
1390
                config.__class__,
1391
                cls.__name__,
1392
                ", ".join(c.__name__ for c in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()),
1393
            )
1394
        )
1395

1396

1397
class AutoModelForQuestionAnswering:
1398
    r"""
1399
        :class:`~transformers.AutoModelForQuestionAnswering` is a generic model class
1400
        that will be instantiated as one of the question answering model classes of the library
1401
        when created with the `AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)`
1402
        class method.
1403

1404
        This class cannot be instantiated using `__init__()` (throws an error).
1405
    """
1406

1407
    def __init__(self):
1408
        raise EnvironmentError(
1409
            "AutoModelForQuestionAnswering is designed to be instantiated "
1410
            "using the `AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)` or "
1411
            "`AutoModelForQuestionAnswering.from_config(config)` methods."
1412
        )
1413

1414
    @classmethod
1415
    def from_config(cls, config):
1416
        r""" Instantiates one of the base model classes of the library
1417
        from a configuration.
1418

1419
        Note:
1420
            Loading a model from its configuration file does **not** load the model weights.
1421
            It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
1422
            the model weights
1423

1424
        Args:
1425
            config (:class:`~transformers.PretrainedConfig`):
1426
                The model class to instantiate is selected based on the configuration class:
1427

1428
                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForQuestionAnswering` (DistilBERT model)
1429
                - isInstance of `albert` configuration class: :class:`~transformers.AlbertForQuestionAnswering` (ALBERT model)
1430
                - isInstance of `bert` configuration class: :class:`~transformers.BertModelForQuestionAnswering` (Bert model)
1431
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetForQuestionAnswering` (XLNet model)
1432
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMForQuestionAnswering` (XLM model)
1433
                - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertForQuestionAnswering` (XLM model)
1434

1435
        Examples::
1436

1437
            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
1438
            model = AutoModelForQuestionAnswering.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
1439
        """
1440
        for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
1441
            if isinstance(config, config_class):
1442
                return model_class(config)
1443

1444
        raise ValueError(
1445
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
1446
            "Model type should be one of {}.".format(
1447
                config.__class__,
1448
                cls.__name__,
1449
                ", ".join(c.__name__ for c in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()),
1450
            )
1451
        )
1452

1453
    @classmethod
1454
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1455
        r""" Instantiates one of the question answering model classes of the library
1456
        from a pre-trained model configuration.
1457

1458
        The `from_pretrained()` method takes care of returning the correct model class instance
1459
        based on the `model_type` property of the config object, or when it's missing,
1460
        falling back to using pattern matching on the `pretrained_model_name_or_path` string:
1461

1462
            - `distilbert`: :class:`~transformers.DistilBertForQuestionAnswering` (DistilBERT model)
1463
            - `albert`: :class:`~transformers.AlbertForQuestionAnswering` (ALBERT model)
1464
            - `bert`: :class:`~transformers.BertForQuestionAnswering` (Bert model)
1465
            - `xlnet`: :class:`~transformers.XLNetForQuestionAnswering` (XLNet model)
1466
            - `xlm`: :class:`~transformers.XLMForQuestionAnswering` (XLM model)
1467
            - `flaubert`: :class:`~transformers.FlaubertForQuestionAnswering` (XLM model)
1468

1469
        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
1470
        To train the model, you should first set it back in training mode with `model.train()`
1471

1472
        Args:
1473
            pretrained_model_name_or_path: either:
1474

1475
                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
1476
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
1477
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
1478
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
1479

1480
            model_args: (`optional`) Sequence of positional arguments:
1481
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method
1482

1483
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
1484
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
1485

1486
                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
1487
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
1488
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
1489

1490
            state_dict: (`optional`) dict:
1491
                an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
1492
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
1493
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
1494

1495
            cache_dir: (`optional`) string:
1496
                Path to a directory in which a downloaded pre-trained model
1497
                configuration should be cached if the standard cache should not be used.
1498

1499
            force_download: (`optional`) boolean, default False:
1500
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
1501

1502
            proxies: (`optional`) dict, default None:
1503
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
1504
                The proxies are used on each request.
1505

1506
            output_loading_info: (`optional`) boolean:
1507
                Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
1508

1509
            kwargs: (`optional`) Remaining dictionary of keyword arguments:
1510
                These arguments will be passed to the configuration and the model.
1511

1512
        Examples::
1513

1514
            model = AutoModelForQuestionAnswering.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
1515
            model = AutoModelForQuestionAnswering.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
1516
            assert model.config.output_attention == True
1517
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
1518
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
1519
            model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
1520

1521
        """
1522
        config = kwargs.pop("config", None)
1523
        if not isinstance(config, PretrainedConfig):
1524
            config, kwargs = AutoConfig.from_pretrained(
1525
                pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
1526
            )
1527

1528
        for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
1529
            if isinstance(config, config_class):
1530
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
1531

1532
        raise ValueError(
1533
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
1534
            "Model type should be one of {}.".format(
1535
                config.__class__,
1536
                cls.__name__,
1537
                ", ".join(c.__name__ for c in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()),
1538
            )
1539
        )
1540

1541

1542
class AutoModelForTokenClassification:
1543
    r"""
1544
        :class:`~transformers.AutoModelForTokenClassification` is a generic model class
1545
        that will be instantiated as one of the token classification model classes of the library
1546
        when created with the `AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path)`
1547
        class method.
1548

1549
        This class cannot be instantiated using `__init__()` (throws an error).
1550
    """
1551

1552
    def __init__(self):
1553
        raise EnvironmentError(
1554
            "AutoModelForTokenClassification is designed to be instantiated "
1555
            "using the `AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path)` or "
1556
            "`AutoModelForTokenClassification.from_config(config)` methods."
1557
        )
1558

1559
    @classmethod
1560
    def from_config(cls, config):
1561
        r""" Instantiates one of the base model classes of the library
1562
        from a configuration.
1563

1564
        Note:
1565
            Loading a model from its configuration file does **not** load the model weights.
1566
            It only affects the model's configuration. Use :func:`~transformers.AutoModel.from_pretrained` to load
1567
            the model weights
1568

1569
        Args:
1570
            config (:class:`~transformers.PretrainedConfig`):
1571
                The model class to instantiate is selected based on the configuration class:
1572

1573
                - isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertModelForTokenClassification` (DistilBERT model)
1574
                - isInstance of `xlm` configuration class: :class:`~transformers.XLMForTokenClassification` (XLM model)
1575
                - isInstance of `xlm roberta` configuration class: :class:`~transformers.XLMRobertaModelForTokenClassification` (XLMRoberta model)
1576
                - isInstance of `bert` configuration class: :class:`~transformers.BertModelForTokenClassification` (Bert model)
1577
                - isInstance of `albert` configuration class: :class:`~transformers.AlbertForTokenClassification` (AlBert model)
1578
                - isInstance of `xlnet` configuration class: :class:`~transformers.XLNetModelForTokenClassification` (XLNet model)
1579
                - isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertForTokenClassification` (Flaubert model)
1580
                - isInstance of `camembert` configuration class: :class:`~transformers.CamembertModelForTokenClassification` (Camembert model)
1581
                - isInstance of `roberta` configuration class: :class:`~transformers.RobertaModelForTokenClassification` (Roberta model)
1582
                - isInstance of `electra` configuration class: :class:`~transformers.ElectraForTokenClassification` (Electra model)
1583

1584
        Examples::
1585

1586
            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
1587
            model = AutoModelForTokenClassification.from_config(config)  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
1588
        """
1589
        for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
1590
            if isinstance(config, config_class):
1591
                return model_class(config)
1592

1593
        raise ValueError(
1594
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
1595
            "Model type should be one of {}.".format(
1596
                config.__class__,
1597
                cls.__name__,
1598
                ", ".join(c.__name__ for c in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()),
1599
            )
1600
        )
1601

1602
    @classmethod
1603
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1604
        r""" Instantiates one of the question answering model classes of the library
1605
        from a pre-trained model configuration.
1606

1607
        The `from_pretrained()` method takes care of returning the correct model class instance
1608
        based on the `model_type` property of the config object, or when it's missing,
1609
        falling back to using pattern matching on the `pretrained_model_name_or_path` string:
1610

1611
            - `distilbert`: :class:`~transformers.DistilBertForTokenClassification` (DistilBERT model)
1612
            - `xlm`: :class:`~transformers.XLMForTokenClassification` (XLM model)
1613
            - `xlm-roberta`: :class:`~transformers.XLMRobertaForTokenClassification` (XLM-RoBERTa?Para model)
1614
            - `camembert`: :class:`~transformers.CamembertForTokenClassification` (Camembert model)
1615
            - `bert`: :class:`~transformers.BertForTokenClassification` (Bert model)
1616
            - `xlnet`: :class:`~transformers.XLNetForTokenClassification` (XLNet model)
1617
            - `flaubert`: :class:`~transformers.FlaubertForTokenClassification` (Flaubert model)
1618
            - `roberta`: :class:`~transformers.RobertaForTokenClassification` (Roberta model)
1619
            - `electra`: :class:`~transformers.ElectraForTokenClassification` (Electra model)
1620

1621
        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
1622
        To train the model, you should first set it back in training mode with `model.train()`
1623

1624
        Args:
1625
            pretrained_model_name_or_path:
1626
                Either:
1627

1628
                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
1629
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
1630
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
1631

1632
            model_args: (`optional`) Sequence of positional arguments:
1633
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method
1634

1635
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
1636
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
1637

1638
                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
1639
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
1640
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
1641

1642
            state_dict: (`optional`) dict:
1643
                an optional state dictionary for the model to use instead of a state dictionary loaded from saved weights file.
1644
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
1645
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
1646

1647
            cache_dir: (`optional`) string:
1648
                Path to a directory in which a downloaded pre-trained model
1649
                configuration should be cached if the standard cache should not be used.
1650

1651
            force_download: (`optional`) boolean, default False:
1652
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
1653

1654
            proxies: (`optional`) dict, default None:
1655
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
1656
                The proxies are used on each request.
1657

1658
            output_loading_info: (`optional`) boolean:
1659
                Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error messages.
1660

1661
            kwargs: (`optional`) Remaining dictionary of keyword arguments:
1662
                These arguments will be passed to the configuration and the model.
1663

1664
        Examples::
1665

1666
            model = AutoModelForTokenClassification.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
1667
            model = AutoModelForTokenClassification.from_pretrained('./test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
1668
            assert model.config.output_attention == True
1669
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
1670
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
1671
            model = AutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
1672

1673
        """
1674
        config = kwargs.pop("config", None)
1675
        if not isinstance(config, PretrainedConfig):
1676
            config, kwargs = AutoConfig.from_pretrained(
1677
                pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
1678
            )
1679

1680
        for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
1681
            if isinstance(config, config_class):
1682
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
1683

1684
        raise ValueError(
1685
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
1686
            "Model type should be one of {}.".format(
1687
                config.__class__,
1688
                cls.__name__,
1689
                ", ".join(c.__name__ for c in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()),
1690
            )
1691
        )
1692

1693

1694
class AutoModelForMultipleChoice:
1695
    r"""
1696
        :class:`~transformers.AutoModelForMultipleChoice` is a generic model class
1697
        that will be instantiated as one of the multiple choice model classes of the library
1698
        when created with the `AutoModelForMultipleChoice.from_pretrained(pretrained_model_name_or_path)`
1699
        class method.
1700

1701
        This class cannot be instantiated using `__init__()` (throws an error).
1702
    """
1703

1704
    def __init__(self):
1705
        raise EnvironmentError(
1706
            "AutoModelForMultipleChoice is designed to be instantiated "
1707
            "using the `AutoModelForMultipleChoice.from_pretrained(pretrained_model_name_or_path)` or "
1708
            "`AutoModelForMultipleChoice.from_config(config)` methods."
1709
        )
1710

1711
    @classmethod
1712
    def from_config(cls, config):
1713
        for config_class, model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items():
1714
            if isinstance(config, config_class):
1715
                return model_class(config)
1716

1717
        raise ValueError(
1718
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
1719
            "Model type should be one of {}.".format(
1720
                config.__class__,
1721
                cls.__name__,
1722
                ", ".join(c.__name__ for c in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys()),
1723
            )
1724
        )
1725

1726
    @classmethod
1727
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1728
        config = kwargs.pop("config", None)
1729
        if not isinstance(config, PretrainedConfig):
1730
            config, kwargs = AutoConfig.from_pretrained(
1731
                pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
1732
            )
1733

1734
        for config_class, model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items():
1735
            if isinstance(config, config_class):
1736
                return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
1737

1738
        raise ValueError(
1739
            "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
1740
            "Model type should be one of {}.".format(
1741
                config.__class__,
1742
                cls.__name__,
1743
                ", ".join(c.__name__ for c in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys()),
1744
            )
1745
        )
1746

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.