CSS-LM

Форк
0
/
modeling_tf_xlnet.py 
1595 строк · 74.4 Кб
1
# coding=utf-8
2
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
""" TF 2.0 XLNet model.
17
"""
18

19

20
import logging
21

22
import numpy as np
23
import tensorflow as tf
24

25
from .configuration_xlnet import XLNetConfig
26
from .file_utils import (
27
    MULTIPLE_CHOICE_DUMMY_INPUTS,
28
    add_code_sample_docstrings,
29
    add_start_docstrings,
30
    add_start_docstrings_to_callable,
31
)
32
from .modeling_tf_utils import (
33
    TFCausalLanguageModelingLoss,
34
    TFMultipleChoiceLoss,
35
    TFPreTrainedModel,
36
    TFQuestionAnsweringLoss,
37
    TFSequenceClassificationLoss,
38
    TFSequenceSummary,
39
    TFSharedEmbeddings,
40
    TFTokenClassificationLoss,
41
    cast_bool_to_primitive,
42
    get_initializer,
43
    keras_serializable,
44
    shape_list,
45
)
46
from .tokenization_utils import BatchEncoding
47

48

49
logger = logging.getLogger(__name__)
50

51
_TOKENIZER_FOR_DOC = "XLNetTokenizer"
52

53
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
54
    "xlnet-base-cased",
55
    "xlnet-large-cased",
56
    # See all XLNet models at https://huggingface.co/models?filter=xlnet
57
]
58

59

60
def gelu(x):
61
    """ Implementation of the gelu activation function.
62
        XLNet is using OpenAI GPT's gelu
63
        Also see https://arxiv.org/abs/1606.08415
64
    """
65
    cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
66
    return x * cdf
67

68

69
def swish(x):
70
    return x * tf.sigmoid(x)
71

72

73
ACT2FN = {
74
    "gelu": tf.keras.layers.Activation(gelu),
75
    "relu": tf.keras.activations.relu,
76
    "swish": tf.keras.layers.Activation(swish),
77
}
78

79

80
class TFXLNetRelativeAttention(tf.keras.layers.Layer):
81
    def __init__(self, config, **kwargs):
82
        super().__init__(**kwargs)
83

84
        if config.d_model % config.n_head != 0:
85
            raise ValueError(
86
                "The hidden size (%d) is not a multiple of the number of attention "
87
                "heads (%d)" % (config.d_model, config.n_head)
88
            )
89

90
        self.n_head = config.n_head
91
        self.d_head = config.d_head
92
        self.d_model = config.d_model
93
        self.scale = 1 / (config.d_head ** 0.5)
94
        self.initializer_range = config.initializer_range
95

96
        self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
97
        self.dropout = tf.keras.layers.Dropout(config.dropout)
98

99
    def build(self, input_shape):
100
        initializer = get_initializer(self.initializer_range)
101
        self.q = self.add_weight(
102
            shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="q"
103
        )
104
        self.k = self.add_weight(
105
            shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="k"
106
        )
107
        self.v = self.add_weight(
108
            shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="v"
109
        )
110
        self.o = self.add_weight(
111
            shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="o"
112
        )
113
        self.r = self.add_weight(
114
            shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="r"
115
        )
116
        self.r_r_bias = self.add_weight(
117
            shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_r_bias"
118
        )
119
        self.r_s_bias = self.add_weight(
120
            shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_s_bias"
121
        )
122
        self.r_w_bias = self.add_weight(
123
            shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_w_bias"
124
        )
125
        self.seg_embed = self.add_weight(
126
            shape=(2, self.n_head, self.d_head), initializer=initializer, trainable=True, name="seg_embed"
127
        )
128
        super().build(input_shape)
129

130
    def prune_heads(self, heads):
131
        raise NotImplementedError
132

133
    def rel_shift(self, x, klen=-1):
134
        """perform relative shift to form the relative attention score."""
135
        x_size = shape_list(x)
136

137
        x = tf.reshape(x, (x_size[1], x_size[0], x_size[2], x_size[3]))
138
        x = x[1:, ...]
139
        x = tf.reshape(x, (x_size[0], x_size[1] - 1, x_size[2], x_size[3]))
140
        x = x[:, 0:klen, :, :]
141
        # x = torch.index_select(x, 1, torch.arange(klen, device=x.device, dtype=torch.long))
142

143
        return x
144

145
    def rel_attn_core(self, inputs, training=False):
146
        """Core relative positional attention operations."""
147

148
        q_head, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask, head_mask, output_attentions = inputs
149

150
        # content based attention score
151
        ac = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_w_bias, k_head_h)
152

153
        # position based attention score
154
        bd = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_r_bias, k_head_r)
155
        bd = self.rel_shift(bd, klen=shape_list(ac)[1])
156

157
        # segment based attention score
158
        if seg_mat is None:
159
            ef = 0
160
        else:
161
            ef = tf.einsum("ibnd,snd->ibns", q_head + self.r_s_bias, self.seg_embed)
162
            ef = tf.einsum("ijbs,ibns->ijbn", seg_mat, ef)
163

164
        # merge attention scores and perform masking
165
        attn_score = (ac + bd + ef) * self.scale
166
        if attn_mask is not None:
167
            # attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
168
            if attn_mask.dtype == tf.float16:
169
                attn_score = attn_score - 65500 * attn_mask
170
            else:
171
                attn_score = attn_score - 1e30 * attn_mask
172

173
        # attention probability
174
        attn_prob = tf.nn.softmax(attn_score, axis=1)
175

176
        attn_prob = self.dropout(attn_prob, training=training)
177

178
        # Mask heads if we want to
179
        if head_mask is not None:
180
            attn_prob = attn_prob * head_mask
181

182
        # attention output
183
        attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, v_head_h)
184

185
        if cast_bool_to_primitive(output_attentions) is True:
186
            return attn_vec, attn_prob
187

188
        return attn_vec
189

190
    def post_attention(self, inputs, residual=True, training=False):
191
        """Post-attention processing."""
192
        # post-attention projection (back to `d_model`)
193
        h, attn_vec = inputs
194

195
        attn_out = tf.einsum("ibnd,hnd->ibh", attn_vec, self.o)
196

197
        attn_out = self.dropout(attn_out, training=training)
198

199
        if residual:
200
            attn_out = attn_out + h
201
        output = self.layer_norm(attn_out)
202

203
        return output
204

205
    def call(self, inputs, training=False):
206
        (h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems, target_mapping, head_mask, output_attentions) = inputs
207

208
        if g is not None:
209
            # Two-stream attention with relative positional encoding.
210
            # content based attention score
211
            if mems is not None and len(shape_list(mems)) > 1:
212
                cat = tf.concat([mems, h], axis=0)
213
            else:
214
                cat = h
215

216
            # content-based key head
217
            k_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.k)
218

219
            # content-based value head
220
            v_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.v)
221

222
            # position-based key head
223
            k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)
224

225
            # h-stream
226
            # content-stream query head
227
            q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
228

229
            # core attention ops
230
            attn_vec_h = self.rel_attn_core(
231
                [q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask, output_attentions],
232
                training=training,
233
            )
234

235
            if cast_bool_to_primitive(output_attentions) is True:
236
                attn_vec_h, attn_prob_h = attn_vec_h
237

238
            # post processing
239
            output_h = self.post_attention([h, attn_vec_h], training=training)
240

241
            # g-stream
242
            # query-stream query head
243
            q_head_g = tf.einsum("ibh,hnd->ibnd", g, self.q)
244

245
            # core attention ops
246
            if target_mapping is not None:
247
                q_head_g = tf.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
248
                attn_vec_g = self.rel_attn_core(
249
                    [q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask, output_attentions],
250
                    training=training,
251
                )
252

253
                if cast_bool_to_primitive(output_attentions) is True:
254
                    attn_vec_g, attn_prob_g = attn_vec_g
255

256
                attn_vec_g = tf.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
257
            else:
258
                attn_vec_g = self.rel_attn_core(
259
                    [q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask, output_attentions],
260
                    training=training,
261
                )
262

263
                if cast_bool_to_primitive(output_attentions) is True:
264
                    attn_vec_g, attn_prob_g = attn_vec_g
265

266
            # post processing
267
            output_g = self.post_attention([g, attn_vec_g], training=training)
268

269
            if cast_bool_to_primitive(output_attentions) is True:
270
                attn_prob = attn_prob_h, attn_prob_g
271

272
        else:
273
            # Multi-head attention with relative positional encoding
274
            if mems is not None and len(shape_list(mems)) > 1:
275
                cat = tf.concat([mems, h], axis=0)
276
            else:
277
                cat = h
278

279
            # content heads
280
            q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
281
            k_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.k)
282
            v_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.v)
283

284
            # positional heads
285
            k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)
286

287
            # core attention ops
288
            attn_vec = self.rel_attn_core(
289
                [q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask, output_attentions],
290
                training=training,
291
            )
292

293
            if cast_bool_to_primitive(output_attentions) is True:
294
                attn_vec, attn_prob = attn_vec
295

296
            # post processing
297
            output_h = self.post_attention([h, attn_vec], training=training)
298
            output_g = None
299

300
        outputs = (output_h, output_g)
301
        if cast_bool_to_primitive(output_attentions) is True:
302
            outputs = outputs + (attn_prob,)
303
        return outputs
304

305

306
class TFXLNetFeedForward(tf.keras.layers.Layer):
307
    def __init__(self, config, **kwargs):
308
        super().__init__(**kwargs)
309
        self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
310
        self.layer_1 = tf.keras.layers.Dense(
311
            config.d_inner, kernel_initializer=get_initializer(config.initializer_range), name="layer_1"
312
        )
313
        self.layer_2 = tf.keras.layers.Dense(
314
            config.d_model, kernel_initializer=get_initializer(config.initializer_range), name="layer_2"
315
        )
316
        self.dropout = tf.keras.layers.Dropout(config.dropout)
317
        if isinstance(config.ff_activation, str):
318
            self.activation_function = ACT2FN[config.ff_activation]
319
        else:
320
            self.activation_function = config.ff_activation
321

322
    def call(self, inp, training=False):
323
        output = inp
324
        output = self.layer_1(output)
325
        output = self.activation_function(output)
326
        output = self.dropout(output, training=training)
327
        output = self.layer_2(output)
328
        output = self.dropout(output, training=training)
329
        output = self.layer_norm(output + inp)
330
        return output
331

332

333
class TFXLNetLayer(tf.keras.layers.Layer):
334
    def __init__(self, config, **kwargs):
335
        super().__init__(**kwargs)
336
        self.rel_attn = TFXLNetRelativeAttention(config, name="rel_attn")
337
        self.ff = TFXLNetFeedForward(config, name="ff")
338
        self.dropout = tf.keras.layers.Dropout(config.dropout)
339

340
    def call(self, inputs, training=False):
341
        outputs = self.rel_attn(inputs, training=training)
342
        output_h, output_g = outputs[:2]
343

344
        if output_g is not None:
345
            output_g = self.ff(output_g, training=training)
346
        output_h = self.ff(output_h, training=training)
347

348
        outputs = (output_h, output_g) + outputs[2:]  # Add again attentions if there are there
349
        return outputs
350

351

352
class TFXLNetLMHead(tf.keras.layers.Layer):
353
    def __init__(self, config, input_embeddings, **kwargs):
354
        super().__init__(**kwargs)
355
        self.vocab_size = config.vocab_size
356
        # The output weights are the same as the input embeddings, but there is
357
        # an output-only bias for each token.
358
        self.input_embeddings = input_embeddings
359

360
    def build(self, input_shape):
361
        self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
362
        super().build(input_shape)
363

364
    def call(self, hidden_states):
365
        hidden_states = self.input_embeddings(hidden_states, mode="linear")
366
        hidden_states = hidden_states + self.bias
367
        return hidden_states
368

369

370
@keras_serializable
371
class TFXLNetMainLayer(tf.keras.layers.Layer):
372
    config_class = XLNetConfig
373

374
    def __init__(self, config, **kwargs):
375
        super().__init__(**kwargs)
376
        self.output_hidden_states = config.output_hidden_states
377
        self.output_attentions = config.output_attentions
378

379
        self.mem_len = config.mem_len
380
        self.reuse_len = config.reuse_len
381
        self.d_model = config.d_model
382
        self.same_length = config.same_length
383
        self.attn_type = config.attn_type
384
        self.bi_data = config.bi_data
385
        self.clamp_len = config.clamp_len
386
        self.n_layer = config.n_layer
387
        self.use_bfloat16 = config.use_bfloat16
388
        self.initializer_range = config.initializer_range
389

390
        self.word_embedding = TFSharedEmbeddings(
391
            config.vocab_size, config.d_model, initializer_range=config.initializer_range, name="word_embedding"
392
        )
393
        self.layer = [TFXLNetLayer(config, name="layer_._{}".format(i)) for i in range(config.n_layer)]
394
        self.dropout = tf.keras.layers.Dropout(config.dropout)
395

396
    def get_input_embeddings(self):
397
        return self.word_embedding
398

399
    def set_input_embeddings(self, value):
400
        self.word_embedding.weight = value
401
        self.word_embedding.vocab_size = value.shape[0]
402

403
    def build(self, input_shape):
404
        initializer = get_initializer(self.initializer_range)
405
        self.mask_emb = self.add_weight(
406
            shape=(1, 1, self.d_model), initializer=initializer, trainable=True, name="mask_emb"
407
        )
408

409
    def _resize_token_embeddings(self, new_num_tokens):
410
        raise NotImplementedError
411

412
    def _prune_heads(self, heads_to_prune):
413
        raise NotImplementedError
414

415
    def create_mask(self, qlen, mlen, dtype=tf.float32):
416
        """
417
        Creates causal attention mask. Float mask where 1.0 indicates masked, 0.0 indicates not-masked.
418

419
        Args:
420
            qlen: TODO Lysandre didn't fill
421
            mlen: TODO Lysandre didn't fill
422

423
        ::
424

425
                  same_length=False:      same_length=True:
426
                  <mlen > <  qlen >       <mlen > <  qlen >
427
               ^ [0 0 0 0 0 1 1 1 1]     [0 0 0 0 0 1 1 1 1]
428
                 [0 0 0 0 0 0 1 1 1]     [1 0 0 0 0 0 1 1 1]
429
            qlen [0 0 0 0 0 0 0 1 1]     [1 1 0 0 0 0 0 1 1]
430
                 [0 0 0 0 0 0 0 0 1]     [1 1 1 0 0 0 0 0 1]
431
               v [0 0 0 0 0 0 0 0 0]     [1 1 1 1 0 0 0 0 0]
432

433
        """
434
        attn_mask = tf.ones([qlen, qlen], dtype=dtype)
435
        mask_u = tf.matrix_band_part(attn_mask, 0, -1)
436
        mask_dia = tf.matrix_band_part(attn_mask, 0, 0)
437
        attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype)
438
        ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
439
        if self.same_length:
440
            mask_l = tf.matrix_band_part(attn_mask, -1, 0)
441
            ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1)
442
        return ret
443

444
    def cache_mem(self, curr_out, prev_mem):
445
        """cache hidden states into memory."""
446
        if self.reuse_len is not None and self.reuse_len > 0:
447
            curr_out = curr_out[: self.reuse_len]
448

449
        if prev_mem is None:
450
            new_mem = curr_out[-self.mem_len :]
451
        else:
452
            new_mem = tf.concat([prev_mem, curr_out], 0)[-self.mem_len :]
453

454
        return tf.stop_gradient(new_mem)
455

456
    @staticmethod
457
    def positional_embedding(pos_seq, inv_freq, bsz=None):
458
        sinusoid_inp = tf.einsum("i,d->id", pos_seq, inv_freq)
459
        pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], axis=-1)
460
        pos_emb = pos_emb[:, None, :]
461

462
        if bsz is not None:
463
            pos_emb = tf.tile(pos_emb, [1, bsz, 1])
464

465
        return pos_emb
466

467
    def relative_positional_encoding(self, qlen, klen, bsz=None, dtype=None):
468
        """create relative positional encoding."""
469
        freq_seq = tf.range(0, self.d_model, 2.0)
470
        if dtype is not None and dtype != tf.float32:
471
            freq_seq = tf.cast(freq_seq, dtype=dtype)
472
        inv_freq = 1 / (10000 ** (freq_seq / self.d_model))
473

474
        if self.attn_type == "bi":
475
            # beg, end = klen - 1, -qlen
476
            beg, end = klen, -qlen
477
        elif self.attn_type == "uni":
478
            # beg, end = klen - 1, -1
479
            beg, end = klen, -1
480
        else:
481
            raise ValueError("Unknown `attn_type` {}.".format(self.attn_type))
482

483
        if self.bi_data:
484
            fwd_pos_seq = tf.range(beg, end, -1.0)
485
            bwd_pos_seq = tf.range(-beg, -end, 1.0)
486

487
            if dtype is not None and dtype != tf.float32:
488
                fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
489
                bwd_pos_seq = tf.cast(bwd_pos_seq, dtype=dtype)
490

491
            if self.clamp_len > 0:
492
                fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len)
493
                bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -self.clamp_len, self.clamp_len)
494

495
            if bsz is not None:
496
                # With bi_data, the batch size should be divisible by 2.
497
                assert bsz % 2 == 0
498
                fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz // 2)
499
                bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz // 2)
500
            else:
501
                fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq)
502
                bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq)
503

504
            pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1)
505
        else:
506
            fwd_pos_seq = tf.range(beg, end, -1.0)
507
            if dtype is not None and dtype != tf.float32:
508
                fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
509
            if self.clamp_len > 0:
510
                fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len)
511
            pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
512

513
        return pos_emb
514

515
    def call(
516
        self,
517
        inputs,
518
        attention_mask=None,
519
        mems=None,
520
        perm_mask=None,
521
        target_mapping=None,
522
        token_type_ids=None,
523
        input_mask=None,
524
        head_mask=None,
525
        inputs_embeds=None,
526
        use_cache=True,
527
        output_attentions=None,
528
        output_hidden_states=None,
529
        training=False,
530
    ):
531
        if isinstance(inputs, (tuple, list)):
532
            input_ids = inputs[0]
533
            attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
534
            mems = inputs[2] if len(inputs) > 2 else mems
535
            perm_mask = inputs[3] if len(inputs) > 3 else perm_mask
536
            target_mapping = inputs[4] if len(inputs) > 4 else target_mapping
537
            token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids
538
            input_mask = inputs[6] if len(inputs) > 6 else input_mask
539
            head_mask = inputs[7] if len(inputs) > 7 else head_mask
540
            inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
541
            use_cache = inputs[9] if len(inputs) > 9 else use_cache
542
            output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
543
            output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
544
            assert len(inputs) <= 12, "Too many inputs."
545
        elif isinstance(inputs, (dict, BatchEncoding)):
546
            input_ids = inputs.get("input_ids")
547
            attention_mask = inputs.get("attention_mask", attention_mask)
548
            mems = inputs.get("mems", mems)
549
            perm_mask = inputs.get("perm_mask", perm_mask)
550
            target_mapping = inputs.get("target_mapping", target_mapping)
551
            token_type_ids = inputs.get("token_type_ids", token_type_ids)
552
            input_mask = inputs.get("input_mask", input_mask)
553
            head_mask = inputs.get("head_mask", head_mask)
554
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
555
            use_cache = inputs.get("use_cache", use_cache)
556
            output_attentions = inputs.get("output_attentions", output_attentions)
557
            output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
558
            assert len(inputs) <= 12, "Too many inputs."
559
        else:
560
            input_ids = inputs
561

562
        output_attentions = output_attentions if output_attentions is not None else self.output_attentions
563
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
564

565
        # the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
566
        # but we want a unified interface in the library with the batch size on the first dimension
567
        # so we move here the first dimension (batch) to the end
568

569
        if input_ids is not None and inputs_embeds is not None:
570
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
571
        elif input_ids is not None:
572
            input_ids = tf.transpose(input_ids, perm=(1, 0))
573
            qlen, bsz = shape_list(input_ids)[:2]
574
        elif inputs_embeds is not None:
575
            inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2))
576
            qlen, bsz = shape_list(inputs_embeds)[:2]
577
        else:
578
            raise ValueError("You have to specify either input_ids or inputs_embeds")
579

580
        token_type_ids = tf.transpose(token_type_ids, perm=(1, 0)) if token_type_ids is not None else None
581
        input_mask = tf.transpose(input_mask, perm=(1, 0)) if input_mask is not None else None
582
        attention_mask = tf.transpose(attention_mask, perm=(1, 0)) if attention_mask is not None else None
583
        perm_mask = tf.transpose(perm_mask, perm=(1, 2, 0)) if perm_mask is not None else None
584
        target_mapping = tf.transpose(target_mapping, perm=(1, 2, 0)) if target_mapping is not None else None
585

586
        mlen = shape_list(mems[0])[0] if mems is not None and mems[0] is not None else 0
587
        klen = mlen + qlen
588

589
        dtype_float = tf.bfloat16 if self.use_bfloat16 else tf.float32
590

591
        # Attention mask
592
        # causal attention mask
593
        if self.attn_type == "uni":
594
            attn_mask = self.create_mask(qlen, mlen)
595
            attn_mask = attn_mask[:, :, None, None]
596
        elif self.attn_type == "bi":
597
            attn_mask = None
598
        else:
599
            raise ValueError("Unsupported attention type: {}".format(self.attn_type))
600

601
        # data mask: input mask & perm mask
602
        assert input_mask is None or attention_mask is None, (
603
            "You can only use one of input_mask (uses 1 for padding) "
604
            "or attention_mask (uses 0 for padding, added for compatbility with BERT). Please choose one."
605
        )
606
        if input_mask is None and attention_mask is not None:
607
            input_mask = 1.0 - tf.cast(attention_mask, dtype=dtype_float)
608
        if input_mask is not None and perm_mask is not None:
609
            data_mask = input_mask[None] + perm_mask
610
        elif input_mask is not None and perm_mask is None:
611
            data_mask = input_mask[None]
612
        elif input_mask is None and perm_mask is not None:
613
            data_mask = perm_mask
614
        else:
615
            data_mask = None
616

617
        if data_mask is not None:
618
            # all mems can be attended to
619
            if mlen > 0:
620
                mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz], dtype=dtype_float)
621
                data_mask = tf.concat([mems_mask, data_mask], axis=1)
622
            if attn_mask is None:
623
                attn_mask = data_mask[:, :, :, None]
624
            else:
625
                attn_mask += data_mask[:, :, :, None]
626

627
        if attn_mask is not None:
628
            attn_mask = tf.cast(attn_mask > 0, dtype=dtype_float)
629

630
        if attn_mask is not None:
631
            non_tgt_mask = -tf.eye(qlen, dtype=dtype_float)
632
            if mlen > 0:
633
                non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=dtype_float), non_tgt_mask], axis=-1)
634
            non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=dtype_float)
635
        else:
636
            non_tgt_mask = None
637

638
        # Word embeddings and prepare h & g hidden states
639
        if inputs_embeds is not None:
640
            word_emb_k = inputs_embeds
641
        else:
642
            word_emb_k = self.word_embedding(input_ids)
643
        output_h = self.dropout(word_emb_k, training=training)
644
        if target_mapping is not None:
645
            word_emb_q = tf.tile(self.mask_emb, [shape_list(target_mapping)[0], bsz, 1])
646
            # else:  # We removed the inp_q input which was same as target mapping
647
            #     inp_q_ext = inp_q[:, :, None]
648
            #     word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
649
            output_g = self.dropout(word_emb_q, training=training)
650
        else:
651
            output_g = None
652

653
        # Segment embedding
654
        if token_type_ids is not None:
655
            # Convert `token_type_ids` to one-hot `seg_mat`
656
            if mlen > 0:
657
                mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
658
                cat_ids = tf.concat([mem_pad, token_type_ids], 0)
659
            else:
660
                cat_ids = token_type_ids
661

662
            # `1` indicates not in the same segment [qlen x klen x bsz]
663
            seg_mat = tf.cast(tf.logical_not(tf.equal(token_type_ids[:, None], cat_ids[None, :])), tf.int32)
664
            seg_mat = tf.one_hot(seg_mat, 2, dtype=dtype_float)
665
        else:
666
            seg_mat = None
667

668
        # Positional encoding
669
        pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz, dtype=dtype_float)
670
        pos_emb = self.dropout(pos_emb, training=training)
671

672
        # Prepare head mask if needed
673
        # 1.0 in head_mask indicate we keep the head
674
        # attention_probs has shape bsz x n_heads x N x N
675
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
676
        # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
677
        if head_mask is not None:
678
            raise NotImplementedError
679
        else:
680
            head_mask = [None] * self.n_layer
681

682
        new_mems = ()
683
        if mems is None:
684
            mems = [None] * len(self.layer)
685

686
        attentions = []
687
        hidden_states = []
688
        for i, layer_module in enumerate(self.layer):
689
            # cache new mems
690
            if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
691
                new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
692
            if cast_bool_to_primitive(output_hidden_states) is True:
693
                hidden_states.append((output_h, output_g) if output_g is not None else output_h)
694

695
            outputs = layer_module(
696
                [
697
                    output_h,
698
                    output_g,
699
                    non_tgt_mask,
700
                    attn_mask,
701
                    pos_emb,
702
                    seg_mat,
703
                    mems[i],
704
                    target_mapping,
705
                    head_mask[i],
706
                    output_attentions,
707
                ],
708
                training=training,
709
            )
710
            output_h, output_g = outputs[:2]
711
            if cast_bool_to_primitive(output_attentions) is True:
712
                attentions.append(outputs[2])
713

714
        # Add last hidden state
715
        if cast_bool_to_primitive(output_hidden_states) is True:
716
            hidden_states.append((output_h, output_g) if output_g is not None else output_h)
717

718
        output = self.dropout(output_g if output_g is not None else output_h, training=training)
719

720
        # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
721
        outputs = (tf.transpose(output, perm=(1, 0, 2)),)
722

723
        if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
724
            outputs = outputs + (new_mems,)
725

726
        if cast_bool_to_primitive(output_hidden_states) is True:
727
            if output_g is not None:
728
                hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs)
729
            else:
730
                hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states)
731
            outputs = outputs + (hidden_states,)
732
        if cast_bool_to_primitive(output_attentions) is True:
733
            attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
734
            outputs = outputs + (attentions,)
735

736
        return outputs  # outputs, (new_mems), (hidden_states), (attentions)
737

738

739
class TFXLNetPreTrainedModel(TFPreTrainedModel):
740
    """ An abstract class to handle weights initialization and
741
        a simple interface for downloading and loading pretrained models.
742
    """
743

744
    config_class = XLNetConfig
745
    base_model_prefix = "transformer"
746

747

748
XLNET_START_DOCSTRING = r"""
749

750
    .. note::
751

752
        TF 2.0 models accepts two formats as inputs:
753

754
            - having all inputs as keyword arguments (like PyTorch models), or
755
            - having all inputs as a list, tuple or dict in the first positional arguments.
756

757
        This second option is useful when using :obj:`tf.keras.Model.fit()` method which currently requires having
758
        all the tensors in the first argument of the model call function: :obj:`model(inputs)`.
759

760
        If you choose this second option, there are three possibilities you can use to gather all the input Tensors
761
        in the first positional argument :
762

763
        - a single Tensor with input_ids only and nothing else: :obj:`model(inputs_ids)`
764
        - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
765
          :obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
766
        - a dictionary with one or several input Tensors associated to the input names given in the docstring:
767
          :obj:`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
768

769
    Parameters:
770
        config (:class:`~transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
771
            Initializing with a config file does not load the weights associated with the model, only the configuration.
772
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
773
"""
774

775
XLNET_INPUTS_DOCSTRING = r"""
776
    Args:
777
        input_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`):
778
            Indices of input sequence tokens in the vocabulary.
779

780
            Indices can be obtained using :class:`transformers.XLNetTokenizer`.
781
            See :func:`transformers.PreTrainedTokenizer.encode` and
782
            :func:`transformers.PreTrainedTokenizer.__call__` for details.
783

784
            `What are input IDs? <../glossary.html#input-ids>`__
785
        attention_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
786
            Mask to avoid performing attention on padding token indices.
787
            Mask values selected in ``[0, 1]``:
788
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
789

790
            `What are attention masks? <../glossary.html#attention-mask>`__
791
        mems (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`):
792
            Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
793
            (see `mems` output below). Can be used to speed up sequential decoding. The token ids which have their mems
794
            given to this model should not be passed as input ids as they have already been computed.
795
        perm_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length, sequence_length)`, `optional`, defaults to :obj:`None`):
796
            Mask to indicate the attention pattern for each input token with values selected in ``[0, 1]``:
797
            If ``perm_mask[k, i, j] = 0``, i attend to j in batch k;
798
            if ``perm_mask[k, i, j] = 1``, i does not attend to j in batch k.
799
            If None, each token attends to all the others (full bidirectional attention).
800
            Only used during pretraining (to define factorization order) or for sequential decoding (generation).
801
        target_mapping (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, num_predict, sequence_length)`, `optional`, defaults to :obj:`None`):
802
            Mask to indicate the output tokens to use.
803
            If ``target_mapping[k, i, j] = 1``, the i-th predict in batch k is on the j-th token.
804
            Only used during pretraining for partial prediction or for sequential decoding (generation).
805
        token_type_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
806
            Segment token indices to indicate first and second portions of the inputs.
807
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
808
            corresponds to a `sentence B` token
809

810
            `What are token type IDs? <../glossary.html#token-type-ids>`_
811
        input_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
812
            Mask to avoid performing attention on padding token indices.
813
            Negative of `attention_mask`, i.e. with 0 for real tokens and 1 for padding.
814
            Kept for compatibility with the original code base.
815
            You can only uses one of `input_mask` and `attention_mask`
816
            Mask values selected in ``[0, 1]``:
817
            ``1`` for tokens that are MASKED, ``0`` for tokens that are NOT MASKED.
818
        head_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
819
            Mask to nullify selected heads of the self-attention modules.
820
            Mask values selected in ``[0, 1]``:
821
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
822
        inputs_embeds (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
823
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
824
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
825
            than the model's internal embedding lookup matrix.
826
        use_cache (:obj:`bool`):
827
            If `use_cache` is True, `mems` are returned and can be used to speed up decoding (see `mems`). Defaults to `True`.
828
        output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
829
            If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
830
"""
831

832

833
@add_start_docstrings(
834
    "The bare XLNet Model transformer outputing raw hidden-states without any specific head on top.",
835
    XLNET_START_DOCSTRING,
836
)
837
class TFXLNetModel(TFXLNetPreTrainedModel):
838
    def __init__(self, config, *inputs, **kwargs):
839
        super().__init__(config, *inputs, **kwargs)
840
        self.transformer = TFXLNetMainLayer(config, name="transformer")
841

842
    @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
843
    @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlnet-base-cased")
844
    def call(self, inputs, **kwargs):
845
        r"""
846
    Return:
847
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
848
        last_hidden_state (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
849
            Sequence of hidden-states at the last layer of the model.
850
        mems (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`):
851
            Contains pre-computed hidden-states (key and values in the attention blocks).
852
            Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
853
            should not be passed as input ids as they have already been computed.
854
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
855
            Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for the output of the embeddings + one for the output of each layer)
856
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
857

858
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
859
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
860
            Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape
861
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
862

863
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
864
            heads.
865
        """
866
        outputs = self.transformer(inputs, **kwargs)
867
        return outputs
868

869

870
@add_start_docstrings(
871
    """XLNet Model with a language modeling head on top
872
    (linear layer with weights tied to the input embeddings). """,
873
    XLNET_START_DOCSTRING,
874
)
875
class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
876
    def __init__(self, config, *inputs, **kwargs):
877
        super().__init__(config, *inputs, **kwargs)
878
        self.transformer = TFXLNetMainLayer(config, name="transformer")
879
        self.lm_loss = TFXLNetLMHead(config, self.transformer.word_embedding, name="lm_loss")
880

881
    def get_output_embeddings(self):
882
        return self.lm_loss.input_embeddings
883

884
    def prepare_inputs_for_generation(self, inputs, past, **kwargs):
885
        # Add dummy token at the end (no attention on this one)
886

887
        # At every pass, the attention values for the new token and the two last generated tokens
888
        # are computed, the rest is reloaded from the `past` cache. A purely auto-regressive model would have
889
        # offset = 1; offset = 2 seems to have slightly better computation.
890
        offset = 2
891

892
        effective_batch_size = inputs.shape[0]
893
        dummy_token = tf.zeros((effective_batch_size, 1), dtype=tf.int32)
894

895
        if past:
896
            inputs = tf.concat([inputs[:, -offset:], dummy_token], axis=1)
897
        else:
898
            inputs = tf.concat([inputs, dummy_token], axis=1)
899

900
        # Build permutation mask so that previous tokens don't see last token
901
        sequence_length = inputs.shape[1]
902
        perm_mask = tf.zeros((effective_batch_size, sequence_length, sequence_length - 1), dtype=tf.float32)
903
        perm_mask_seq_end = tf.ones((effective_batch_size, sequence_length, 1), dtype=tf.float32)
904
        perm_mask = tf.concat([perm_mask, perm_mask_seq_end], axis=-1)
905

906
        # We'll only predict the last token
907
        target_mapping = tf.zeros((effective_batch_size, 1, sequence_length - 1), dtype=tf.float32)
908
        target_mapping_seq_end = tf.ones((effective_batch_size, 1, 1), dtype=tf.float32)
909
        target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1)
910

911
        inputs = {
912
            "inputs": inputs,
913
            "perm_mask": perm_mask,
914
            "target_mapping": target_mapping,
915
            "use_cache": kwargs["use_cache"],
916
        }
917

918
        # if past is defined in model kwargs then use it for faster decoding
919
        if past:
920
            inputs["mems"] = tuple(layer_past[:-offset, :, :] for layer_past in past)
921

922
        return inputs
923

924
    @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
925
    def call(
926
        self,
927
        inputs,
928
        attention_mask=None,
929
        mems=None,
930
        perm_mask=None,
931
        target_mapping=None,
932
        token_type_ids=None,
933
        input_mask=None,
934
        head_mask=None,
935
        inputs_embeds=None,
936
        use_cache=True,
937
        output_attentions=None,
938
        output_hidden_states=None,
939
        labels=None,
940
        training=False,
941
    ):
942
        r"""
943
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
944
            Labels for computing the cross entropy classification loss.
945
            Indices should be in ``[0, ..., config.vocab_size - 1]``.
946

947
    Return:
948
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
949
        prediction_scores (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
950
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
951
        mems (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`):
952
            Contains pre-computed hidden-states (key and values in the attention blocks).
953
            Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
954
            should not be passed as input ids as they have already been computed.
955
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
956
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
957
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
958

959
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
960
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
961
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
962
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
963

964
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
965
            heads.
966

967
    Examples::
968

969
        import tensorflow as tf
970
        import numpy as np
971
        from transformers import XLNetTokenizer, TFXLNetLMHeadModel
972

973
        tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
974
        model = TFXLNetLMHeadModel.from_pretrained('xlnet-large-cased')
975

976
        # We show how to setup inputs to predict a next token using a bi-directional context.
977
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is very <mask>", add_special_tokens=True))[None, :]  # We will predict the masked token
978

979
        perm_mask = np.zeros((1, input_ids.shape[1], input_ids.shape[1]))
980
        perm_mask[:, :, -1] = 1.0  # Previous tokens don't see last token
981

982
        target_mapping = np.zeros((1, 1, input_ids.shape[1]))  # Shape [1, 1, seq_length] => let's predict one token
983
        target_mapping[0, 0, -1] = 1.0  # Our first (and only) prediction will be the last token of the sequence (the masked token)
984

985
        outputs = model(input_ids, perm_mask=tf.constant(perm_mask, dtype=tf.float32), target_mapping=tf.constant(target_mapping, dtype=tf.float32))
986

987
        next_token_logits = outputs[0]  # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
988

989
        """
990
        if isinstance(inputs, (tuple, list)):
991
            labels = inputs[12] if len(inputs) > 12 else labels
992
            if len(inputs) > 12:
993
                inputs = inputs[:12]
994
        elif isinstance(inputs, (dict, BatchEncoding)):
995
            labels = inputs.pop("labels", labels)
996

997
        transformer_outputs = self.transformer(
998
            inputs,
999
            attention_mask=None,
1000
            mems=None,
1001
            perm_mask=None,
1002
            target_mapping=None,
1003
            token_type_ids=None,
1004
            input_mask=None,
1005
            head_mask=None,
1006
            inputs_embeds=None,
1007
            use_cache=True,
1008
            output_attentions=None,
1009
            output_hidden_states=None,
1010
            training=training,
1011
        )
1012
        hidden_state = transformer_outputs[0]
1013
        logits = self.lm_loss(hidden_state, training=training)
1014

1015
        outputs = (logits,) + transformer_outputs[1:]  # Keep mems, hidden states, attentions if there are in it
1016

1017
        if labels is not None:
1018
            # shift labels to the left and cut last logit token
1019
            logits = logits[:, :-1]
1020
            labels = labels[:, 1:]
1021
            loss = self.compute_loss(labels, logits)
1022
            outputs = (loss,) + outputs
1023

1024
        return outputs  # return logits, (mems), (hidden states), (attentions)
1025

1026

1027
@add_start_docstrings(
1028
    """XLNet Model with a sequence classification/regression head on top (a linear layer on top of
1029
    the pooled output) e.g. for GLUE tasks. """,
1030
    XLNET_START_DOCSTRING,
1031
)
1032
class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassificationLoss):
1033
    def __init__(self, config, *inputs, **kwargs):
1034
        super().__init__(config, *inputs, **kwargs)
1035
        self.num_labels = config.num_labels
1036

1037
        self.transformer = TFXLNetMainLayer(config, name="transformer")
1038
        self.sequence_summary = TFSequenceSummary(
1039
            config, initializer_range=config.initializer_range, name="sequence_summary"
1040
        )
1041
        self.logits_proj = tf.keras.layers.Dense(
1042
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj"
1043
        )
1044

1045
    @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
1046
    @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlnet-base-cased")
1047
    def call(
1048
        self,
1049
        inputs=None,
1050
        attention_mask=None,
1051
        mems=None,
1052
        perm_mask=None,
1053
        target_mapping=None,
1054
        token_type_ids=None,
1055
        input_mask=None,
1056
        head_mask=None,
1057
        inputs_embeds=None,
1058
        use_cache=True,
1059
        output_attentions=None,
1060
        output_hidden_states=None,
1061
        labels=None,
1062
        training=False,
1063
    ):
1064
        r"""
1065
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1066
            Labels for computing the sequence classification/regression loss.
1067
            Indices should be in ``[0, ..., config.num_labels - 1]``.
1068
            If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
1069
            If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
1070

1071
    Return:
1072
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
1073
        logits (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:(batch_size, config.num_labels)`):
1074
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
1075
        mems (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`):
1076
            Contains pre-computed hidden-states (key and values in the attention blocks).
1077
            Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
1078
            should not be passed as input ids as they have already been computed.
1079
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1080
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
1081
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1082

1083
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1084
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1085
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
1086
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
1087

1088
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1089
            heads.
1090
        """
1091
        if isinstance(inputs, (tuple, list)):
1092
            labels = inputs[12] if len(inputs) > 12 else labels
1093
            if len(inputs) > 12:
1094
                inputs = inputs[:12]
1095
        elif isinstance(inputs, (dict, BatchEncoding)):
1096
            labels = inputs.pop("labels", labels)
1097

1098
        transformer_outputs = self.transformer(
1099
            inputs,
1100
            attention_mask=attention_mask,
1101
            mems=mems,
1102
            perm_mask=perm_mask,
1103
            target_mapping=target_mapping,
1104
            token_type_ids=token_type_ids,
1105
            input_mask=input_mask,
1106
            head_mask=head_mask,
1107
            inputs_embeds=inputs_embeds,
1108
            use_cache=use_cache,
1109
            output_attentions=output_attentions,
1110
            output_hidden_states=output_hidden_states,
1111
        )
1112
        output = transformer_outputs[0]
1113

1114
        output = self.sequence_summary(output)
1115
        logits = self.logits_proj(output)
1116

1117
        outputs = (logits,) + transformer_outputs[1:]  # Keep mems, hidden states, attentions if there are in it
1118

1119
        if labels is not None:
1120
            loss = self.compute_loss(labels, logits)
1121
            outputs = (loss,) + outputs
1122

1123
        return outputs  # (loss), logits, (hidden_states), (attentions)
1124

1125

1126
@add_start_docstrings(
1127
    """XLNET Model with a multiple choice classification head on top (a linear layer on top of
1128
    the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
1129
    XLNET_START_DOCSTRING,
1130
)
1131
class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
1132
    def __init__(self, config, *inputs, **kwargs):
1133
        super().__init__(config, *inputs, **kwargs)
1134

1135
        self.transformer = TFXLNetMainLayer(config, name="transformer")
1136
        self.sequence_summary = TFSequenceSummary(
1137
            config, initializer_range=config.initializer_range, name="sequence_summary"
1138
        )
1139
        self.logits_proj = tf.keras.layers.Dense(
1140
            1, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj"
1141
        )
1142

1143
    @property
1144
    def dummy_inputs(self):
1145
        """ Dummy inputs to build the network.
1146

1147
        Returns:
1148
            tf.Tensor with dummy inputs
1149
        """
1150
        return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
1151

1152
    @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
1153
    @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlnet-base-cased")
1154
    def call(
1155
        self,
1156
        inputs=None,
1157
        token_type_ids=None,
1158
        input_mask=None,
1159
        attention_mask=None,
1160
        mems=None,
1161
        perm_mask=None,
1162
        target_mapping=None,
1163
        head_mask=None,
1164
        inputs_embeds=None,
1165
        use_cache=True,
1166
        output_attentions=None,
1167
        output_hidden_states=None,
1168
        labels=None,
1169
        training=False,
1170
    ):
1171
        r"""
1172
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1173
            Labels for computing the multiple choice classification loss.
1174
            Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
1175
            of the input tensors. (see `input_ids` above)
1176

1177
    Return:
1178
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1179
        classification_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`:
1180
            `num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above).
1181

1182
            Classification scores (before SoftMax).
1183
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1184
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
1185
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1186

1187
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1188
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1189
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
1190
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
1191

1192
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1193
            heads.
1194
        """
1195
        if isinstance(inputs, (tuple, list)):
1196
            input_ids = inputs[0]
1197
            attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
1198
            mems = inputs[2] if len(inputs) > 2 else mems
1199
            perm_mask = inputs[3] if len(inputs) > 3 else perm_mask
1200
            target_mapping = inputs[4] if len(inputs) > 4 else target_mapping
1201
            token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids
1202
            input_mask = inputs[6] if len(inputs) > 6 else input_mask
1203
            head_mask = inputs[7] if len(inputs) > 7 else head_mask
1204
            inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
1205
            use_cache = inputs[9] if len(inputs) > 9 else use_cache
1206
            output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
1207
            output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
1208
            labels = inputs[12] if len(inputs) > 12 else labels
1209
            assert len(inputs) <= 13, "Too many inputs."
1210
        elif isinstance(inputs, (dict, BatchEncoding)):
1211
            input_ids = inputs.get("input_ids")
1212
            attention_mask = inputs.get("attention_mask", attention_mask)
1213
            mems = inputs.get("mems", mems)
1214
            perm_mask = inputs.get("perm_mask", perm_mask)
1215
            target_mapping = inputs.get("target_mapping", target_mapping)
1216
            token_type_ids = inputs.get("token_type_ids", token_type_ids)
1217
            input_mask = inputs.get("input_mask", input_mask)
1218
            head_mask = inputs.get("head_mask", head_mask)
1219
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
1220
            use_cache = inputs.get("use_cache", use_cache)
1221
            output_attentions = inputs.get("output_attentions", output_attentions)
1222
            output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
1223
            labels = inputs.get("labels", labels)
1224
            assert len(inputs) <= 13, "Too many inputs."
1225
        else:
1226
            input_ids = inputs
1227

1228
        if input_ids is not None:
1229
            num_choices = shape_list(input_ids)[1]
1230
            seq_length = shape_list(input_ids)[2]
1231
        else:
1232
            num_choices = shape_list(inputs_embeds)[1]
1233
            seq_length = shape_list(inputs_embeds)[2]
1234

1235
        flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
1236
        flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
1237
        flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
1238
        flat_input_mask = tf.reshape(input_mask, (-1, seq_length)) if input_mask is not None else None
1239
        flat_inputs_embeds = (
1240
            tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
1241
            if inputs_embeds is not None
1242
            else None
1243
        )
1244

1245
        flat_inputs = [
1246
            flat_input_ids,
1247
            flat_attention_mask,
1248
            mems,
1249
            perm_mask,
1250
            target_mapping,
1251
            flat_token_type_ids,
1252
            flat_input_mask,
1253
            head_mask,
1254
            flat_inputs_embeds,
1255
            use_cache,
1256
            output_attentions,
1257
            output_hidden_states,
1258
        ]
1259

1260
        transformer_outputs = self.transformer(flat_inputs, training=training)
1261
        output = transformer_outputs[0]
1262
        logits = self.sequence_summary(output)
1263
        logits = self.logits_proj(logits)
1264
        reshaped_logits = tf.reshape(logits, (-1, num_choices))
1265

1266
        outputs = (reshaped_logits,) + transformer_outputs[1:]  # add hidden states and attention if they are here
1267

1268
        if labels is not None:
1269
            loss = self.compute_loss(labels, reshaped_logits)
1270
            outputs = (loss,) + outputs
1271

1272
        return outputs  # (loss), logits, (mems), (hidden states), (attentions)
1273

1274

1275
@add_start_docstrings(
1276
    """XLNet Model with a token classification head on top (a linear layer on top of
1277
    the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
1278
    XLNET_START_DOCSTRING,
1279
)
1280
class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificationLoss):
1281
    def __init__(self, config, *inputs, **kwargs):
1282
        super().__init__(config, *inputs, **kwargs)
1283
        self.num_labels = config.num_labels
1284

1285
        self.transformer = TFXLNetMainLayer(config, name="transformer")
1286
        self.classifier = tf.keras.layers.Dense(
1287
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
1288
        )
1289

1290
    @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
1291
    @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlnet-base-cased")
1292
    def call(
1293
        self,
1294
        inputs=None,
1295
        attention_mask=None,
1296
        mems=None,
1297
        perm_mask=None,
1298
        target_mapping=None,
1299
        token_type_ids=None,
1300
        input_mask=None,
1301
        head_mask=None,
1302
        inputs_embeds=None,
1303
        use_cache=True,
1304
        output_attentions=None,
1305
        output_hidden_states=None,
1306
        labels=None,
1307
        training=False,
1308
    ):
1309
        r"""
1310
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
1311
            Labels for computing the token classification loss.
1312
            Indices should be in ``[0, ..., config.num_labels - 1]``.
1313

1314
    Return:
1315
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
1316
        logits (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:(batch_size, config.num_labels)`):
1317
            Classification scores (before SoftMax).
1318
        mems (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`):
1319
            Contains pre-computed hidden-states (key and values in the attention blocks).
1320
            Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
1321
            should not be passed as input ids as they have already been computed.
1322
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1323
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
1324
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1325

1326
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1327
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1328
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
1329
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
1330

1331
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1332
            heads.
1333
        """
1334
        if isinstance(inputs, (tuple, list)):
1335
            labels = inputs[12] if len(inputs) > 12 else labels
1336
            if len(inputs) > 12:
1337
                inputs = inputs[:12]
1338
        elif isinstance(inputs, (dict, BatchEncoding)):
1339
            labels = inputs.pop("labels", labels)
1340

1341
        transformer_outputs = self.transformer(
1342
            inputs,
1343
            attention_mask=attention_mask,
1344
            mems=mems,
1345
            perm_mask=perm_mask,
1346
            target_mapping=target_mapping,
1347
            token_type_ids=token_type_ids,
1348
            input_mask=input_mask,
1349
            head_mask=head_mask,
1350
            inputs_embeds=inputs_embeds,
1351
            use_cache=use_cache,
1352
            output_attentions=output_attentions,
1353
            output_hidden_states=output_hidden_states,
1354
            training=training,
1355
        )
1356
        output = transformer_outputs[0]
1357

1358
        logits = self.classifier(output)
1359

1360
        outputs = (logits,) + transformer_outputs[1:]  # Keep mems, hidden states, attentions if there are in it
1361

1362
        if labels is not None:
1363
            loss = self.compute_loss(labels, logits)
1364
            outputs = (loss,) + outputs
1365

1366
        return outputs  # (loss), logits, (hidden_states), (attentions)
1367

1368

1369
@add_start_docstrings(
1370
    """XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
1371
    the hidden-states output to compute `span start logits` and `span end logits`). """,
1372
    XLNET_START_DOCSTRING,
1373
)
1374
class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnsweringLoss):
1375
    def __init__(self, config, *inputs, **kwargs):
1376
        super().__init__(config, *inputs, **kwargs)
1377
        self.transformer = TFXLNetMainLayer(config, name="transformer")
1378
        self.qa_outputs = tf.keras.layers.Dense(
1379
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
1380
        )
1381

1382
    @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
1383
    @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlnet-base-cased")
1384
    def call(
1385
        self,
1386
        inputs=None,
1387
        attention_mask=None,
1388
        mems=None,
1389
        perm_mask=None,
1390
        target_mapping=None,
1391
        token_type_ids=None,
1392
        input_mask=None,
1393
        head_mask=None,
1394
        inputs_embeds=None,
1395
        use_cache=True,
1396
        output_attentions=None,
1397
        output_hidden_states=None,
1398
        start_positions=None,
1399
        end_positions=None,
1400
        training=False,
1401
    ):
1402
        r"""
1403
        start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1404
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
1405
            Positions are clamped to the length of the sequence (`sequence_length`).
1406
            Position outside of the sequence are not taken into account for computing the loss.
1407
        end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1408
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
1409
            Positions are clamped to the length of the sequence (`sequence_length`).
1410
            Position outside of the sequence are not taken into account for computing the loss.
1411

1412
    Returns:
1413
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
1414
        loss (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
1415
            Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
1416
        start_scores (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length,)`):
1417
            Span-start scores (before SoftMax).
1418
        end_scores (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length,)`):
1419
            Span-end scores (before SoftMax).
1420
        mems (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`):
1421
            Contains pre-computed hidden-states (key and values in the attention blocks).
1422
            Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
1423
            should not be passed as input ids as they have already been computed.
1424
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1425
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
1426
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1427

1428
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1429
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1430
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
1431
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
1432

1433
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1434
            heads.
1435
        """
1436
        if isinstance(inputs, (tuple, list)):
1437
            start_positions = inputs[12] if len(inputs) > 12 else start_positions
1438
            end_positions = inputs[13] if len(inputs) > 13 else end_positions
1439
            if len(inputs) > 12:
1440
                inputs = inputs[:12]
1441
        elif isinstance(inputs, (dict, BatchEncoding)):
1442
            start_positions = inputs.pop("start_positions", start_positions)
1443
            end_positions = inputs.pop("end_positions", start_positions)
1444

1445
        transformer_outputs = self.transformer(
1446
            inputs,
1447
            attention_mask=attention_mask,
1448
            mems=mems,
1449
            perm_mask=perm_mask,
1450
            target_mapping=target_mapping,
1451
            token_type_ids=token_type_ids,
1452
            input_mask=input_mask,
1453
            head_mask=head_mask,
1454
            inputs_embeds=inputs_embeds,
1455
            use_cache=use_cache,
1456
            output_attentions=output_attentions,
1457
            output_hidden_states=output_hidden_states,
1458
            training=training,
1459
        )
1460

1461
        sequence_output = transformer_outputs[0]
1462

1463
        logits = self.qa_outputs(sequence_output)
1464
        start_logits, end_logits = tf.split(logits, 2, axis=-1)
1465
        start_logits = tf.squeeze(start_logits, axis=-1)
1466
        end_logits = tf.squeeze(end_logits, axis=-1)
1467

1468
        outputs = (start_logits, end_logits,) + transformer_outputs[
1469
            1:
1470
        ]  # Keep mems, hidden states, attentions if there are in it
1471

1472
        if start_positions is not None and end_positions is not None:
1473
            labels = {"start_position": start_positions}
1474
            labels["end_position"] = end_positions
1475
            loss = self.compute_loss(labels, outputs[:2])
1476
            outputs = (loss,) + outputs
1477

1478
        return outputs  # (loss), start_logits, end_logits, (mems), (hidden_states), (attentions)
1479

1480

1481
# @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
1482
#     the hidden-states output to compute `span start logits` and `span end logits`). """,
1483
#     XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
1484
# class TFXLNetForQuestionAnswering(TFXLNetPreTrainedModel):
1485
#     r"""
1486
#     Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
1487
#         **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1488
#             ``tf.Tensor`` of shape ``(batch_size, config.start_n_top)``
1489
#             Log probabilities for the top config.start_n_top start token possibilities (beam-search).
1490
#         **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1491
#             ``tf.Tensor`` of shape ``(batch_size, config.start_n_top)``
1492
#             Indices for the top config.start_n_top start token possibilities (beam-search).
1493
#         **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1494
#             ``tf.Tensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
1495
#             Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
1496
#         **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1497
#             ``tf.Tensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
1498
#             Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
1499
#         **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1500
#             ``tf.Tensor`` of shape ``(batch_size,)``
1501
#             Log probabilities for the ``is_impossible`` label of the answers.
1502
#         **mems**:
1503
#             list of ``tf.Tensor`` (one for each layer):
1504
#             that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
1505
#             if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
1506
#             See details in the docstring of the `mems` input above.
1507
#         **hidden_states**: (`optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``)
1508
#             list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
1509
#             of shape ``(batch_size, sequence_length, hidden_size)``:
1510
#             Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1511
#         **attentions**: (`optional`, returned when ``output_attentions=True``)
1512
#             list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
1513
#             Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
1514

1515
#     Examples::
1516

1517
#         # For example purposes. Not runnable.
1518
#         tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
1519
#         model = XLMForQuestionAnswering.from_pretrained('xlnet-large-cased')
1520
#         input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]  # Batch size 1
1521
#         start_positions = tf.constant([1])
1522
#         end_positions = tf.constant([3])
1523
#         outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
1524
#         loss, start_scores, end_scores = outputs[:2]
1525

1526
#     """
1527
#     def __init__(self, config, *inputs, **kwargs):
1528
#         super().__init__(config, *inputs, **kwargs)
1529
#         self.start_n_top = config.start_n_top
1530
#         self.end_n_top = config.end_n_top
1531

1532
#         self.transformer = TFXLNetMainLayer(config, name='transformer')
1533
#         self.start_logits = TFPoolerStartLogits(config, name='start_logits')
1534
#         self.end_logits = TFPoolerEndLogits(config, name='end_logits')
1535
#         self.answer_class = TFPoolerAnswerClass(config, name='answer_class')
1536

1537
#     def call(self, inputs, training=False):
1538
#         transformer_outputs = self.transformer(inputs, training=training)
1539
#         hidden_states = transformer_outputs[0]
1540
#         start_logits = self.start_logits(hidden_states, p_mask=p_mask)
1541

1542
#         outputs = transformer_outputs[1:]  # Keep mems, hidden states, attentions if there are in it
1543

1544
#         if start_positions is not None and end_positions is not None:
1545
#             # If we are on multi-GPU, let's remove the dimension added by batch splitting
1546
#             for x in (start_positions, end_positions, cls_index, is_impossible):
1547
#                 if x is not None and x.dim() > 1:
1548
#                     x.squeeze_(-1)
1549

1550
#             # during training, compute the end logits based on the ground truth of the start position
1551
#             end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
1552

1553
#             loss_fct = CrossEntropyLoss()
1554
#             start_loss = loss_fct(start_logits, start_positions)
1555
#             end_loss = loss_fct(end_logits, end_positions)
1556
#             total_loss = (start_loss + end_loss) / 2
1557

1558
#             if cls_index is not None and is_impossible is not None:
1559
#                 # Predict answerability from the representation of CLS and START
1560
#                 cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
1561
#                 loss_fct_cls = nn.BCEWithLogitsLoss()
1562
#                 cls_loss = loss_fct_cls(cls_logits, is_impossible)
1563

1564
#                 # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
1565
#                 total_loss += cls_loss * 0.5
1566

1567
#             outputs = (total_loss,) + outputs
1568

1569
#         else:
1570
#             # during inference, compute the end logits based on beam search
1571
#             bsz, slen, hsz = hidden_states.size()
1572
#             start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
1573

1574
#             start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
1575
#             start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
1576
#             start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
1577
#             start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
1578

1579
#             hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
1580
#             p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
1581
#             end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
1582
#             end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
1583

1584
#             end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top)
1585
#             end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
1586
#             end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
1587

1588
#             start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)  # get the representation of START as weighted sum of hidden states
1589
#             cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)  # Shape (batch size,): one single `cls_logits` for each sample
1590

1591
#             outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
1592

1593
#         # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
1594
#         # or (if labels are provided) (total_loss,)
1595
#         return outputs
1596

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.