CSS-LM

Форк
0
/
__init__.py 
671 строка · 21.8 Кб
1
# flake8: noqa
2
# There's no way to ignore "F401 '...' imported but unused" warnings in this
3
# module, but to preserve other warnings. So, don't check this module at all.
4

5
__version__ = "3.0.2"
6

7
# Work around to update TensorFlow's absl.logging threshold which alters the
8
# default Python logging output behavior when present.
9
# see: https://github.com/abseil/abseil-py/issues/99
10
# and: https://github.com/tensorflow/tensorflow/issues/26691#issuecomment-500369493
11
try:
12
    import absl.logging
13
except ImportError:
14
    pass
15
else:
16
    absl.logging.set_verbosity("info")
17
    absl.logging.set_stderrthreshold("info")
18
    absl.logging._warn_preinit_stderr = False
19

20
import logging
21

22
# Configurations
23
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
24
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
25
from .configuration_bart import BartConfig, MBartConfig
26
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
27
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
28
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
29
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
30
from .configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig
31
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
32
from .configuration_encoder_decoder import EncoderDecoderConfig
33
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
34
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
35
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
36
from .configuration_marian import MarianConfig
37
from .configuration_mmbt import MMBTConfig
38
from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig
39
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
40
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
41
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
42
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
43
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
44
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
45
from .configuration_utils import PretrainedConfig
46
from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
47
from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
48
from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
49
from .data import (
50
    DataProcessor,
51
    InputExample,
52
    InputFeatures,
53
    SingleSentenceClassificationProcessor,
54
    SquadExample,
55
    SquadFeatures,
56
    SquadV1Processor,
57
    SquadV2Processor,
58
    glue_convert_examples_to_features,
59
    glue_output_modes,
60
    glue_processors,
61
    glue_tasks_num_labels,
62
    is_sklearn_available,
63
    squad_convert_examples_to_features,
64
    xnli_output_modes,
65
    xnli_processors,
66
    xnli_tasks_num_labels,
67
)
68

69
# Files and general utilities
70
from .file_utils import (
71
    CONFIG_NAME,
72
    MODEL_CARD_NAME,
73
    PYTORCH_PRETRAINED_BERT_CACHE,
74
    PYTORCH_TRANSFORMERS_CACHE,
75
    TF2_WEIGHTS_NAME,
76
    TF_WEIGHTS_NAME,
77
    TRANSFORMERS_CACHE,
78
    WEIGHTS_NAME,
79
    add_end_docstrings,
80
    add_start_docstrings,
81
    cached_path,
82
    is_apex_available,
83
    is_psutil_available,
84
    is_py3nvml_available,
85
    is_tf_available,
86
    is_torch_available,
87
    is_torch_tpu_available,
88
)
89
from .hf_argparser import HfArgumentParser
90

91
# Model Cards
92
from .modelcard import ModelCard
93

94
# TF 2.0 <=> PyTorch conversion utilities
95
from .modeling_tf_pytorch_utils import (
96
    convert_tf_weight_name_to_pt_weight_name,
97
    load_pytorch_checkpoint_in_tf2_model,
98
    load_pytorch_model_in_tf2_model,
99
    load_pytorch_weights_in_tf2_model,
100
    load_tf2_checkpoint_in_pytorch_model,
101
    load_tf2_model_in_pytorch_model,
102
    load_tf2_weights_in_pytorch_model,
103
)
104

105
# Pipelines
106
from .pipelines import (
107
    Conversation,
108
    ConversationalPipeline,
109
    CsvPipelineDataFormat,
110
    FeatureExtractionPipeline,
111
    FillMaskPipeline,
112
    JsonPipelineDataFormat,
113
    NerPipeline,
114
    PipedPipelineDataFormat,
115
    Pipeline,
116
    PipelineDataFormat,
117
    QuestionAnsweringPipeline,
118
    SummarizationPipeline,
119
    TextClassificationPipeline,
120
    TextGenerationPipeline,
121
    TokenClassificationPipeline,
122
    TranslationPipeline,
123
    pipeline,
124
)
125

126
# Tokenizers
127
from .tokenization_albert import AlbertTokenizer
128
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
129
from .tokenization_bart import BartTokenizer, BartTokenizerFast, MBartTokenizer
130
from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer
131
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
132
from .tokenization_camembert import CamembertTokenizer
133
from .tokenization_ctrl import CTRLTokenizer
134
from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast
135
from .tokenization_dpr import (
136
    DPRContextEncoderTokenizer,
137
    DPRContextEncoderTokenizerFast,
138
    DPRQuestionEncoderTokenizer,
139
    DPRQuestionEncoderTokenizerFast,
140
    DPRReaderTokenizer,
141
    DPRReaderTokenizerFast,
142
)
143
from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
144
from .tokenization_flaubert import FlaubertTokenizer
145
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
146
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
147
from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast
148
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
149
from .tokenization_reformer import ReformerTokenizer
150
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
151
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
152
from .tokenization_t5 import T5Tokenizer
153
from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer, TransfoXLTokenizerFast
154
from .tokenization_utils import PreTrainedTokenizer
155
from .tokenization_utils_base import (
156
    BatchEncoding,
157
    CharSpan,
158
    PreTrainedTokenizerBase,
159
    SpecialTokensMixin,
160
    TensorType,
161
    TokenSpan,
162
)
163
from .tokenization_utils_fast import PreTrainedTokenizerFast
164
from .tokenization_xlm import XLMTokenizer
165
from .tokenization_xlm_roberta import XLMRobertaTokenizer
166
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
167

168
# Trainer
169
from .trainer_utils import EvalPrediction, set_seed
170
from .training_args import TrainingArguments
171
from .training_args_tf import TFTrainingArguments
172

173

174
logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
175

176

177
if is_sklearn_available():
178
    from .data import glue_compute_metrics, xnli_compute_metrics
179

180

181
# Modeling
182
if is_torch_available():
183
    from .generation_utils import top_k_top_p_filtering
184
    from .modeling_utils import PreTrainedModel, prune_layer, Conv1D, apply_chunking_to_forward
185
    from .modeling_auto import (
186
        AutoModel,
187
        AutoModelForPreTraining,
188
        AutoModelForSequenceClassification,
189
        AutoModelForQuestionAnswering,
190
        AutoModelWithLMHead,
191
        AutoModelForCausalLM,
192
        AutoModelForMaskedLM,
193
        AutoModelForSeq2SeqLM,
194
        AutoModelForTokenClassification,
195
        AutoModelForMultipleChoice,
196
        MODEL_MAPPING,
197
        MODEL_FOR_PRETRAINING_MAPPING,
198
        MODEL_WITH_LM_HEAD_MAPPING,
199
        MODEL_FOR_CAUSAL_LM_MAPPING,
200
        MODEL_FOR_MASKED_LM_MAPPING,
201
        MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
202
        MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
203
        MODEL_FOR_QUESTION_ANSWERING_MAPPING,
204
        MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
205
        MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
206
    )
207

208
    from .modeling_mobilebert import (
209
        MobileBertPreTrainedModel,
210
        MobileBertModel,
211
        MobileBertForPreTraining,
212
        MobileBertForSequenceClassification,
213
        MobileBertForQuestionAnswering,
214
        MobileBertForMaskedLM,
215
        MobileBertForNextSentencePrediction,
216
        MobileBertForMultipleChoice,
217
        MobileBertForTokenClassification,
218
        load_tf_weights_in_mobilebert,
219
        MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
220
        MobileBertLayer,
221
    )
222

223
    from .modeling_bert import (
224
        BertPreTrainedModel,
225
        BertModel,
226
        BertForPreTraining,
227
        BertForMaskedLM,
228
        BertLMHeadModel,
229
        BertForNextSentencePrediction,
230
        BertForSequenceClassification,
231
        BertForMultipleChoice,
232
        BertForTokenClassification,
233
        BertForQuestionAnswering,
234
        load_tf_weights_in_bert,
235
        BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
236
        BertLayer,
237
    )
238
    from .modeling_openai import (
239
        OpenAIGPTPreTrainedModel,
240
        OpenAIGPTModel,
241
        OpenAIGPTLMHeadModel,
242
        OpenAIGPTDoubleHeadsModel,
243
        load_tf_weights_in_openai_gpt,
244
        OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
245
    )
246
    from .modeling_transfo_xl import (
247
        TransfoXLPreTrainedModel,
248
        TransfoXLModel,
249
        TransfoXLLMHeadModel,
250
        AdaptiveEmbedding,
251
        load_tf_weights_in_transfo_xl,
252
        TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
253
    )
254
    from .modeling_gpt2 import (
255
        GPT2PreTrainedModel,
256
        GPT2Model,
257
        GPT2LMHeadModel,
258
        GPT2DoubleHeadsModel,
259
        load_tf_weights_in_gpt2,
260
        GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
261
    )
262
    from .modeling_ctrl import CTRLPreTrainedModel, CTRLModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_LIST
263
    from .modeling_xlnet import (
264
        XLNetPreTrainedModel,
265
        XLNetModel,
266
        XLNetLMHeadModel,
267
        XLNetForSequenceClassification,
268
        XLNetForTokenClassification,
269
        XLNetForMultipleChoice,
270
        XLNetForQuestionAnsweringSimple,
271
        XLNetForQuestionAnswering,
272
        load_tf_weights_in_xlnet,
273
        XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
274
    )
275
    from .modeling_xlm import (
276
        XLMPreTrainedModel,
277
        XLMModel,
278
        XLMWithLMHeadModel,
279
        XLMForSequenceClassification,
280
        XLMForTokenClassification,
281
        XLMForQuestionAnswering,
282
        XLMForQuestionAnsweringSimple,
283
        XLMForMultipleChoice,
284
        XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
285
    )
286
    from .modeling_bart import (
287
        PretrainedBartModel,
288
        BartForSequenceClassification,
289
        BartModel,
290
        BartForConditionalGeneration,
291
        BartForQuestionAnswering,
292
        BART_PRETRAINED_MODEL_ARCHIVE_LIST,
293
    )
294
    from .modeling_marian import MarianMTModel
295
    from .tokenization_marian import MarianTokenizer
296
    from .modeling_roberta import (
297
        RobertaForMaskedLM,
298
        RobertaModel,
299
        RobertaForSequenceClassification,
300
        RobertaForMultipleChoice,
301
        RobertaForTokenClassification,
302
        RobertaForQuestionAnswering,
303
        ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
304
    )
305
    from .modeling_distilbert import (
306
        DistilBertPreTrainedModel,
307
        DistilBertForMaskedLM,
308
        DistilBertModel,
309
        DistilBertForMultipleChoice,
310
        DistilBertForSequenceClassification,
311
        DistilBertForQuestionAnswering,
312
        DistilBertForTokenClassification,
313
        DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
314
    )
315
    from .modeling_camembert import (
316
        CamembertForMaskedLM,
317
        CamembertModel,
318
        CamembertForSequenceClassification,
319
        CamembertForMultipleChoice,
320
        CamembertForTokenClassification,
321
        CamembertForQuestionAnswering,
322
        CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
323
    )
324
    from .modeling_encoder_decoder import EncoderDecoderModel
325
    from .modeling_t5 import (
326
        T5PreTrainedModel,
327
        T5Model,
328
        T5ForConditionalGeneration,
329
        load_tf_weights_in_t5,
330
        T5_PRETRAINED_MODEL_ARCHIVE_LIST,
331
    )
332
    from .modeling_albert import (
333
        AlbertPreTrainedModel,
334
        AlbertModel,
335
        AlbertForPreTraining,
336
        AlbertForMaskedLM,
337
        AlbertForMultipleChoice,
338
        AlbertForSequenceClassification,
339
        AlbertForQuestionAnswering,
340
        AlbertForTokenClassification,
341
        load_tf_weights_in_albert,
342
        ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
343
    )
344
    from .modeling_xlm_roberta import (
345
        XLMRobertaForMaskedLM,
346
        XLMRobertaModel,
347
        XLMRobertaForMultipleChoice,
348
        XLMRobertaForSequenceClassification,
349
        XLMRobertaForTokenClassification,
350
        XLMRobertaForQuestionAnswering,
351
        XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
352
    )
353
    from .modeling_mmbt import ModalEmbeddings, MMBTModel, MMBTForClassification
354

355
    from .modeling_flaubert import (
356
        FlaubertModel,
357
        FlaubertWithLMHeadModel,
358
        FlaubertForSequenceClassification,
359
        FlaubertForTokenClassification,
360
        FlaubertForQuestionAnswering,
361
        FlaubertForQuestionAnsweringSimple,
362
        FlaubertForTokenClassification,
363
        FlaubertForMultipleChoice,
364
        FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
365
    )
366

367
    from .modeling_electra import (
368
        ElectraForPreTraining,
369
        ElectraForMaskedLM,
370
        ElectraForTokenClassification,
371
        ElectraPreTrainedModel,
372
        ElectraForMultipleChoice,
373
        ElectraForSequenceClassification,
374
        ElectraForQuestionAnswering,
375
        ElectraModel,
376
        load_tf_weights_in_electra,
377
        ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
378
    )
379

380
    from .modeling_reformer import (
381
        ReformerAttention,
382
        ReformerLayer,
383
        ReformerModel,
384
        ReformerForMaskedLM,
385
        ReformerModelWithLMHead,
386
        ReformerForSequenceClassification,
387
        ReformerForQuestionAnswering,
388
        REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
389
    )
390

391
    from .modeling_longformer import (
392
        LongformerModel,
393
        LongformerForMaskedLM,
394
        LongformerForSequenceClassification,
395
        LongformerForMultipleChoice,
396
        LongformerForTokenClassification,
397
        LongformerForQuestionAnswering,
398
        LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
399
    )
400

401
    from .modeling_dpr import (
402
        DPRPretrainedContextEncoder,
403
        DPRPretrainedQuestionEncoder,
404
        DPRPretrainedReader,
405
        DPRContextEncoder,
406
        DPRQuestionEncoder,
407
        DPRReader,
408
    )
409
    from .modeling_retribert import (
410
        RetriBertPreTrainedModel,
411
        RetriBertModel,
412
        RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
413
    )
414

415
    # Optimization
416
    from .optimization import (
417
        AdamW,
418
        get_constant_schedule,
419
        get_constant_schedule_with_warmup,
420
        get_cosine_schedule_with_warmup,
421
        get_cosine_with_hard_restarts_schedule_with_warmup,
422
        get_linear_schedule_with_warmup,
423
    )
424

425
    # Trainer
426
    from .trainer import Trainer, set_seed, torch_distributed_zero_first, EvalPrediction
427
    from .data.data_collator import (
428
        default_data_collator,
429
        DataCollator,
430
        DataCollatorForLanguageModeling,
431
        DataCollatorForPermutationLanguageModeling,
432
    )
433
    from .data.datasets import (
434
        GlueDataset,
435
        TextDataset,
436
        LineByLineTextDataset,
437
        GlueDataTrainingArguments,
438
        SquadDataset,
439
        SquadDataTrainingArguments,
440
    )
441

442
    # Benchmarks
443
    from .benchmark.benchmark import PyTorchBenchmark
444
    from .benchmark.benchmark_args import PyTorchBenchmarkArguments
445

446
# TensorFlow
447
if is_tf_available():
448
    from .generation_tf_utils import tf_top_k_top_p_filtering
449
    from .modeling_tf_utils import (
450
        shape_list,
451
        TFPreTrainedModel,
452
        TFSequenceSummary,
453
        TFSharedEmbeddings,
454
    )
455
    from .modeling_tf_auto import (
456
        TF_MODEL_MAPPING,
457
        TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
458
        TF_MODEL_FOR_PRETRAINING_MAPPING,
459
        TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
460
        TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
461
        TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
462
        TF_MODEL_WITH_LM_HEAD_MAPPING,
463
        TF_MODEL_FOR_CAUSAL_LM_MAPPING,
464
        TF_MODEL_FOR_MASKED_LM_MAPPING,
465
        TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
466
        TFAutoModel,
467
        TFAutoModelForMultipleChoice,
468
        TFAutoModelForPreTraining,
469
        TFAutoModelForQuestionAnswering,
470
        TFAutoModelForSequenceClassification,
471
        TFAutoModelForTokenClassification,
472
        TFAutoModelWithLMHead,
473
        TFAutoModelForCausalLM,
474
        TFAutoModelForMaskedLM,
475
        TFAutoModelForSeq2SeqLM,
476
    )
477

478
    from .modeling_tf_albert import (
479
        TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
480
        TFAlbertForMaskedLM,
481
        TFAlbertForMultipleChoice,
482
        TFAlbertForPreTraining,
483
        TFAlbertForQuestionAnswering,
484
        TFAlbertForSequenceClassification,
485
        TFAlbertForTokenClassification,
486
        TFAlbertMainLayer,
487
        TFAlbertModel,
488
        TFAlbertPreTrainedModel,
489
    )
490

491
    from .modeling_tf_bert import (
492
        TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
493
        TFBertEmbeddings,
494
        TFBertLMHeadModel,
495
        TFBertForMaskedLM,
496
        TFBertForMultipleChoice,
497
        TFBertForNextSentencePrediction,
498
        TFBertForPreTraining,
499
        TFBertForQuestionAnswering,
500
        TFBertForSequenceClassification,
501
        TFBertForTokenClassification,
502
        TFBertMainLayer,
503
        TFBertModel,
504
        TFBertPreTrainedModel,
505
    )
506

507
    from .modeling_tf_camembert import (
508
        TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
509
        TFCamembertForMaskedLM,
510
        TFCamembertModel,
511
        TFCamembertForMultipleChoice,
512
        TFCamembertForQuestionAnswering,
513
        TFCamembertForSequenceClassification,
514
        TFCamembertForTokenClassification,
515
    )
516

517
    from .modeling_tf_ctrl import (
518
        TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
519
        TFCTRLLMHeadModel,
520
        TFCTRLModel,
521
        TFCTRLPreTrainedModel,
522
    )
523

524
    from .modeling_tf_distilbert import (
525
        TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
526
        TFDistilBertForMaskedLM,
527
        TFDistilBertForMultipleChoice,
528
        TFDistilBertForQuestionAnswering,
529
        TFDistilBertForSequenceClassification,
530
        TFDistilBertForTokenClassification,
531
        TFDistilBertMainLayer,
532
        TFDistilBertModel,
533
        TFDistilBertPreTrainedModel,
534
    )
535

536
    from .modeling_tf_electra import (
537
        TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
538
        TFElectraForMaskedLM,
539
        TFElectraForPreTraining,
540
        TFElectraForQuestionAnswering,
541
        TFElectraForTokenClassification,
542
        TFElectraModel,
543
        TFElectraPreTrainedModel,
544
    )
545

546
    from .modeling_tf_flaubert import (
547
        TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
548
        TFFlaubertForMultipleChoice,
549
        TFFlaubertForQuestionAnsweringSimple,
550
        TFFlaubertForSequenceClassification,
551
        TFFlaubertForTokenClassification,
552
        TFFlaubertWithLMHeadModel,
553
        TFFlaubertModel,
554
    )
555

556
    from .modeling_tf_gpt2 import (
557
        TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
558
        TFGPT2DoubleHeadsModel,
559
        TFGPT2LMHeadModel,
560
        TFGPT2MainLayer,
561
        TFGPT2Model,
562
        TFGPT2PreTrainedModel,
563
    )
564

565
    from .modeling_tf_mobilebert import (
566
        TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
567
        TFMobileBertModel,
568
        TFMobileBertPreTrainedModel,
569
        TFMobileBertForPreTraining,
570
        TFMobileBertForSequenceClassification,
571
        TFMobileBertForQuestionAnswering,
572
        TFMobileBertForMaskedLM,
573
        TFMobileBertForNextSentencePrediction,
574
        TFMobileBertForMultipleChoice,
575
        TFMobileBertForTokenClassification,
576
        TFMobileBertMainLayer,
577
    )
578

579
    from .modeling_tf_openai import (
580
        TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
581
        TFOpenAIGPTDoubleHeadsModel,
582
        TFOpenAIGPTLMHeadModel,
583
        TFOpenAIGPTMainLayer,
584
        TFOpenAIGPTModel,
585
        TFOpenAIGPTPreTrainedModel,
586
    )
587

588
    from .modeling_tf_roberta import (
589
        TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
590
        TFRobertaForMaskedLM,
591
        TFRobertaForMultipleChoice,
592
        TFRobertaForQuestionAnswering,
593
        TFRobertaForSequenceClassification,
594
        TFRobertaForTokenClassification,
595
        TFRobertaMainLayer,
596
        TFRobertaModel,
597
        TFRobertaPreTrainedModel,
598
    )
599

600
    from .modeling_tf_t5 import (
601
        TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
602
        TFT5ForConditionalGeneration,
603
        TFT5Model,
604
        TFT5PreTrainedModel,
605
    )
606

607
    from .modeling_tf_transfo_xl import (
608
        TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
609
        TFAdaptiveEmbedding,
610
        TFTransfoXLLMHeadModel,
611
        TFTransfoXLMainLayer,
612
        TFTransfoXLModel,
613
        TFTransfoXLPreTrainedModel,
614
    )
615

616
    from .modeling_tf_xlm import (
617
        TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
618
        TFXLMForMultipleChoice,
619
        TFXLMForQuestionAnsweringSimple,
620
        TFXLMForSequenceClassification,
621
        TFXLMForTokenClassification,
622
        TFXLMWithLMHeadModel,
623
        TFXLMMainLayer,
624
        TFXLMModel,
625
        TFXLMPreTrainedModel,
626
    )
627

628
    from .modeling_tf_xlm_roberta import (
629
        TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
630
        TFXLMRobertaForMaskedLM,
631
        TFXLMRobertaForMultipleChoice,
632
        TFXLMRobertaForQuestionAnswering,
633
        TFXLMRobertaForSequenceClassification,
634
        TFXLMRobertaForTokenClassification,
635
        TFXLMRobertaModel,
636
    )
637

638
    from .modeling_tf_xlnet import (
639
        TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
640
        TFXLNetForMultipleChoice,
641
        TFXLNetForQuestionAnsweringSimple,
642
        TFXLNetForSequenceClassification,
643
        TFXLNetForTokenClassification,
644
        TFXLNetLMHeadModel,
645
        TFXLNetMainLayer,
646
        TFXLNetModel,
647
        TFXLNetPreTrainedModel,
648
    )
649

650
    # Optimization
651
    from .optimization_tf import (
652
        AdamWeightDecay,
653
        create_optimizer,
654
        GradientAccumulator,
655
        WarmUp,
656
    )
657

658
    # Trainer
659
    from .trainer_tf import TFTrainer
660

661
    # Benchmarks
662
    from .benchmark.benchmark_tf import TensorFlowBenchmark
663
    from .benchmark.benchmark_args_tf import TensorFlowBenchmarkArguments
664

665

666
if not is_tf_available() and not is_torch_available():
667
    logger.warning(
668
        "Neither PyTorch nor TensorFlow >= 2.0 have been found."
669
        "Models won't be available and only tokenizers, configuration"
670
        "and file/data utilities can be used."
671
    )
672

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.