CSS-LM

Форк
0
/
modeling_tf_ctrl.py 
645 строк · 28.9 Кб
1
# coding=utf-8
2
# Copyright 2018 Salesforce and HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
""" TF 2.0 CTRL model."""
17

18

19
import logging
20

21
import numpy as np
22
import tensorflow as tf
23

24
from .configuration_ctrl import CTRLConfig
25
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
26
from .modeling_tf_utils import (
27
    TFCausalLanguageModelingLoss,
28
    TFPreTrainedModel,
29
    TFSharedEmbeddings,
30
    cast_bool_to_primitive,
31
    keras_serializable,
32
    shape_list,
33
)
34
from .tokenization_utils import BatchEncoding
35

36

37
logger = logging.getLogger(__name__)
38

39
_TOKENIZER_FOR_DOC = "CtrlTokenizer"
40

41
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = [
42
    "ctrl"
43
    # See all CTRL models at https://huggingface.co/models?filter=ctrl
44
]
45

46

47
def angle_defn(pos, i, d_model_size):
48
    angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model_size))
49
    return pos * angle_rates
50

51

52
def positional_encoding(position, d_model_size):
53
    # create the sinusoidal pattern for the positional encoding
54
    angle_rads = angle_defn(np.arange(position)[:, np.newaxis], np.arange(d_model_size)[np.newaxis, :], d_model_size)
55

56
    sines = np.sin(angle_rads[:, 0::2])
57
    cosines = np.cos(angle_rads[:, 1::2])
58

59
    # pos_encoding = tf.cast(np.concatenate([sines, cosines], axis=-1)[np.newaxis, ...], dtype=tf.float32)
60
    pos_encoding = tf.cast(np.concatenate([sines, cosines], axis=-1), dtype=tf.float32)
61
    return pos_encoding
62

63

64
def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=None):
65
    # calculate attention
66
    matmul_qk = tf.matmul(q, k, transpose_b=True)
67

68
    dk = tf.cast(shape_list(k)[-1], tf.float32)
69
    scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
70

71
    if mask is not None:
72
        scaled_attention_logits += mask * -1e4
73

74
    if attention_mask is not None:
75
        # Apply the attention mask
76
        scaled_attention_logits = scaled_attention_logits + attention_mask
77

78
    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
79

80
    # Mask heads if we want to
81
    if head_mask is not None:
82
        attention_weights = attention_weights * head_mask
83

84
    output = tf.matmul(attention_weights, v)
85

86
    return output, attention_weights
87

88

89
class TFMultiHeadAttention(tf.keras.layers.Layer):
90
    def __init__(self, d_model_size, num_heads, **kwargs):
91
        super().__init__(**kwargs)
92
        self.num_heads = num_heads
93
        self.d_model_size = d_model_size
94

95
        self.depth = int(d_model_size / self.num_heads)
96

97
        self.Wq = tf.keras.layers.Dense(d_model_size, name="Wq")
98
        self.Wk = tf.keras.layers.Dense(d_model_size, name="Wk")
99
        self.Wv = tf.keras.layers.Dense(d_model_size, name="Wv")
100

101
        self.dense = tf.keras.layers.Dense(d_model_size, name="dense")
102

103
    def split_into_heads(self, x, batch_size):
104
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
105
        return tf.transpose(x, perm=[0, 2, 1, 3])
106

107
    def call(self, inputs, training=False):
108
        v, k, q, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
109
        batch_size = shape_list(q)[0]
110

111
        q = self.Wq(q)
112
        k = self.Wk(k)
113
        v = self.Wv(v)
114

115
        q = self.split_into_heads(q, batch_size)
116
        k = self.split_into_heads(k, batch_size)
117
        v = self.split_into_heads(v, batch_size)
118

119
        if layer_past is not None:
120
            past_key, past_value = tf.unstack(layer_past, axis=0)
121
            k = tf.concat((past_key, k), axis=-2)
122
            v = tf.concat((past_value, v), axis=-2)
123

124
        # to cope with keras serialization
125
        use_cache = cast_bool_to_primitive(use_cache, True)
126

127
        if use_cache is True:
128
            present = tf.stack((k, v), axis=0)
129
        else:
130
            present = (None,)
131

132
        output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
133
        scaled_attention = tf.transpose(output[0], perm=[0, 2, 1, 3])
134
        attn = output[1]
135
        original_size_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model_size))
136
        output = self.dense(original_size_attention)
137

138
        outputs = (output, present)
139
        if cast_bool_to_primitive(output_attentions) is True:
140
            outputs = outputs + (attn,)
141
        return outputs
142

143

144
class TFPointWiseFeedForwardLayer(tf.keras.layers.Layer):
145
    def __init__(self, d_model_size, dff, **kwargs):
146
        super().__init__(**kwargs)
147

148
        self.dense_0 = tf.keras.layers.Dense(dff, activation="relu", name="0")
149
        self.dense_2 = tf.keras.layers.Dense(d_model_size, name="2")
150

151
    def call(self, inputs, trainable=False):
152
        dense_0_output = self.dense_0(inputs)
153
        dense_2_output = self.dense_2(dense_0_output)
154

155
        return dense_2_output
156

157

158
class TFEncoderLayer(tf.keras.layers.Layer):
159
    def __init__(self, d_model_size, num_heads, dff, rate=0.1, layer_norm_epsilon=1e-6, **kwargs):
160
        super().__init__(**kwargs)
161

162
        self.multi_head_attention = TFMultiHeadAttention(d_model_size, num_heads, name="multi_head_attention")
163
        self.ffn = TFPointWiseFeedForwardLayer(d_model_size, dff, name="ffn")
164

165
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm1")
166
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm2")
167

168
        self.dropout1 = tf.keras.layers.Dropout(rate)
169
        self.dropout2 = tf.keras.layers.Dropout(rate)
170

171
    def call(self, inputs, training=False):
172
        x, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
173
        normed = self.layernorm1(x)
174
        attn_outputs = self.multi_head_attention(
175
            [normed, normed, normed, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions],
176
            training=training,
177
        )
178
        attn_output = attn_outputs[0]
179
        attn_output = self.dropout1(attn_output, training=training)
180
        out1 = x + attn_output
181

182
        out2 = self.layernorm2(out1)
183
        ffn_output = self.ffn(out2)
184
        ffn_output = self.dropout2(ffn_output, training=training)
185
        out2 = out1 + ffn_output
186

187
        outputs = (out2,) + attn_outputs[1:]
188
        return outputs
189

190

191
@keras_serializable
192
class TFCTRLMainLayer(tf.keras.layers.Layer):
193
    config_class = CTRLConfig
194

195
    def __init__(self, config, **kwargs):
196
        super().__init__(**kwargs)
197
        self.output_hidden_states = config.output_hidden_states
198
        self.output_attentions = config.output_attentions
199
        self.use_cache = config.use_cache
200

201
        self.d_model_size = config.n_embd
202
        self.num_layers = config.n_layer
203

204
        self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size)
205

206
        self.w = TFSharedEmbeddings(
207
            config.vocab_size, config.n_embd, initializer_range=config.initializer_range, name="w"
208
        )
209

210
        self.dropout = tf.keras.layers.Dropout(config.embd_pdrop)
211
        self.h = [
212
            TFEncoderLayer(
213
                config.n_embd,
214
                config.n_head,
215
                config.dff,
216
                config.resid_pdrop,
217
                config.layer_norm_epsilon,
218
                name="h_._{}".format(i),
219
            )
220
            for i in range(config.n_layer)
221
        ]
222
        self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="layernorm")
223

224
    def get_input_embeddings(self):
225
        return self.w
226

227
    def set_input_embeddings(self, value):
228
        self.w.weight = value
229
        self.w.vocab_size = value.shape[0]
230

231
    def _resize_token_embeddings(self, new_num_tokens):
232
        raise NotImplementedError
233

234
    def _prune_heads(self, heads_to_prune):
235
        """ Prunes heads of the model.
236
                heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
237
        """
238
        raise NotImplementedError
239

240
    def call(
241
        self,
242
        inputs,
243
        past=None,
244
        attention_mask=None,
245
        token_type_ids=None,
246
        position_ids=None,
247
        head_mask=None,
248
        inputs_embeds=None,
249
        use_cache=None,
250
        output_attentions=None,
251
        output_hidden_states=None,
252
        training=False,
253
    ):
254

255
        if isinstance(inputs, (tuple, list)):
256
            input_ids = inputs[0]
257
            past = inputs[1] if len(inputs) > 1 else past
258
            attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
259
            token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
260
            position_ids = inputs[4] if len(inputs) > 4 else position_ids
261
            head_mask = inputs[5] if len(inputs) > 5 else head_mask
262
            inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
263
            use_cache = inputs[7] if len(inputs) > 7 else use_cache
264
            output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
265
            output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states
266
            assert len(inputs) <= 10, "Too many inputs."
267
        elif isinstance(inputs, (dict, BatchEncoding)):
268
            input_ids = inputs.get("input_ids")
269
            past = inputs.get("past", past)
270
            attention_mask = inputs.get("attention_mask", attention_mask)
271
            token_type_ids = inputs.get("token_type_ids", token_type_ids)
272
            position_ids = inputs.get("position_ids", position_ids)
273
            head_mask = inputs.get("head_mask", head_mask)
274
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
275
            use_cache = inputs.get("use_cache", use_cache)
276
            output_attentions = inputs.get("output_attentions", output_attentions)
277
            output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
278
            assert len(inputs) <= 10, "Too many inputs."
279
        else:
280
            input_ids = inputs
281

282
        output_attentions = output_attentions if output_attentions is not None else self.output_attentions
283
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
284
        use_cache = use_cache if use_cache is not None else self.use_cache
285

286
        # If using past key value states, only the last tokens
287
        # should be given as an input
288
        if past is not None:
289
            if input_ids is not None:
290
                input_ids = input_ids[:, -1:]
291
            if inputs_embeds is not None:
292
                inputs_embeds = inputs_embeds[:, -1:]
293
            if token_type_ids is not None:
294
                token_type_ids = token_type_ids[:, -1:]
295

296
        if input_ids is not None and inputs_embeds is not None:
297
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
298
        elif input_ids is not None:
299
            input_shape = shape_list(input_ids)
300
            input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
301
        elif inputs_embeds is not None:
302
            input_shape = shape_list(inputs_embeds)[:-1]
303
        else:
304
            raise ValueError("You have to specify either input_ids or inputs_embeds")
305

306
        if past is None:
307
            past_length = 0
308
            past = [None] * len(self.h)
309
        else:
310
            past_length = shape_list(past[0][0])[-2]
311
        if position_ids is None:
312
            position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
313
            position_ids = tf.tile(position_ids, [input_shape[0], 1])
314

315
        # Attention mask.
316
        if attention_mask is not None:
317
            # We create a 3D attention mask from a 2D tensor mask.
318
            # Sizes are [batch_size, 1, 1, to_seq_length]
319
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
320
            # this attention mask is more simple than the triangular masking of causal attention
321
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
322
            attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
323

324
            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
325
            # masked positions, this operation will create a tensor which is 0.0 for
326
            # positions we want to attend and -10000.0 for masked positions.
327
            # Since we are adding it to the raw scores before the softmax, this is
328
            # effectively the same as removing these entirely.
329

330
            attention_mask = tf.cast(attention_mask, tf.float32)
331
            attention_mask = (1.0 - attention_mask) * -10000.0
332
        else:
333
            attention_mask = None
334

335
        # Prepare head mask if needed
336
        # 1.0 in head_mask indicate we keep the head
337
        # attention_probs has shape bsz x n_heads x N x N
338
        # head_mask has shape n_layer x batch x n_heads x N x N
339
        if head_mask is not None:
340
            raise NotImplementedError
341
        else:
342
            head_mask = [None] * self.num_layers
343

344
        if token_type_ids is not None:
345
            token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
346
            token_type_embeds = self.w(token_type_ids, mode="embedding")
347
            token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
348
        else:
349
            token_type_embeds = 0
350
        position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
351

352
        if inputs_embeds is None:
353
            inputs_embeds = self.w(input_ids, mode="embedding")
354
        seq_len = input_shape[-1]
355
        mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
356

357
        inputs_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
358

359
        pos_embeds = tf.gather(self.pos_encoding, position_ids)
360

361
        hidden_states = inputs_embeds + pos_embeds + token_type_embeds
362

363
        hidden_states = self.dropout(hidden_states, training=training)
364

365
        output_shape = input_shape + [shape_list(hidden_states)[-1]]
366
        presents = ()
367
        all_hidden_states = ()
368
        all_attentions = []
369
        for i, (h, layer_past) in enumerate(zip(self.h, past)):
370
            if cast_bool_to_primitive(output_hidden_states) is True:
371
                all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
372
            outputs = h(
373
                [hidden_states, mask, layer_past, attention_mask, head_mask[i], use_cache, output_attentions],
374
                training=training,
375
            )
376
            hidden_states, present = outputs[:2]
377

378
            if use_cache is True:
379
                presents = presents + (present,)
380

381
            if cast_bool_to_primitive(output_attentions) is True:
382
                all_attentions.append(outputs[2])
383

384
        hidden_states = self.layernorm(hidden_states)
385
        hidden_states = tf.reshape(hidden_states, output_shape)
386
        if cast_bool_to_primitive(output_hidden_states) is True:
387
            all_hidden_states = all_hidden_states + (hidden_states,)
388

389
        outputs = (hidden_states,)
390
        if use_cache is True:
391
            outputs = outputs + (presents,)
392
        if cast_bool_to_primitive(output_hidden_states) is True:
393
            outputs = outputs + (all_hidden_states,)
394
        if cast_bool_to_primitive(output_attentions) is True:
395
            # let the number of heads free (-1) so we can extract attention even after head pruning
396
            attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
397
            all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
398
            outputs = outputs + (all_attentions,)
399
        return outputs
400

401

402
class TFCTRLPreTrainedModel(TFPreTrainedModel):
403
    """ An abstract class to handle weights initialization and
404
        a simple interface for downloading and loading pretrained models.
405
    """
406

407
    config_class = CTRLConfig
408
    base_model_prefix = "transformer"
409

410

411
CTRL_START_DOCSTRING = r"""
412

413
    .. note::
414
        TF 2.0 models accepts two formats as inputs:
415

416
            - having all inputs as keyword arguments (like PyTorch models), or
417
            - having all inputs as a list, tuple or dict in the first positional arguments.
418

419
        This second option is useful when using :obj:`tf.keras.Model.fit()` method which currently requires having
420
        all the tensors in the first argument of the model call function: :obj:`model(inputs)`.
421

422
        If you choose this second option, there are three possibilities you can use to gather all the input Tensors
423
        in the first positional argument :
424

425
        - a single Tensor with input_ids only and nothing else: :obj:`model(inputs_ids)`
426
        - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
427
          :obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
428
        - a dictionary with one or several input Tensors associated to the input names given in the docstring:
429
          :obj:`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
430

431
    Parameters:
432
        config (:class:`~transformers.CTRLConfig`): Model configuration class with all the parameters of the model.
433
            Initializing with a config file does not load the weights associated with the model, only the configuration.
434
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
435
"""
436

437
CTRL_INPUTS_DOCSTRING = r"""
438
    Args:
439
        input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, input_ids_length)`):
440
            :obj:`input_ids_length` = ``sequence_length`` if ``past`` is ``None`` else ``past[0].shape[-2]`` (``sequence_length`` of input past key value states).
441

442
            Indices of input sequence tokens in the vocabulary.
443

444
            If `past` is used, only input_ids that do not have their past calculated should be passed as input_ids (see `past`).
445

446
            Indices can be obtained using :class:`transformers.CTRLTokenizer`.
447
            See :func:`transformers.PreTrainedTokenizer.encode` and
448
            :func:`transformers.PreTrainedTokenizer.__call__` for details.
449

450
            `What are input IDs? <../glossary.html#input-ids>`__
451
        past (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`):
452
            Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
453
            (see `past` output below). Can be used to speed up sequential decoding.
454
            The token ids which have their past given to this model
455
            should not be passed as input ids as they have already been computed.
456
        attention_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
457
            Mask to avoid performing attention on padding token indices.
458
            Mask values selected in ``[0, 1]``:
459
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
460

461
            `What are attention masks? <../glossary.html#attention-mask>`__
462
        token_type_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
463
            Segment token indices to indicate first and second portions of the inputs.
464
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
465
            corresponds to a `sentence B` token
466

467
            `What are token type IDs? <../glossary.html#token-type-ids>`_
468
        position_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
469
            Indices of positions of each input sequence tokens in the position embeddings.
470
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
471

472
            `What are position IDs? <../glossary.html#position-ids>`_
473
        head_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
474
            Mask to nullify selected heads of the self-attention modules.
475
            Mask values selected in ``[0, 1]``:
476
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
477
        inputs_embeds (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
478
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
479
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
480
            than the model's internal embedding lookup matrix.
481
        use_cache (:obj:`bool`):
482
            If `use_cache` is True, `past` key value states are returned and
483
            can be used to speed up decoding (see `past`). Defaults to `True`.
484
        training (:obj:`boolean`, `optional`, defaults to :obj:`False`):
485
            Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them
486
            (if set to :obj:`False`) for evaluation.
487
        output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
488
            If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
489
"""
490

491

492
@add_start_docstrings(
493
    "The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.",
494
    CTRL_START_DOCSTRING,
495
)
496
class TFCTRLModel(TFCTRLPreTrainedModel):
497
    def __init__(self, config, *inputs, **kwargs):
498
        super().__init__(config, *inputs, **kwargs)
499
        self.transformer = TFCTRLMainLayer(config, name="transformer")
500

501
    @add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
502
    @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="ctrl")
503
    def call(self, inputs, **kwargs):
504
        r"""
505
    Return:
506
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.CTRLConfig`) and inputs:
507
        last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
508
            Sequence of hidden-states at the last layer of the model.
509
        past (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
510
            Contains pre-computed hidden-states (key and values in the attention blocks).
511
            Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
512
            should not be passed as input ids as they have already been computed.
513
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
514
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
515
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
516

517
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
518
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
519
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
520
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
521

522
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
523
            heads.
524
        """
525
        outputs = self.transformer(inputs, **kwargs)
526
        return outputs
527

528

529
class TFCTRLLMHead(tf.keras.layers.Layer):
530
    def __init__(self, config, input_embeddings, **kwargs):
531
        super().__init__(**kwargs)
532
        self.vocab_size = config.vocab_size
533

534
        # The output weights are the same as the input embeddings, but there is
535
        # an output-only bias for each token.
536
        self.input_embeddings = input_embeddings
537

538
    def build(self, input_shape):
539
        self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
540
        super().build(input_shape)
541

542
    def call(self, hidden_states):
543
        hidden_states = self.input_embeddings(hidden_states, mode="linear")
544
        hidden_states = hidden_states + self.bias
545
        return hidden_states
546

547

548
@add_start_docstrings(
549
    """The CTRL Model transformer with a language modeling head on top
550
    (linear layer with weights tied to the input embeddings). """,
551
    CTRL_START_DOCSTRING,
552
)
553
class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
554
    def __init__(self, config, *inputs, **kwargs):
555
        super().__init__(config, *inputs, **kwargs)
556
        self.transformer = TFCTRLMainLayer(config, name="transformer")
557

558
        self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head")
559

560
    def get_output_embeddings(self):
561
        return self.lm_head.input_embeddings
562

563
    def prepare_inputs_for_generation(self, inputs, past, **kwargs):
564
        # only last token for inputs_ids if past is defined in kwargs
565
        if past:
566
            inputs = tf.expand_dims(inputs[:, -1], -1)
567

568
        return {"inputs": inputs, "past": past, "use_cache": kwargs["use_cache"]}
569

570
    @add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
571
    @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="ctrl")
572
    def call(
573
        self,
574
        inputs,
575
        past=None,
576
        attention_mask=None,
577
        token_type_ids=None,
578
        position_ids=None,
579
        head_mask=None,
580
        inputs_embeds=None,
581
        use_cache=None,
582
        output_attentions=None,
583
        output_hidden_states=None,
584
        labels=None,
585
        training=False,
586
    ):
587
        r"""
588
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
589
            Labels for computing the cross entropy classification loss.
590
            Indices should be in ``[0, ..., config.vocab_size - 1]``.
591

592
    Return:
593
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.CTRLConfig`) and inputs:
594
        prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
595
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
596
        past (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
597
            Contains pre-computed hidden-states (key and values in the attention blocks).
598
            Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
599
            should not be passed as input ids as they have already been computed.
600
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
601
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
602
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
603

604
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
605
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
606
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
607
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
608

609
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
610
            heads.
611
        """
612
        if isinstance(inputs, (tuple, list)):
613
            labels = inputs[10] if len(inputs) > 10 else labels
614
            if len(inputs) > 10:
615
                inputs = inputs[:10]
616
        elif isinstance(inputs, (dict, BatchEncoding)):
617
            labels = inputs.pop("labels", labels)
618

619
        transformer_outputs = self.transformer(
620
            inputs,
621
            past=past,
622
            attention_mask=attention_mask,
623
            token_type_ids=token_type_ids,
624
            position_ids=position_ids,
625
            head_mask=head_mask,
626
            inputs_embeds=inputs_embeds,
627
            use_cache=use_cache,
628
            output_attentions=output_attentions,
629
            output_hidden_states=output_hidden_states,
630
            training=training,
631
        )
632

633
        hidden_states = transformer_outputs[0]
634

635
        logits = self.lm_head(hidden_states)
636

637
        outputs = (logits,) + transformer_outputs[1:]
638
        if labels is not None:
639
            # shift labels to the left and cut last logit token
640
            logits = logits[:, :-1]
641
            labels = labels[:, 1:]
642
            loss = self.compute_loss(labels, logits)
643
            outputs = (loss,) + outputs
644

645
        return outputs  # lm_logits, presents, (all hidden_states), (attentions)
646

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.