CSS-LM

Форк
0
/
modeling_tf_xlm_roberta.py 
147 строк · 5.9 Кб
1
# coding=utf-8
2
# Copyright 2019 Facebook AI Research and the HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
""" TF 2.0  XLM-RoBERTa model. """
17

18

19
import logging
20

21
from .configuration_xlm_roberta import XLMRobertaConfig
22
from .file_utils import add_start_docstrings
23
from .modeling_tf_roberta import (
24
    TFRobertaForMaskedLM,
25
    TFRobertaForMultipleChoice,
26
    TFRobertaForQuestionAnswering,
27
    TFRobertaForSequenceClassification,
28
    TFRobertaForTokenClassification,
29
    TFRobertaModel,
30
)
31

32

33
logger = logging.getLogger(__name__)
34

35
TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
36
    # See all XLM-RoBERTa models at https://huggingface.co/models?filter=xlm-roberta
37
]
38

39

40
XLM_ROBERTA_START_DOCSTRING = r"""
41

42
    .. note::
43

44
        TF 2.0 models accepts two formats as inputs:
45

46
            - having all inputs as keyword arguments (like PyTorch models), or
47
            - having all inputs as a list, tuple or dict in the first positional arguments.
48

49
        This second option is useful when using :obj:`tf.keras.Model.fit()` method which currently requires having
50
        all the tensors in the first argument of the model call function: :obj:`model(inputs)`.
51

52
        If you choose this second option, there are three possibilities you can use to gather all the input Tensors
53
        in the first positional argument :
54

55
        - a single Tensor with input_ids only and nothing else: :obj:`model(inputs_ids)`
56
        - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
57
          :obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
58
        - a dictionary with one or several input Tensors associated to the input names given in the docstring:
59
          :obj:`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
60

61
    Parameters:
62
        config (:class:`~transformers.XLMRobertaConfig`): Model configuration class with all the parameters of the
63
            model. Initializing with a config file does not load the weights associated with the model, only the configuration.
64
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
65
        output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
66
            If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
67
"""
68

69

70
@add_start_docstrings(
71
    "The bare XLM-RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
72
    XLM_ROBERTA_START_DOCSTRING,
73
)
74
class TFXLMRobertaModel(TFRobertaModel):
75
    """
76
    This class overrides :class:`~transformers.TFRobertaModel`. Please check the
77
    superclass for the appropriate documentation alongside usage examples.
78
    """
79

80
    config_class = XLMRobertaConfig
81

82

83
@add_start_docstrings(
84
    """XLM-RoBERTa Model with a `language modeling` head on top. """, XLM_ROBERTA_START_DOCSTRING,
85
)
86
class TFXLMRobertaForMaskedLM(TFRobertaForMaskedLM):
87
    """
88
    This class overrides :class:`~transformers.TFRobertaForMaskedLM`. Please check the
89
    superclass for the appropriate documentation alongside usage examples.
90
    """
91

92
    config_class = XLMRobertaConfig
93

94

95
@add_start_docstrings(
96
    """XLM-RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer
97
    on top of the pooled output) e.g. for GLUE tasks. """,
98
    XLM_ROBERTA_START_DOCSTRING,
99
)
100
class TFXLMRobertaForSequenceClassification(TFRobertaForSequenceClassification):
101
    """
102
    This class overrides :class:`~transformers.TFRobertaForSequenceClassification`. Please check the
103
    superclass for the appropriate documentation alongside usage examples.
104
    """
105

106
    config_class = XLMRobertaConfig
107

108

109
@add_start_docstrings(
110
    """XLM-RoBERTa Model with a token classification head on top (a linear layer on top of
111
    the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
112
    XLM_ROBERTA_START_DOCSTRING,
113
)
114
class TFXLMRobertaForTokenClassification(TFRobertaForTokenClassification):
115
    """
116
    This class overrides :class:`~transformers.TFRobertaForTokenClassification`. Please check the
117
    superclass for the appropriate documentation alongside usage examples.
118
    """
119

120
    config_class = XLMRobertaConfig
121

122

123
@add_start_docstrings(
124
    """XLM-RoBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
125
    XLM_ROBERTA_START_DOCSTRING,
126
)
127
class TFXLMRobertaForQuestionAnswering(TFRobertaForQuestionAnswering):
128
    """
129
    This class overrides :class:`~transformers.TFRobertaForQuestionAnsweringSimple`. Please check the
130
    superclass for the appropriate documentation alongside usage examples.
131
    """
132

133
    config_class = XLMRobertaConfig
134

135

136
@add_start_docstrings(
137
    """Roberta Model with a multiple choice classification head on top (a linear layer on top of
138
    the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
139
    XLM_ROBERTA_START_DOCSTRING,
140
)
141
class TFXLMRobertaForMultipleChoice(TFRobertaForMultipleChoice):
142
    """
143
    This class overrides :class:`~transformers.TFRobertaForMultipleChoice`. Please check the
144
    superclass for the appropriate documentation alongside usage examples.
145
    """
146

147
    config_class = XLMRobertaConfig
148

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.