CSS-LM

Форк
0
/
modeling_xlnet.py 
1983 строки · 86.8 Кб
1
# coding=utf-8
2
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
""" PyTorch XLNet model.
17
"""
18

19

20
import logging
21
from dataclasses import dataclass
22
from typing import List, Optional, Tuple
23

24
import torch
25
from torch import nn
26
from torch.nn import CrossEntropyLoss, MSELoss
27
from torch.nn import functional as F
28

29
from .activations import gelu_new, swish
30
from .configuration_xlnet import XLNetConfig
31
from .file_utils import (
32
    ModelOutput,
33
    add_code_sample_docstrings,
34
    add_start_docstrings,
35
    add_start_docstrings_to_callable,
36
    replace_return_docstrings,
37
)
38
from .modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits, PreTrainedModel, SequenceSummary
39

40

41
logger = logging.getLogger(__name__)
42

43
_CONFIG_FOR_DOC = "XLNetConfig"
44
_TOKENIZER_FOR_DOC = "XLNetTokenizer"
45

46
XLNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
47
    "xlnet-base-cased",
48
    "xlnet-large-cased",
49
    # See all XLNet models at https://huggingface.co/models?filter=xlnet
50
]
51

52

53
def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None):
54
    """ A map of modules from TF to PyTorch.
55
        I use a map to keep the PyTorch model as
56
        identical to the original PyTorch model as possible.
57
    """
58

59
    tf_to_pt_map = {}
60

61
    if hasattr(model, "transformer"):
62
        if hasattr(model, "lm_loss"):
63
            # We will load also the output bias
64
            tf_to_pt_map["model/lm_loss/bias"] = model.lm_loss.bias
65
        if hasattr(model, "sequence_summary") and "model/sequnece_summary/summary/kernel" in tf_weights:
66
            # We will load also the sequence summary
67
            tf_to_pt_map["model/sequnece_summary/summary/kernel"] = model.sequence_summary.summary.weight
68
            tf_to_pt_map["model/sequnece_summary/summary/bias"] = model.sequence_summary.summary.bias
69
        if (
70
            hasattr(model, "logits_proj")
71
            and config.finetuning_task is not None
72
            and "model/regression_{}/logit/kernel".format(config.finetuning_task) in tf_weights
73
        ):
74
            tf_to_pt_map["model/regression_{}/logit/kernel".format(config.finetuning_task)] = model.logits_proj.weight
75
            tf_to_pt_map["model/regression_{}/logit/bias".format(config.finetuning_task)] = model.logits_proj.bias
76

77
        # Now load the rest of the transformer
78
        model = model.transformer
79

80
    # Embeddings and output
81
    tf_to_pt_map.update(
82
        {
83
            "model/transformer/word_embedding/lookup_table": model.word_embedding.weight,
84
            "model/transformer/mask_emb/mask_emb": model.mask_emb,
85
        }
86
    )
87

88
    # Transformer blocks
89
    for i, b in enumerate(model.layer):
90
        layer_str = "model/transformer/layer_%d/" % i
91
        tf_to_pt_map.update(
92
            {
93
                layer_str + "rel_attn/LayerNorm/gamma": b.rel_attn.layer_norm.weight,
94
                layer_str + "rel_attn/LayerNorm/beta": b.rel_attn.layer_norm.bias,
95
                layer_str + "rel_attn/o/kernel": b.rel_attn.o,
96
                layer_str + "rel_attn/q/kernel": b.rel_attn.q,
97
                layer_str + "rel_attn/k/kernel": b.rel_attn.k,
98
                layer_str + "rel_attn/r/kernel": b.rel_attn.r,
99
                layer_str + "rel_attn/v/kernel": b.rel_attn.v,
100
                layer_str + "ff/LayerNorm/gamma": b.ff.layer_norm.weight,
101
                layer_str + "ff/LayerNorm/beta": b.ff.layer_norm.bias,
102
                layer_str + "ff/layer_1/kernel": b.ff.layer_1.weight,
103
                layer_str + "ff/layer_1/bias": b.ff.layer_1.bias,
104
                layer_str + "ff/layer_2/kernel": b.ff.layer_2.weight,
105
                layer_str + "ff/layer_2/bias": b.ff.layer_2.bias,
106
            }
107
        )
108

109
    # Relative positioning biases
110
    if config.untie_r:
111
        r_r_list = []
112
        r_w_list = []
113
        r_s_list = []
114
        seg_embed_list = []
115
        for b in model.layer:
116
            r_r_list.append(b.rel_attn.r_r_bias)
117
            r_w_list.append(b.rel_attn.r_w_bias)
118
            r_s_list.append(b.rel_attn.r_s_bias)
119
            seg_embed_list.append(b.rel_attn.seg_embed)
120
    else:
121
        r_r_list = [model.r_r_bias]
122
        r_w_list = [model.r_w_bias]
123
        r_s_list = [model.r_s_bias]
124
        seg_embed_list = [model.seg_embed]
125
    tf_to_pt_map.update(
126
        {
127
            "model/transformer/r_r_bias": r_r_list,
128
            "model/transformer/r_w_bias": r_w_list,
129
            "model/transformer/r_s_bias": r_s_list,
130
            "model/transformer/seg_embed": seg_embed_list,
131
        }
132
    )
133
    return tf_to_pt_map
134

135

136
def load_tf_weights_in_xlnet(model, config, tf_path):
137
    """ Load tf checkpoints in a pytorch model
138
    """
139
    try:
140
        import numpy as np
141
        import tensorflow as tf
142
    except ImportError:
143
        logger.error(
144
            "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
145
            "https://www.tensorflow.org/install/ for installation instructions."
146
        )
147
        raise
148
    # Load weights from TF model
149
    init_vars = tf.train.list_variables(tf_path)
150
    tf_weights = {}
151
    for name, shape in init_vars:
152
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
153
        array = tf.train.load_variable(tf_path, name)
154
        tf_weights[name] = array
155

156
    # Build TF to PyTorch weights loading map
157
    tf_to_pt_map = build_tf_xlnet_to_pytorch_map(model, config, tf_weights)
158

159
    for name, pointer in tf_to_pt_map.items():
160
        logger.info("Importing {}".format(name))
161
        if name not in tf_weights:
162
            logger.info("{} not in tf pre-trained weights, skipping".format(name))
163
            continue
164
        array = tf_weights[name]
165
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
166
        # which are not required for using pretrained model
167
        if "kernel" in name and ("ff" in name or "summary" in name or "logit" in name):
168
            logger.info("Transposing")
169
            array = np.transpose(array)
170
        if isinstance(pointer, list):
171
            # Here we will split the TF weights
172
            assert len(pointer) == array.shape[0]
173
            for i, p_i in enumerate(pointer):
174
                arr_i = array[i, ...]
175
                try:
176
                    assert p_i.shape == arr_i.shape
177
                except AssertionError as e:
178
                    e.args += (p_i.shape, arr_i.shape)
179
                    raise
180
                logger.info("Initialize PyTorch weight {} for layer {}".format(name, i))
181
                p_i.data = torch.from_numpy(arr_i)
182
        else:
183
            try:
184
                assert pointer.shape == array.shape
185
            except AssertionError as e:
186
                e.args += (pointer.shape, array.shape)
187
                raise
188
            logger.info("Initialize PyTorch weight {}".format(name))
189
            pointer.data = torch.from_numpy(array)
190
        tf_weights.pop(name, None)
191
        tf_weights.pop(name + "/Adam", None)
192
        tf_weights.pop(name + "/Adam_1", None)
193

194
    logger.info("Weights not copied to PyTorch model: {}".format(", ".join(tf_weights.keys())))
195
    return model
196

197

198
ACT2FN = {"gelu": gelu_new, "relu": torch.nn.functional.relu, "swish": swish}
199

200

201
XLNetLayerNorm = nn.LayerNorm
202

203

204
class XLNetRelativeAttention(nn.Module):
205
    def __init__(self, config):
206
        super().__init__()
207

208
        if config.d_model % config.n_head != 0:
209
            raise ValueError(
210
                "The hidden size (%d) is not a multiple of the number of attention "
211
                "heads (%d)" % (config.d_model, config.n_head)
212
            )
213

214
        self.n_head = config.n_head
215
        self.d_head = config.d_head
216
        self.d_model = config.d_model
217
        self.scale = 1 / (config.d_head ** 0.5)
218

219
        self.q = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
220
        self.k = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
221
        self.v = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
222
        self.o = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
223
        self.r = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
224

225
        self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
226
        self.r_s_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
227
        self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
228
        self.seg_embed = nn.Parameter(torch.FloatTensor(2, self.n_head, self.d_head))
229

230
        self.layer_norm = XLNetLayerNorm(config.d_model, eps=config.layer_norm_eps)
231
        self.dropout = nn.Dropout(config.dropout)
232

233
    def prune_heads(self, heads):
234
        raise NotImplementedError
235

236
    @staticmethod
237
    def rel_shift(x, klen=-1):
238
        """perform relative shift to form the relative attention score."""
239
        x_size = x.shape
240

241
        x = x.reshape(x_size[1], x_size[0], x_size[2], x_size[3])
242
        x = x[1:, ...]
243
        x = x.reshape(x_size[0], x_size[1] - 1, x_size[2], x_size[3])
244
        # x = x[:, 0:klen, :, :]
245
        x = torch.index_select(x, 1, torch.arange(klen, device=x.device, dtype=torch.long))
246

247
        return x
248

249
    @staticmethod
250
    def rel_shift_bnij(x, klen=-1):
251
        x_size = x.shape
252

253
        x = x.reshape(x_size[0], x_size[1], x_size[3], x_size[2])
254
        x = x[:, :, 1:, :]
255
        x = x.reshape(x_size[0], x_size[1], x_size[2], x_size[3] - 1)
256
        # Note: the tensor-slice form was faster in my testing than torch.index_select
257
        #       However, tracing doesn't like the nature of the slice, and if klen changes
258
        #       during the run then it'll fail, whereas index_select will be fine.
259
        x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long))
260
        # x = x[:, :, :, :klen]
261

262
        return x
263

264
    def rel_attn_core(
265
        self,
266
        q_head,
267
        k_head_h,
268
        v_head_h,
269
        k_head_r,
270
        seg_mat=None,
271
        attn_mask=None,
272
        head_mask=None,
273
        output_attentions=False,
274
    ):
275
        """Core relative positional attention operations."""
276

277
        # content based attention score
278
        ac = torch.einsum("ibnd,jbnd->bnij", q_head + self.r_w_bias, k_head_h)
279

280
        # position based attention score
281
        bd = torch.einsum("ibnd,jbnd->bnij", q_head + self.r_r_bias, k_head_r)
282
        bd = self.rel_shift_bnij(bd, klen=ac.shape[3])
283

284
        # segment based attention score
285
        if seg_mat is None:
286
            ef = 0
287
        else:
288
            ef = torch.einsum("ibnd,snd->ibns", q_head + self.r_s_bias, self.seg_embed)
289
            ef = torch.einsum("ijbs,ibns->bnij", seg_mat, ef)
290

291
        # merge attention scores and perform masking
292
        attn_score = (ac + bd + ef) * self.scale
293
        if attn_mask is not None:
294
            # attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
295
            if attn_mask.dtype == torch.float16:
296
                attn_score = attn_score - 65500 * torch.einsum("ijbn->bnij", attn_mask)
297
            else:
298
                attn_score = attn_score - 1e30 * torch.einsum("ijbn->bnij", attn_mask)
299

300
        # attention probability
301
        attn_prob = F.softmax(attn_score, dim=3)
302
        attn_prob = self.dropout(attn_prob)
303

304
        # Mask heads if we want to
305
        if head_mask is not None:
306
            attn_prob = attn_prob * torch.einsum("ijbn->bnij", head_mask)
307

308
        # attention output
309
        attn_vec = torch.einsum("bnij,jbnd->ibnd", attn_prob, v_head_h)
310

311
        if output_attentions:
312
            return attn_vec, torch.einsum("bnij->ijbn", attn_prob)
313

314
        return attn_vec
315

316
    def post_attention(self, h, attn_vec, residual=True):
317
        """Post-attention processing."""
318
        # post-attention projection (back to `d_model`)
319
        attn_out = torch.einsum("ibnd,hnd->ibh", attn_vec, self.o)
320

321
        attn_out = self.dropout(attn_out)
322
        if residual:
323
            attn_out = attn_out + h
324
        output = self.layer_norm(attn_out)
325

326
        return output
327

328
    def forward(
329
        self,
330
        h,
331
        g,
332
        attn_mask_h,
333
        attn_mask_g,
334
        r,
335
        seg_mat,
336
        mems=None,
337
        target_mapping=None,
338
        head_mask=None,
339
        output_attentions=False,
340
    ):
341
        if g is not None:
342
            # Two-stream attention with relative positional encoding.
343
            # content based attention score
344
            if mems is not None and mems.dim() > 1:
345
                cat = torch.cat([mems, h], dim=0)
346
            else:
347
                cat = h
348

349
            # content-based key head
350
            k_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.k)
351

352
            # content-based value head
353
            v_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.v)
354

355
            # position-based key head
356
            k_head_r = torch.einsum("ibh,hnd->ibnd", r, self.r)
357

358
            # h-stream
359
            # content-stream query head
360
            q_head_h = torch.einsum("ibh,hnd->ibnd", h, self.q)
361

362
            # core attention ops
363
            attn_vec_h = self.rel_attn_core(
364
                q_head_h,
365
                k_head_h,
366
                v_head_h,
367
                k_head_r,
368
                seg_mat=seg_mat,
369
                attn_mask=attn_mask_h,
370
                head_mask=head_mask,
371
                output_attentions=output_attentions,
372
            )
373

374
            if output_attentions:
375
                attn_vec_h, attn_prob_h = attn_vec_h
376

377
            # post processing
378
            output_h = self.post_attention(h, attn_vec_h)
379

380
            # g-stream
381
            # query-stream query head
382
            q_head_g = torch.einsum("ibh,hnd->ibnd", g, self.q)
383

384
            # core attention ops
385
            if target_mapping is not None:
386
                q_head_g = torch.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
387
                attn_vec_g = self.rel_attn_core(
388
                    q_head_g,
389
                    k_head_h,
390
                    v_head_h,
391
                    k_head_r,
392
                    seg_mat=seg_mat,
393
                    attn_mask=attn_mask_g,
394
                    head_mask=head_mask,
395
                    output_attentions=output_attentions,
396
                )
397

398
                if output_attentions:
399
                    attn_vec_g, attn_prob_g = attn_vec_g
400

401
                attn_vec_g = torch.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
402
            else:
403
                attn_vec_g = self.rel_attn_core(
404
                    q_head_g,
405
                    k_head_h,
406
                    v_head_h,
407
                    k_head_r,
408
                    seg_mat=seg_mat,
409
                    attn_mask=attn_mask_g,
410
                    head_mask=head_mask,
411
                    output_attentions=output_attentions,
412
                )
413

414
                if output_attentions:
415
                    attn_vec_g, attn_prob_g = attn_vec_g
416

417
            # post processing
418
            output_g = self.post_attention(g, attn_vec_g)
419

420
            if output_attentions:
421
                attn_prob = attn_prob_h, attn_prob_g
422

423
        else:
424
            # Multi-head attention with relative positional encoding
425
            if mems is not None and mems.dim() > 1:
426
                cat = torch.cat([mems, h], dim=0)
427
            else:
428
                cat = h
429

430
            # content heads
431
            q_head_h = torch.einsum("ibh,hnd->ibnd", h, self.q)
432
            k_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.k)
433
            v_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.v)
434

435
            # positional heads
436
            k_head_r = torch.einsum("ibh,hnd->ibnd", r, self.r)
437

438
            # core attention ops
439
            attn_vec = self.rel_attn_core(
440
                q_head_h,
441
                k_head_h,
442
                v_head_h,
443
                k_head_r,
444
                seg_mat=seg_mat,
445
                attn_mask=attn_mask_h,
446
                head_mask=head_mask,
447
                output_attentions=output_attentions,
448
            )
449

450
            if output_attentions:
451
                attn_vec, attn_prob = attn_vec
452

453
            # post processing
454
            output_h = self.post_attention(h, attn_vec)
455
            output_g = None
456

457
        outputs = (output_h, output_g)
458
        if output_attentions:
459
            outputs = outputs + (attn_prob,)
460
        return outputs
461

462

463
class XLNetFeedForward(nn.Module):
464
    def __init__(self, config):
465
        super().__init__()
466
        self.layer_norm = XLNetLayerNorm(config.d_model, eps=config.layer_norm_eps)
467
        self.layer_1 = nn.Linear(config.d_model, config.d_inner)
468
        self.layer_2 = nn.Linear(config.d_inner, config.d_model)
469
        self.dropout = nn.Dropout(config.dropout)
470
        if isinstance(config.ff_activation, str):
471
            self.activation_function = ACT2FN[config.ff_activation]
472
        else:
473
            self.activation_function = config.ff_activation
474

475
    def forward(self, inp):
476
        output = inp
477
        output = self.layer_1(output)
478
        output = self.activation_function(output)
479
        output = self.dropout(output)
480
        output = self.layer_2(output)
481
        output = self.dropout(output)
482
        output = self.layer_norm(output + inp)
483
        return output
484

485

486
class XLNetLayer(nn.Module):
487
    def __init__(self, config):
488
        super().__init__()
489
        self.rel_attn = XLNetRelativeAttention(config)
490
        self.ff = XLNetFeedForward(config)
491
        self.dropout = nn.Dropout(config.dropout)
492

493
    def forward(
494
        self,
495
        output_h,
496
        output_g,
497
        attn_mask_h,
498
        attn_mask_g,
499
        r,
500
        seg_mat,
501
        mems=None,
502
        target_mapping=None,
503
        head_mask=None,
504
        output_attentions=False,
505
    ):
506
        outputs = self.rel_attn(
507
            output_h,
508
            output_g,
509
            attn_mask_h,
510
            attn_mask_g,
511
            r,
512
            seg_mat,
513
            mems=mems,
514
            target_mapping=target_mapping,
515
            head_mask=head_mask,
516
            output_attentions=output_attentions,
517
        )
518
        output_h, output_g = outputs[:2]
519

520
        if output_g is not None:
521
            output_g = self.ff(output_g)
522
        output_h = self.ff(output_h)
523

524
        outputs = (output_h, output_g) + outputs[2:]  # Add again attentions if there are there
525
        return outputs
526

527

528
class XLNetPreTrainedModel(PreTrainedModel):
529
    """ An abstract class to handle weights initialization and
530
        a simple interface for downloading and loading pretrained models.
531
    """
532

533
    config_class = XLNetConfig
534
    load_tf_weights = load_tf_weights_in_xlnet
535
    base_model_prefix = "transformer"
536

537
    def _init_weights(self, module):
538
        """ Initialize the weights.
539
        """
540
        if isinstance(module, (nn.Linear, nn.Embedding)):
541
            # Slightly different from the TF version which uses truncated_normal for initialization
542
            # cf https://github.com/pytorch/pytorch/pull/5617
543
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
544
            if isinstance(module, nn.Linear) and module.bias is not None:
545
                module.bias.data.zero_()
546
        elif isinstance(module, XLNetLayerNorm):
547
            module.bias.data.zero_()
548
            module.weight.data.fill_(1.0)
549
        elif isinstance(module, XLNetRelativeAttention):
550
            for param in [
551
                module.q,
552
                module.k,
553
                module.v,
554
                module.o,
555
                module.r,
556
                module.r_r_bias,
557
                module.r_s_bias,
558
                module.r_w_bias,
559
                module.seg_embed,
560
            ]:
561
                param.data.normal_(mean=0.0, std=self.config.initializer_range)
562
        elif isinstance(module, XLNetModel):
563
            module.mask_emb.data.normal_(mean=0.0, std=self.config.initializer_range)
564

565

566
@dataclass
567
class XLNetModelOutput(ModelOutput):
568
    """
569
    Output type of :class:`~transformers.XLNetModel`.
570

571
    Args:
572
        last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_predict, hidden_size)`):
573
            Sequence of hidden-states at the last layer of the model.
574

575
            ``num_predict`` corresponds to ``target_mapping.shape[1]``. If ``target_mapping`` is ``None``, then
576
            ``num_predict`` corresponds to ``sequence_length``.
577
        mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
578
            Contains pre-computed hidden-states.
579
            Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
580
            should not be passed as input ids as they have already been computed.
581
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
582
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
583
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
584

585
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
586
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
587
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
588
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
589

590
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
591
            heads.
592
    """
593

594
    last_hidden_state: torch.FloatTensor
595
    mems: Optional[List[torch.FloatTensor]] = None
596
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
597
    attentions: Optional[Tuple[torch.FloatTensor]] = None
598

599

600
@dataclass
601
class XLNetLMHeadModelOutput(ModelOutput):
602
    """
603
    Output type of :class:`~transformers.XLNetLMHeadModel`.
604

605
    Args:
606
        loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
607
            Language modeling loss (for next-token prediction).
608
        logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_predict, config.vocab_size)`):
609
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
610

611
            ``num_predict`` corresponds to ``target_mapping.shape[1]``. If ``target_mapping`` is ``None``, then
612
            ``num_predict`` corresponds to ``sequence_length``.
613
        mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
614
            Contains pre-computed hidden-states.
615
            Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
616
            should not be passed as input ids as they have already been computed.
617
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
618
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
619
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
620

621
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
622
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
623
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
624
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
625

626
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
627
            heads.
628
    """
629

630
    loss: Optional[torch.FloatTensor] = None
631
    logits: torch.FloatTensor = None
632
    mems: Optional[List[torch.FloatTensor]] = None
633
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
634
    attentions: Optional[Tuple[torch.FloatTensor]] = None
635

636

637
@dataclass
638
class XLNetForSequenceClassificationOutput(ModelOutput):
639
    """
640
    Output type of :class:`~transformers.XLNetForSequenceClassification`.
641

642
    Args:
643
        loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
644
            Classification (or regression if config.num_labels==1) loss.
645
        logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
646
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
647
        mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
648
            Contains pre-computed hidden-states.
649
            Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
650
            should not be passed as input ids as they have already been computed.
651
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
652
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
653
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
654

655
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
656
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
657
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
658
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
659

660
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
661
            heads.
662
    """
663

664
    loss: Optional[torch.FloatTensor] = None
665
    logits: torch.FloatTensor = None
666
    mems: Optional[List[torch.FloatTensor]] = None
667
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
668
    attentions: Optional[Tuple[torch.FloatTensor]] = None
669

670

671
@dataclass
672
class XLNetForTokenClassificationOutput(ModelOutput):
673
    """
674
    Output type of :class:`~transformers.XLNetForTokenClassificationOutput`.
675

676
    Args:
677
        loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
678
            Classification loss.
679
        logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
680
            Classification scores (before SoftMax).
681
        mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
682
            Contains pre-computed hidden-states.
683
            Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
684
            should not be passed as input ids as they have already been computed.
685
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
686
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
687
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
688

689
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
690
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
691
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
692
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
693

694
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
695
            heads.
696
    """
697

698
    loss: Optional[torch.FloatTensor] = None
699
    logits: torch.FloatTensor = None
700
    mems: Optional[List[torch.FloatTensor]] = None
701
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
702
    attentions: Optional[Tuple[torch.FloatTensor]] = None
703

704

705
@dataclass
706
class XLNetForMultipleChoiceOutput(ModelOutput):
707
    """
708
    Base class for outputs of multiple choice models.
709

710
    Args:
711
        loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
712
            Classification loss.
713
        logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
714
            `num_choices` is the second dimension of the input tensors. (see `input_ids` above).
715

716
            Classification scores (before SoftMax).
717
        mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
718
            Contains pre-computed hidden-states.
719
            Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
720
            should not be passed as input ids as they have already been computed.
721
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
722
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
723
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
724

725
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
726
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
727
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
728
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
729

730
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
731
            heads.
732
    """
733

734
    loss: Optional[torch.FloatTensor] = None
735
    logits: torch.FloatTensor = None
736
    mems: Optional[List[torch.FloatTensor]] = None
737
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
738
    attentions: Optional[Tuple[torch.FloatTensor]] = None
739

740

741
@dataclass
742
class XLNetForQuestionAnsweringSimpleOutput(ModelOutput):
743
    """
744
    Base class for outputs of question answering models.
745

746
    Args:
747
        loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
748
            Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
749
        start_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
750
            Span-start scores (before SoftMax).
751
        end_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
752
            Span-end scores (before SoftMax).
753
        mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
754
            Contains pre-computed hidden-states.
755
            Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
756
            should not be passed as input ids as they have already been computed.
757
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
758
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
759
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
760

761
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
762
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
763
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
764
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
765

766
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
767
            heads.
768
    """
769

770
    loss: Optional[torch.FloatTensor] = None
771
    start_logits: torch.FloatTensor = None
772
    end_logits: torch.FloatTensor = None
773
    mems: Optional[List[torch.FloatTensor]] = None
774
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
775
    attentions: Optional[Tuple[torch.FloatTensor]] = None
776

777

778
@dataclass
779
class XLNetForQuestionAnsweringOutput(ModelOutput):
780
    """
781
    Base class for outputs of question answering models using a :obj:`SquadHead`.
782

783
    Args:
784
        loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned if both :obj:`start_positions` and :obj:`end_positions` are provided):
785
            Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
786
        start_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
787
            Log probabilities for the top config.start_n_top start token possibilities (beam-search).
788
        start_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
789
            Indices for the top config.start_n_top start token possibilities (beam-search).
790
        end_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
791
            Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
792
        end_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
793
            Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
794
        cls_logits (``torch.FloatTensor`` of shape ``(batch_size,)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
795
            Log probabilities for the ``is_impossible`` label of the answers.
796
        mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
797
            Contains pre-computed hidden-states.
798
            Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
799
            should not be passed as input ids as they have already been computed.
800
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
801
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
802
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
803

804
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
805
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
806
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
807
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
808

809
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
810
            heads.
811
    """
812

813
    loss: Optional[torch.FloatTensor] = None
814
    start_top_log_probs: Optional[torch.FloatTensor] = None
815
    start_top_index: Optional[torch.LongTensor] = None
816
    end_top_log_probs: Optional[torch.FloatTensor] = None
817
    end_top_index: Optional[torch.LongTensor] = None
818
    cls_logits: Optional[torch.FloatTensor] = None
819
    mems: Optional[List[torch.FloatTensor]] = None
820
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
821
    attentions: Optional[Tuple[torch.FloatTensor]] = None
822

823

824
XLNET_START_DOCSTRING = r"""
825

826
    This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
827
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
828
    usage and behavior.
829

830
    Parameters:
831
        config (:class:`~transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
832
            Initializing with a config file does not load the weights associated with the model, only the configuration.
833
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
834
"""
835

836
XLNET_INPUTS_DOCSTRING = r"""
837
    Args:
838
        input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
839
            Indices of input sequence tokens in the vocabulary.
840

841
            Indices can be obtained using :class:`transformers.BertTokenizer`.
842
            See :func:`transformers.PreTrainedTokenizer.encode` and
843
            :func:`transformers.PreTrainedTokenizer.__call__` for details.
844

845
            `What are input IDs? <../glossary.html#input-ids>`__
846
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
847
            Mask to avoid performing attention on padding token indices.
848
            Mask values selected in ``[0, 1]``:
849
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
850

851
            `What are attention masks? <../glossary.html#attention-mask>`__
852
        mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
853
            Contains pre-computed hidden-states as computed by the model
854
            (see `mems` output below). Can be used to speed up sequential decoding. The token ids which have their mems
855
            given to this model should not be passed as input ids as they have already been computed.
856
            `use_cache` has to be set to `True` to make use of `mems`.
857
        perm_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, sequence_length)`, `optional`, defaults to :obj:`None`):
858
            Mask to indicate the attention pattern for each input token with values selected in ``[0, 1]``:
859
            If ``perm_mask[k, i, j] = 0``, i attend to j in batch k;
860
            if ``perm_mask[k, i, j] = 1``, i does not attend to j in batch k.
861
            If None, each token attends to all the others (full bidirectional attention).
862
            Only used during pretraining (to define factorization order) or for sequential decoding (generation).
863
        target_mapping (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_predict, sequence_length)`, `optional`, defaults to :obj:`None`):
864
            Mask to indicate the output tokens to use.
865
            If ``target_mapping[k, i, j] = 1``, the i-th predict in batch k is on the j-th token.
866
            Only used during pretraining for partial prediction or for sequential decoding (generation).
867
        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
868
            Segment token indices to indicate first and second portions of the inputs.
869
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
870
            corresponds to a `sentence B` token. The classifier token should be represented by a ``2``.
871

872
            `What are token type IDs? <../glossary.html#token-type-ids>`_
873
        input_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
874
            Mask to avoid performing attention on padding token indices.
875
            Negative of `attention_mask`, i.e. with 0 for real tokens and 1 for padding.
876
            Kept for compatibility with the original code base.
877
            You can only uses one of `input_mask` and `attention_mask`
878
            Mask values selected in ``[0, 1]``:
879
            ``1`` for tokens that are MASKED, ``0`` for tokens that are NOT MASKED.
880
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
881
            Mask to nullify selected heads of the self-attention modules.
882
            Mask values selected in ``[0, 1]``:
883
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
884
        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
885
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
886
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
887
            than the model's internal embedding lookup matrix.
888
        use_cache (:obj:`bool`):
889
            If `use_cache` is True, `mems` are returned and can be used to speed up decoding (see `mems`). Defaults to `True`.
890
        output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
891
            If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
892
        output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
893
            If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
894
        return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`):
895
            If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a
896
            plain tuple.
897
"""
898

899

900
@add_start_docstrings(
901
    "The bare XLNet Model transformer outputting raw hidden-states without any specific head on top.",
902
    XLNET_START_DOCSTRING,
903
)
904
class XLNetModel(XLNetPreTrainedModel):
905
    def __init__(self, config):
906
        super().__init__(config)
907

908
        self.mem_len = config.mem_len
909
        self.reuse_len = config.reuse_len
910
        self.d_model = config.d_model
911
        self.same_length = config.same_length
912
        self.attn_type = config.attn_type
913
        self.bi_data = config.bi_data
914
        self.clamp_len = config.clamp_len
915
        self.n_layer = config.n_layer
916

917
        self.word_embedding = nn.Embedding(config.vocab_size, config.d_model)
918
        self.mask_emb = nn.Parameter(torch.FloatTensor(1, 1, config.d_model))
919
        self.layer = nn.ModuleList([XLNetLayer(config) for _ in range(config.n_layer)])
920
        self.dropout = nn.Dropout(config.dropout)
921

922
        self.init_weights()
923

924
    def get_input_embeddings(self):
925
        return self.word_embedding
926

927
    def set_input_embeddings(self, new_embeddings):
928
        self.word_embedding = new_embeddings
929

930
    def _prune_heads(self, heads_to_prune):
931
        raise NotImplementedError
932

933
    def create_mask(self, qlen, mlen):
934
        """
935
        Creates causal attention mask. Float mask where 1.0 indicates masked, 0.0 indicates not-masked.
936

937
        Args:
938
            qlen: Sequence length
939
            mlen: Mask length
940

941
        ::
942

943
                  same_length=False:      same_length=True:
944
                  <mlen > <  qlen >       <mlen > <  qlen >
945
               ^ [0 0 0 0 0 1 1 1 1]     [0 0 0 0 0 1 1 1 1]
946
                 [0 0 0 0 0 0 1 1 1]     [1 0 0 0 0 0 1 1 1]
947
            qlen [0 0 0 0 0 0 0 1 1]     [1 1 0 0 0 0 0 1 1]
948
                 [0 0 0 0 0 0 0 0 1]     [1 1 1 0 0 0 0 0 1]
949
               v [0 0 0 0 0 0 0 0 0]     [1 1 1 1 0 0 0 0 0]
950

951
        """
952
        attn_mask = torch.ones([qlen, qlen])
953
        mask_up = torch.triu(attn_mask, diagonal=1)
954
        attn_mask_pad = torch.zeros([qlen, mlen])
955
        ret = torch.cat([attn_mask_pad, mask_up], dim=1)
956
        if self.same_length:
957
            mask_lo = torch.tril(attn_mask, diagonal=-1)
958
            ret = torch.cat([ret[:, :qlen] + mask_lo, ret[:, qlen:]], dim=1)
959

960
        ret = ret.to(self.device)
961
        return ret
962

963
    def cache_mem(self, curr_out, prev_mem):
964
        # cache hidden states into memory.
965
        if self.reuse_len is not None and self.reuse_len > 0:
966
            curr_out = curr_out[: self.reuse_len]
967

968
        if self.mem_len is None or self.mem_len == 0:
969
            # If `use_cache` is active but no `mem_len` is defined, the model behaves like GPT-2 at inference time
970
            # and returns all of the past and current hidden states.
971
            cutoff = 0
972
        else:
973
            # If `use_cache` is active and `mem_len` is defined, the model returns the last `mem_len` hidden
974
            # states. This is the preferred setting for training and long-form generation.
975
            cutoff = -self.mem_len
976
        if prev_mem is None:
977
            # if `use_cache` is active and `mem_len` is defined, the model
978
            new_mem = curr_out[cutoff:]
979
        else:
980
            new_mem = torch.cat([prev_mem, curr_out], dim=0)[cutoff:]
981

982
        return new_mem.detach()
983

984
    @staticmethod
985
    def positional_embedding(pos_seq, inv_freq, bsz=None):
986
        sinusoid_inp = torch.einsum("i,d->id", pos_seq, inv_freq)
987
        pos_emb = torch.cat([torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)], dim=-1)
988
        pos_emb = pos_emb[:, None, :]
989

990
        if bsz is not None:
991
            pos_emb = pos_emb.expand(-1, bsz, -1)
992

993
        return pos_emb
994

995
    def relative_positional_encoding(self, qlen, klen, bsz=None):
996
        # create relative positional encoding.
997
        freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.float)
998
        inv_freq = 1 / torch.pow(10000, (freq_seq / self.d_model))
999

1000
        if self.attn_type == "bi":
1001
            # beg, end = klen - 1, -qlen
1002
            beg, end = klen, -qlen
1003
        elif self.attn_type == "uni":
1004
            # beg, end = klen - 1, -1
1005
            beg, end = klen, -1
1006
        else:
1007
            raise ValueError("Unknown `attn_type` {}.".format(self.attn_type))
1008

1009
        if self.bi_data:
1010
            fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.float)
1011
            bwd_pos_seq = torch.arange(-beg, -end, 1.0, dtype=torch.float)
1012

1013
            if self.clamp_len > 0:
1014
                fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
1015
                bwd_pos_seq = bwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
1016

1017
            if bsz is not None:
1018
                fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz // 2)
1019
                bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz // 2)
1020
            else:
1021
                fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq)
1022
                bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq)
1023

1024
            pos_emb = torch.cat([fwd_pos_emb, bwd_pos_emb], dim=1)
1025
        else:
1026
            fwd_pos_seq = torch.arange(beg, end, -1.0)
1027
            if self.clamp_len > 0:
1028
                fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
1029
            pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
1030

1031
        pos_emb = pos_emb.to(self.device)
1032
        return pos_emb
1033

1034
    @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1035
    @add_code_sample_docstrings(
1036
        tokenizer_class=_TOKENIZER_FOR_DOC,
1037
        checkpoint="xlnet-base-cased",
1038
        output_type=XLNetModelOutput,
1039
        config_class=_CONFIG_FOR_DOC,
1040
    )
1041
    def forward(
1042
        self,
1043
        input_ids=None,
1044
        attention_mask=None,
1045
        mems=None,
1046
        perm_mask=None,
1047
        target_mapping=None,
1048
        token_type_ids=None,
1049
        input_mask=None,
1050
        head_mask=None,
1051
        inputs_embeds=None,
1052
        use_cache=None,
1053
        output_attentions=None,
1054
        output_hidden_states=None,
1055
        return_dict=None,
1056
    ):
1057
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1058
        output_hidden_states = (
1059
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1060
        )
1061
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1062
        use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
1063

1064
        # the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
1065
        # but we want a unified interface in the library with the batch size on the first dimension
1066
        # so we move here the first dimension (batch) to the end
1067
        if input_ids is not None and inputs_embeds is not None:
1068
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1069
        elif input_ids is not None:
1070
            input_ids = input_ids.transpose(0, 1).contiguous()
1071
            qlen, bsz = input_ids.shape[0], input_ids.shape[1]
1072
        elif inputs_embeds is not None:
1073
            inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
1074
            qlen, bsz = inputs_embeds.shape[0], inputs_embeds.shape[1]
1075
        else:
1076
            raise ValueError("You have to specify either input_ids or inputs_embeds")
1077

1078
        token_type_ids = token_type_ids.transpose(0, 1).contiguous() if token_type_ids is not None else None
1079
        input_mask = input_mask.transpose(0, 1).contiguous() if input_mask is not None else None
1080
        attention_mask = attention_mask.transpose(0, 1).contiguous() if attention_mask is not None else None
1081
        perm_mask = perm_mask.permute(1, 2, 0).contiguous() if perm_mask is not None else None
1082
        target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None
1083

1084
        mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0
1085
        klen = mlen + qlen
1086

1087
        dtype_float = self.dtype
1088
        device = self.device
1089

1090
        # Attention mask
1091
        # causal attention mask
1092
        if self.attn_type == "uni":
1093
            attn_mask = self.create_mask(qlen, mlen)
1094
            attn_mask = attn_mask[:, :, None, None]
1095
        elif self.attn_type == "bi":
1096
            attn_mask = None
1097
        else:
1098
            raise ValueError("Unsupported attention type: {}".format(self.attn_type))
1099

1100
        # data mask: input mask & perm mask
1101
        assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) "
1102
        "or attention_mask (uses 0 for padding, added for compatbility with BERT). Please choose one."
1103
        if input_mask is None and attention_mask is not None:
1104
            input_mask = 1.0 - attention_mask
1105
        if input_mask is not None and perm_mask is not None:
1106
            data_mask = input_mask[None] + perm_mask
1107
        elif input_mask is not None and perm_mask is None:
1108
            data_mask = input_mask[None]
1109
        elif input_mask is None and perm_mask is not None:
1110
            data_mask = perm_mask
1111
        else:
1112
            data_mask = None
1113

1114
        if data_mask is not None:
1115
            # all mems can be attended to
1116
            if mlen > 0:
1117
                mems_mask = torch.zeros([data_mask.shape[0], mlen, bsz]).to(data_mask)
1118
                data_mask = torch.cat([mems_mask, data_mask], dim=1)
1119
            if attn_mask is None:
1120
                attn_mask = data_mask[:, :, :, None]
1121
            else:
1122
                attn_mask += data_mask[:, :, :, None]
1123

1124
        if attn_mask is not None:
1125
            attn_mask = (attn_mask > 0).to(dtype_float)
1126

1127
        if attn_mask is not None:
1128
            non_tgt_mask = -torch.eye(qlen).to(attn_mask)
1129
            if mlen > 0:
1130
                non_tgt_mask = torch.cat([torch.zeros([qlen, mlen]).to(attn_mask), non_tgt_mask], dim=-1)
1131
            non_tgt_mask = ((attn_mask + non_tgt_mask[:, :, None, None]) > 0).to(attn_mask)
1132
        else:
1133
            non_tgt_mask = None
1134

1135
        # Word embeddings and prepare h & g hidden states
1136
        if inputs_embeds is not None:
1137
            word_emb_k = inputs_embeds
1138
        else:
1139
            word_emb_k = self.word_embedding(input_ids)
1140
        output_h = self.dropout(word_emb_k)
1141
        if target_mapping is not None:
1142
            word_emb_q = self.mask_emb.expand(target_mapping.shape[0], bsz, -1)
1143
            # else:  # We removed the inp_q input which was same as target mapping
1144
            #     inp_q_ext = inp_q[:, :, None]
1145
            #     word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
1146
            output_g = self.dropout(word_emb_q)
1147
        else:
1148
            output_g = None
1149

1150
        # Segment embedding
1151
        if token_type_ids is not None:
1152
            # Convert `token_type_ids` to one-hot `seg_mat`
1153
            if mlen > 0:
1154
                mem_pad = torch.zeros([mlen, bsz], dtype=torch.long, device=device)
1155
                cat_ids = torch.cat([mem_pad, token_type_ids], dim=0)
1156
            else:
1157
                cat_ids = token_type_ids
1158

1159
            # `1` indicates not in the same segment [qlen x klen x bsz]
1160
            seg_mat = (token_type_ids[:, None] != cat_ids[None, :]).long()
1161
            seg_mat = F.one_hot(seg_mat, num_classes=2).to(dtype_float)
1162
        else:
1163
            seg_mat = None
1164

1165
        # Positional encoding
1166
        pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
1167
        pos_emb = self.dropout(pos_emb)
1168

1169
        # Prepare head mask if needed
1170
        # 1.0 in head_mask indicate we keep the head
1171
        # attention_probs has shape bsz x n_heads x N x N
1172
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
1173
        # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
1174
        if head_mask is not None:
1175
            if head_mask.dim() == 1:
1176
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
1177
                head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
1178
            elif head_mask.dim() == 2:
1179
                head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
1180
            head_mask = head_mask.to(
1181
                dtype=next(self.parameters()).dtype
1182
            )  # switch to fload if need + fp16 compatibility
1183
        else:
1184
            head_mask = [None] * self.n_layer
1185

1186
        new_mems = ()
1187
        if mems is None:
1188
            mems = [None] * len(self.layer)
1189

1190
        attentions = [] if output_attentions else None
1191
        hidden_states = [] if output_hidden_states else None
1192
        for i, layer_module in enumerate(self.layer):
1193
            if use_cache:
1194
                # cache new mems
1195
                new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
1196
            if output_hidden_states:
1197
                hidden_states.append((output_h, output_g) if output_g is not None else output_h)
1198

1199
            outputs = layer_module(
1200
                output_h,
1201
                output_g,
1202
                attn_mask_h=non_tgt_mask,
1203
                attn_mask_g=attn_mask,
1204
                r=pos_emb,
1205
                seg_mat=seg_mat,
1206
                mems=mems[i],
1207
                target_mapping=target_mapping,
1208
                head_mask=head_mask[i],
1209
                output_attentions=output_attentions,
1210
            )
1211
            output_h, output_g = outputs[:2]
1212
            if output_attentions:
1213
                attentions.append(outputs[2])
1214

1215
        # Add last hidden state
1216
        if output_hidden_states:
1217
            hidden_states.append((output_h, output_g) if output_g is not None else output_h)
1218

1219
        output = self.dropout(output_g if output_g is not None else output_h)
1220

1221
        # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
1222
        output = output.permute(1, 0, 2).contiguous()
1223

1224
        # TODO Teven: fix this test to only use use_cache.
1225
        if not use_cache:
1226
            new_mems = None
1227

1228
        if output_hidden_states:
1229
            if output_g is not None:
1230
                hidden_states = tuple(h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs)
1231
            else:
1232
                hidden_states = tuple(hs.permute(1, 0, 2).contiguous() for hs in hidden_states)
1233

1234
        if output_attentions:
1235
            if target_mapping is not None:
1236
                # when target_mapping is provided, there are 2-tuple of attentions
1237
                attentions = tuple(
1238
                    tuple(att_stream.permute(2, 3, 0, 1).contiguous() for att_stream in t) for t in attentions
1239
                )
1240
            else:
1241
                attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
1242

1243
        if not return_dict:
1244
            return tuple(v for v in [output, new_mems, hidden_states, attentions] if v is not None)
1245

1246
        return XLNetModelOutput(
1247
            last_hidden_state=output, mems=new_mems, hidden_states=hidden_states, attentions=attentions
1248
        )
1249

1250

1251
@add_start_docstrings(
1252
    """XLNet Model with a language modeling head on top
1253
    (linear layer with weights tied to the input embeddings). """,
1254
    XLNET_START_DOCSTRING,
1255
)
1256
class XLNetLMHeadModel(XLNetPreTrainedModel):
1257
    def __init__(self, config):
1258
        super().__init__(config)
1259
        self.attn_type = config.attn_type
1260
        self.same_length = config.same_length
1261

1262
        self.transformer = XLNetModel(config)
1263
        self.lm_loss = nn.Linear(config.d_model, config.vocab_size, bias=True)
1264

1265
        self.init_weights()
1266

1267
    def get_output_embeddings(self):
1268
        return self.lm_loss
1269

1270
    def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
1271
        # Add dummy token at the end (no attention on this one)
1272

1273
        effective_batch_size = input_ids.shape[0]
1274
        dummy_token = torch.zeros((effective_batch_size, 1), dtype=torch.long, device=input_ids.device)
1275

1276
        # At every pass, the attention values for the new token and the two last generated tokens
1277
        # are computed, the rest is reloaded from the `past` cache. A purely auto-regressive model would have
1278
        # offset = 1; offset = 2 seems to have slightly better computation.
1279
        offset = 2
1280

1281
        if past:
1282
            input_ids = torch.cat([input_ids[:, -offset:], dummy_token], dim=1)
1283
        else:
1284
            input_ids = torch.cat([input_ids, dummy_token], dim=1)
1285

1286
        # Build permutation mask so that previous tokens don't see last token
1287
        sequence_length = input_ids.shape[1]
1288
        perm_mask = torch.zeros(
1289
            (effective_batch_size, sequence_length, sequence_length), dtype=torch.float, device=input_ids.device
1290
        )
1291
        perm_mask[:, :, -1] = 1.0
1292

1293
        # We'll only predict the last token
1294
        target_mapping = torch.zeros(
1295
            (effective_batch_size, 1, sequence_length), dtype=torch.float, device=input_ids.device
1296
        )
1297
        target_mapping[0, 0, -1] = 1.0
1298

1299
        inputs = {
1300
            "input_ids": input_ids,
1301
            "perm_mask": perm_mask,
1302
            "target_mapping": target_mapping,
1303
            "use_cache": kwargs["use_cache"],
1304
        }
1305

1306
        # if past is defined in model kwargs then use it for faster decoding
1307
        if past:
1308
            inputs["mems"] = tuple(layer_past[:-offset, :, :] for layer_past in past)
1309

1310
        return inputs
1311

1312
    @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1313
    @replace_return_docstrings(output_type=XLNetLMHeadModelOutput, config_class=_CONFIG_FOR_DOC)
1314
    def forward(
1315
        self,
1316
        input_ids=None,
1317
        attention_mask=None,
1318
        mems=None,
1319
        perm_mask=None,
1320
        target_mapping=None,
1321
        token_type_ids=None,
1322
        input_mask=None,
1323
        head_mask=None,
1324
        inputs_embeds=None,
1325
        labels=None,
1326
        use_cache=None,
1327
        output_attentions=None,
1328
        output_hidden_states=None,
1329
        return_dict=None,
1330
    ):
1331
        r"""
1332
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_predict)`, `optional`, defaults to :obj:`None`):
1333
            Labels for masked language modeling.
1334
            `num_predict` corresponds to `target_mapping.shape[1]`. If `target_mapping` is `None`, then `num_predict` corresponds to `sequence_length`.
1335
            The labels should correspond to the masked input words that should be predicted and depends on `target_mapping`. Note in order to perform standard auto-regressive language modeling a `<mask>` token has to be added to the `input_ids` (see `prepare_inputs_for_generation` fn and examples below)
1336
            Indices are selected in ``[-100, 0, ..., config.vocab_size]``
1337
            All labels set to ``-100`` are ignored, the loss is only
1338
            computed for labels in ``[0, ..., config.vocab_size]``
1339

1340
    Return:
1341

1342
    Examples::
1343

1344
        from transformers import XLNetTokenizer, XLNetLMHeadModel
1345
        import torch
1346

1347
        tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
1348
        model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased', return_dict=True)
1349

1350
        # We show how to setup inputs to predict a next token using a bi-directional context.
1351
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is very <mask>", add_special_tokens=False)).unsqueeze(0)  # We will predict the masked token
1352
        perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
1353
        perm_mask[:, :, -1] = 1.0  # Previous tokens don't see last token
1354
        target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float)  # Shape [1, 1, seq_length] => let's predict one token
1355
        target_mapping[0, 0, -1] = 1.0  # Our first (and only) prediction will be the last token of the sequence (the masked token)
1356

1357
        outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)
1358
        next_token_logits = outputs[0]  # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
1359

1360
        # The same way can the XLNetLMHeadModel be used to be trained by standard auto-regressive language modeling.
1361
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is very <mask>", add_special_tokens=False)).unsqueeze(0)  # We will predict the masked token
1362
        labels = torch.tensor(tokenizer.encode("cute", add_special_tokens=False)).unsqueeze(0)
1363
        assert labels.shape[0] == 1, 'only one word will be predicted'
1364
        perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
1365
        perm_mask[:, :, -1] = 1.0  # Previous tokens don't see last token as is done in standard auto-regressive lm training
1366
        target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float)  # Shape [1, 1, seq_length] => let's predict one token
1367
        target_mapping[0, 0, -1] = 1.0  # Our first (and only) prediction will be the last token of the sequence (the masked token)
1368

1369
        outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping, labels=labels)
1370
        loss = outputs.loss
1371
        next_token_logits = outputs.logits  # Logits have shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
1372
        """
1373
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1374
        use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
1375

1376
        transformer_outputs = self.transformer(
1377
            input_ids,
1378
            attention_mask=attention_mask,
1379
            mems=mems,
1380
            perm_mask=perm_mask,
1381
            target_mapping=target_mapping,
1382
            token_type_ids=token_type_ids,
1383
            input_mask=input_mask,
1384
            head_mask=head_mask,
1385
            inputs_embeds=inputs_embeds,
1386
            use_cache=use_cache,
1387
            output_attentions=output_attentions,
1388
            output_hidden_states=output_hidden_states,
1389
            return_dict=return_dict,
1390
        )
1391

1392
        logits = self.lm_loss(transformer_outputs[0])
1393

1394
        loss = None
1395
        if labels is not None:
1396
            # Flatten the tokens
1397
            loss_fct = CrossEntropyLoss()
1398
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
1399

1400
        if not return_dict:
1401
            output = (logits,) + transformer_outputs[1:]
1402
            return ((loss,) + output) if loss is not None else output
1403

1404
        return XLNetLMHeadModelOutput(
1405
            loss=loss,
1406
            logits=logits,
1407
            mems=transformer_outputs.mems,
1408
            hidden_states=transformer_outputs.hidden_states,
1409
            attentions=transformer_outputs.attentions,
1410
        )
1411

1412

1413
@add_start_docstrings(
1414
    """XLNet Model with a sequence classification/regression head on top (a linear layer on top of
1415
    the pooled output) e.g. for GLUE tasks. """,
1416
    XLNET_START_DOCSTRING,
1417
)
1418
class XLNetForSequenceClassification(XLNetPreTrainedModel):
1419
    def __init__(self, config):
1420
        super().__init__(config)
1421
        self.num_labels = config.num_labels
1422

1423
        self.transformer = XLNetModel(config)
1424
        self.sequence_summary = SequenceSummary(config)
1425
        self.logits_proj = nn.Linear(config.d_model, config.num_labels)
1426

1427
        self.init_weights()
1428

1429
    @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1430
    @add_code_sample_docstrings(
1431
        tokenizer_class=_TOKENIZER_FOR_DOC,
1432
        checkpoint="xlnet-base-cased",
1433
        output_type=XLNetForSequenceClassificationOutput,
1434
        config_class=_CONFIG_FOR_DOC,
1435
    )
1436
    def forward(
1437
        self,
1438
        input_ids=None,
1439
        attention_mask=None,
1440
        mems=None,
1441
        perm_mask=None,
1442
        target_mapping=None,
1443
        token_type_ids=None,
1444
        input_mask=None,
1445
        head_mask=None,
1446
        inputs_embeds=None,
1447
        labels=None,
1448
        use_cache=None,
1449
        output_attentions=None,
1450
        output_hidden_states=None,
1451
        return_dict=None,
1452
    ):
1453
        r"""
1454
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`)
1455
            Labels for computing the sequence classification/regression loss.
1456
            Indices should be in ``[0, ..., config.num_labels - 1]``.
1457
            If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
1458
            If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
1459
        """
1460
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1461
        use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
1462

1463
        transformer_outputs = self.transformer(
1464
            input_ids,
1465
            attention_mask=attention_mask,
1466
            mems=mems,
1467
            perm_mask=perm_mask,
1468
            target_mapping=target_mapping,
1469
            token_type_ids=token_type_ids,
1470
            input_mask=input_mask,
1471
            head_mask=head_mask,
1472
            inputs_embeds=inputs_embeds,
1473
            use_cache=use_cache,
1474
            output_attentions=output_attentions,
1475
            output_hidden_states=output_hidden_states,
1476
            return_dict=return_dict,
1477
        )
1478
        output = transformer_outputs[0]
1479

1480
        output = self.sequence_summary(output)
1481
        logits = self.logits_proj(output)
1482

1483
        loss = None
1484
        if labels is not None:
1485
            if self.num_labels == 1:
1486
                #  We are doing regression
1487
                loss_fct = MSELoss()
1488
                loss = loss_fct(logits.view(-1), labels.view(-1))
1489
            else:
1490
                loss_fct = CrossEntropyLoss()
1491
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1492

1493
        if not return_dict:
1494
            output = (logits,) + transformer_outputs[1:]
1495
            return ((loss,) + output) if loss is not None else output
1496

1497
        return XLNetForSequenceClassificationOutput(
1498
            loss=loss,
1499
            logits=logits,
1500
            mems=transformer_outputs.mems,
1501
            hidden_states=transformer_outputs.hidden_states,
1502
            attentions=transformer_outputs.attentions,
1503
        )
1504

1505

1506
@add_start_docstrings(
1507
    """XLNet Model with a token classification head on top (a linear layer on top of
1508
    the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
1509
    XLNET_START_DOCSTRING,
1510
)
1511
class XLNetForTokenClassification(XLNetPreTrainedModel):
1512
    def __init__(self, config):
1513
        super().__init__(config)
1514
        self.num_labels = config.num_labels
1515

1516
        self.transformer = XLNetModel(config)
1517
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1518

1519
        self.init_weights()
1520

1521
    @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1522
    @add_code_sample_docstrings(
1523
        tokenizer_class=_TOKENIZER_FOR_DOC,
1524
        checkpoint="xlnet-base-cased",
1525
        output_type=XLNetForTokenClassificationOutput,
1526
        config_class=_CONFIG_FOR_DOC,
1527
    )
1528
    def forward(
1529
        self,
1530
        input_ids=None,
1531
        attention_mask=None,
1532
        mems=None,
1533
        perm_mask=None,
1534
        target_mapping=None,
1535
        token_type_ids=None,
1536
        input_mask=None,
1537
        head_mask=None,
1538
        inputs_embeds=None,
1539
        labels=None,
1540
        use_cache=None,
1541
        output_attentions=None,
1542
        output_hidden_states=None,
1543
        return_dict=None,
1544
    ):
1545
        r"""
1546
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1547
            Labels for computing the multiple choice classification loss.
1548
            Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
1549
            of the input tensors. (see `input_ids` above)
1550
        """
1551
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1552
        use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
1553

1554
        outputs = self.transformer(
1555
            input_ids,
1556
            attention_mask=attention_mask,
1557
            mems=mems,
1558
            perm_mask=perm_mask,
1559
            target_mapping=target_mapping,
1560
            token_type_ids=token_type_ids,
1561
            input_mask=input_mask,
1562
            head_mask=head_mask,
1563
            inputs_embeds=inputs_embeds,
1564
            use_cache=use_cache,
1565
            output_attentions=output_attentions,
1566
            output_hidden_states=output_hidden_states,
1567
            return_dict=return_dict,
1568
        )
1569

1570
        sequence_output = outputs[0]
1571

1572
        logits = self.classifier(sequence_output)
1573

1574
        loss = None
1575
        if labels is not None:
1576
            loss_fct = CrossEntropyLoss()
1577
            # Only keep active parts of the loss
1578
            if attention_mask is not None:
1579
                active_loss = attention_mask.view(-1) == 1
1580
                active_logits = logits.view(-1, self.num_labels)
1581
                active_labels = torch.where(
1582
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
1583
                )
1584
                loss = loss_fct(active_logits, active_labels)
1585
            else:
1586
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1587

1588
        if not return_dict:
1589
            output = (logits,) + outputs[1:]
1590
            return ((loss,) + output) if loss is not None else output
1591

1592
        return XLNetForTokenClassificationOutput(
1593
            loss=loss,
1594
            logits=logits,
1595
            mems=outputs.mems,
1596
            hidden_states=outputs.hidden_states,
1597
            attentions=outputs.attentions,
1598
        )
1599

1600

1601
@add_start_docstrings(
1602
    """XLNet Model with a multiple choice classification head on top (a linear layer on top of
1603
    the pooled output and a softmax) e.g. for RACE/SWAG tasks. """,
1604
    XLNET_START_DOCSTRING,
1605
)
1606
class XLNetForMultipleChoice(XLNetPreTrainedModel):
1607
    def __init__(self, config):
1608
        super().__init__(config)
1609

1610
        self.transformer = XLNetModel(config)
1611
        self.sequence_summary = SequenceSummary(config)
1612
        self.logits_proj = nn.Linear(config.d_model, 1)
1613

1614
        self.init_weights()
1615

1616
    @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
1617
    @add_code_sample_docstrings(
1618
        tokenizer_class=_TOKENIZER_FOR_DOC,
1619
        checkpoint="xlnet-base-cased",
1620
        output_type=XLNetForMultipleChoiceOutput,
1621
        config_class=_CONFIG_FOR_DOC,
1622
    )
1623
    def forward(
1624
        self,
1625
        input_ids=None,
1626
        token_type_ids=None,
1627
        input_mask=None,
1628
        attention_mask=None,
1629
        mems=None,
1630
        perm_mask=None,
1631
        target_mapping=None,
1632
        head_mask=None,
1633
        inputs_embeds=None,
1634
        labels=None,
1635
        use_cache=None,
1636
        output_attentions=None,
1637
        output_hidden_states=None,
1638
        return_dict=None,
1639
    ):
1640
        r"""
1641
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1642
            Labels for computing the multiple choice classification loss.
1643
            Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
1644
            of the input tensors. (see `input_ids` above)
1645
        """
1646
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1647
        use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
1648
        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1649

1650
        flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1651
        flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1652
        flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1653
        flat_input_mask = input_mask.view(-1, input_mask.size(-1)) if input_mask is not None else None
1654
        flat_inputs_embeds = (
1655
            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1656
            if inputs_embeds is not None
1657
            else None
1658
        )
1659

1660
        transformer_outputs = self.transformer(
1661
            flat_input_ids,
1662
            token_type_ids=flat_token_type_ids,
1663
            input_mask=flat_input_mask,
1664
            attention_mask=flat_attention_mask,
1665
            mems=mems,
1666
            perm_mask=perm_mask,
1667
            target_mapping=target_mapping,
1668
            head_mask=head_mask,
1669
            inputs_embeds=flat_inputs_embeds,
1670
            use_cache=use_cache,
1671
            output_attentions=output_attentions,
1672
            output_hidden_states=output_hidden_states,
1673
            return_dict=return_dict,
1674
        )
1675

1676
        output = transformer_outputs[0]
1677

1678
        output = self.sequence_summary(output)
1679
        logits = self.logits_proj(output)
1680
        reshaped_logits = logits.view(-1, num_choices)
1681

1682
        loss = None
1683
        if labels is not None:
1684
            loss_fct = CrossEntropyLoss()
1685
            loss = loss_fct(reshaped_logits, labels.view(-1))
1686

1687
        if not return_dict:
1688
            output = (reshaped_logits,) + transformer_outputs[1:]
1689
            return ((loss,) + output) if loss is not None else output
1690

1691
        return XLNetForMultipleChoiceOutput(
1692
            loss=loss,
1693
            logits=reshaped_logits,
1694
            mems=transformer_outputs.mems,
1695
            hidden_states=transformer_outputs.hidden_states,
1696
            attentions=transformer_outputs.attentions,
1697
        )
1698

1699

1700
@add_start_docstrings(
1701
    """XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
1702
    the hidden-states output to compute `span start logits` and `span end logits`). """,
1703
    XLNET_START_DOCSTRING,
1704
)
1705
class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
1706
    def __init__(self, config):
1707
        super().__init__(config)
1708
        self.num_labels = config.num_labels
1709

1710
        self.transformer = XLNetModel(config)
1711
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1712

1713
        self.init_weights()
1714

1715
    @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1716
    @add_code_sample_docstrings(
1717
        tokenizer_class=_TOKENIZER_FOR_DOC,
1718
        checkpoint="xlnet-base-cased",
1719
        output_type=XLNetForQuestionAnsweringSimpleOutput,
1720
        config_class=_CONFIG_FOR_DOC,
1721
    )
1722
    def forward(
1723
        self,
1724
        input_ids=None,
1725
        attention_mask=None,
1726
        mems=None,
1727
        perm_mask=None,
1728
        target_mapping=None,
1729
        token_type_ids=None,
1730
        input_mask=None,
1731
        head_mask=None,
1732
        inputs_embeds=None,
1733
        start_positions=None,
1734
        end_positions=None,
1735
        use_cache=None,
1736
        output_attentions=None,
1737
        output_hidden_states=None,
1738
        return_dict=None,
1739
    ):
1740
        r"""
1741
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1742
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
1743
            Positions are clamped to the length of the sequence (`sequence_length`).
1744
            Position outside of the sequence are not taken into account for computing the loss.
1745
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1746
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
1747
            Positions are clamped to the length of the sequence (`sequence_length`).
1748
            Position outside of the sequence are not taken into account for computing the loss.
1749
        """
1750
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1751
        use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
1752

1753
        outputs = self.transformer(
1754
            input_ids,
1755
            attention_mask=attention_mask,
1756
            mems=mems,
1757
            perm_mask=perm_mask,
1758
            target_mapping=target_mapping,
1759
            token_type_ids=token_type_ids,
1760
            input_mask=input_mask,
1761
            head_mask=head_mask,
1762
            inputs_embeds=inputs_embeds,
1763
            use_cache=use_cache,
1764
            output_attentions=output_attentions,
1765
            output_hidden_states=output_hidden_states,
1766
            return_dict=return_dict,
1767
        )
1768

1769
        sequence_output = outputs[0]
1770

1771
        logits = self.qa_outputs(sequence_output)
1772
        start_logits, end_logits = logits.split(1, dim=-1)
1773
        start_logits = start_logits.squeeze(-1)
1774
        end_logits = end_logits.squeeze(-1)
1775

1776
        total_loss = None
1777
        if start_positions is not None and end_positions is not None:
1778
            # If we are on multi-GPU, split add a dimension
1779
            if len(start_positions.size()) > 1:
1780
                start_positions = start_positions.squeeze(-1)
1781
            if len(end_positions.size()) > 1:
1782
                end_positions = end_positions.squeeze(-1)
1783
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
1784
            ignored_index = start_logits.size(1)
1785
            start_positions.clamp_(0, ignored_index)
1786
            end_positions.clamp_(0, ignored_index)
1787

1788
            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1789
            start_loss = loss_fct(start_logits, start_positions)
1790
            end_loss = loss_fct(end_logits, end_positions)
1791
            total_loss = (start_loss + end_loss) / 2
1792

1793
        if not return_dict:
1794
            output = (start_logits, end_logits) + outputs[1:]
1795
            return ((total_loss,) + output) if total_loss is not None else output
1796

1797
        return XLNetForQuestionAnsweringSimpleOutput(
1798
            loss=total_loss,
1799
            start_logits=start_logits,
1800
            end_logits=end_logits,
1801
            mems=outputs.mems,
1802
            hidden_states=outputs.hidden_states,
1803
            attentions=outputs.attentions,
1804
        )
1805

1806

1807
@add_start_docstrings(
1808
    """XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
1809
    the hidden-states output to compute `span start logits` and `span end logits`). """,
1810
    XLNET_START_DOCSTRING,
1811
)
1812
class XLNetForQuestionAnswering(XLNetPreTrainedModel):
1813
    def __init__(self, config):
1814
        super().__init__(config)
1815
        self.start_n_top = config.start_n_top
1816
        self.end_n_top = config.end_n_top
1817

1818
        self.transformer = XLNetModel(config)
1819
        self.start_logits = PoolerStartLogits(config)
1820
        self.end_logits = PoolerEndLogits(config)
1821
        self.answer_class = PoolerAnswerClass(config)
1822

1823
        self.init_weights()
1824

1825
    @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1826
    @replace_return_docstrings(output_type=XLNetForQuestionAnsweringOutput, config_class=_CONFIG_FOR_DOC)
1827
    def forward(
1828
        self,
1829
        input_ids=None,
1830
        attention_mask=None,
1831
        mems=None,
1832
        perm_mask=None,
1833
        target_mapping=None,
1834
        token_type_ids=None,
1835
        input_mask=None,
1836
        head_mask=None,
1837
        inputs_embeds=None,
1838
        start_positions=None,
1839
        end_positions=None,
1840
        is_impossible=None,
1841
        cls_index=None,
1842
        p_mask=None,
1843
        use_cache=None,
1844
        output_attentions=None,
1845
        output_hidden_states=None,
1846
        return_dict=None,
1847
    ):
1848
        r"""
1849
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1850
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
1851
            Positions are clamped to the length of the sequence (`sequence_length`).
1852
            Position outside of the sequence are not taken into account for computing the loss.
1853
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1854
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
1855
            Positions are clamped to the length of the sequence (`sequence_length`).
1856
            Position outside of the sequence are not taken into account for computing the loss.
1857
        is_impossible (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
1858
            Labels whether a question has an answer or no answer (SQuAD 2.0)
1859
        cls_index (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
1860
            Labels for position (index) of the classification token to use as input for computing plausibility of the answer.
1861
        p_mask (``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
1862
            Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...).
1863
            1.0 means token should be masked. 0.0 mean token is not masked.
1864

1865
    Returns:
1866

1867
    Example::
1868

1869
        >>> from transformers import XLNetTokenizer, XLNetForQuestionAnswering
1870
        >>> import torch
1871

1872
        >>> tokenizer =  XLNetTokenizer.from_pretrained('xlnet-base-cased')
1873
        >>> model = XLNetForQuestionAnswering.from_pretrained('xlnet-base-cased', return_dict=True)
1874

1875
        >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
1876
        >>> start_positions = torch.tensor([1])
1877
        >>> end_positions = torch.tensor([3])
1878
        >>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
1879

1880
        >>> loss = outputs.loss
1881
        """
1882
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1883
        use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
1884

1885
        transformer_outputs = self.transformer(
1886
            input_ids,
1887
            attention_mask=attention_mask,
1888
            mems=mems,
1889
            perm_mask=perm_mask,
1890
            target_mapping=target_mapping,
1891
            token_type_ids=token_type_ids,
1892
            input_mask=input_mask,
1893
            head_mask=head_mask,
1894
            inputs_embeds=inputs_embeds,
1895
            use_cache=use_cache,
1896
            output_attentions=output_attentions,
1897
            output_hidden_states=output_hidden_states,
1898
            return_dict=return_dict,
1899
        )
1900
        hidden_states = transformer_outputs[0]
1901
        start_logits = self.start_logits(hidden_states, p_mask=p_mask)
1902

1903
        outputs = transformer_outputs[1:]  # Keep mems, hidden states, attentions if there are in it
1904

1905
        if start_positions is not None and end_positions is not None:
1906
            # If we are on multi-GPU, let's remove the dimension added by batch splitting
1907
            for x in (start_positions, end_positions, cls_index, is_impossible):
1908
                if x is not None and x.dim() > 1:
1909
                    x.squeeze_(-1)
1910

1911
            # during training, compute the end logits based on the ground truth of the start position
1912
            end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
1913

1914
            loss_fct = CrossEntropyLoss()
1915
            start_loss = loss_fct(start_logits, start_positions)
1916
            end_loss = loss_fct(end_logits, end_positions)
1917
            total_loss = (start_loss + end_loss) / 2
1918

1919
            if cls_index is not None and is_impossible is not None:
1920
                # Predict answerability from the representation of CLS and START
1921
                cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
1922
                loss_fct_cls = nn.BCEWithLogitsLoss()
1923
                cls_loss = loss_fct_cls(cls_logits, is_impossible)
1924

1925
                # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
1926
                total_loss += cls_loss * 0.5
1927

1928
            if not return_dict:
1929
                return (total_loss,) + transformer_outputs[1:]
1930
            else:
1931
                return XLNetForQuestionAnsweringOutput(
1932
                    loss=total_loss,
1933
                    mems=transformer_outputs.mems,
1934
                    hidden_states=transformer_outputs.hidden_states,
1935
                    attentions=transformer_outputs.attentions,
1936
                )
1937

1938
        else:
1939
            # during inference, compute the end logits based on beam search
1940
            bsz, slen, hsz = hidden_states.size()
1941
            start_log_probs = F.softmax(start_logits, dim=-1)  # shape (bsz, slen)
1942

1943
            start_top_log_probs, start_top_index = torch.topk(
1944
                start_log_probs, self.start_n_top, dim=-1
1945
            )  # shape (bsz, start_n_top)
1946
            start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz)  # shape (bsz, start_n_top, hsz)
1947
            start_states = torch.gather(hidden_states, -2, start_top_index_exp)  # shape (bsz, start_n_top, hsz)
1948
            start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1)  # shape (bsz, slen, start_n_top, hsz)
1949

1950
            hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
1951
                start_states
1952
            )  # shape (bsz, slen, start_n_top, hsz)
1953
            p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
1954
            end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
1955
            end_log_probs = F.softmax(end_logits, dim=1)  # shape (bsz, slen, start_n_top)
1956

1957
            end_top_log_probs, end_top_index = torch.topk(
1958
                end_log_probs, self.end_n_top, dim=1
1959
            )  # shape (bsz, end_n_top, start_n_top)
1960
            end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
1961
            end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
1962

1963
            start_states = torch.einsum(
1964
                "blh,bl->bh", hidden_states, start_log_probs
1965
            )  # get the representation of START as weighted sum of hidden states
1966
            cls_logits = self.answer_class(
1967
                hidden_states, start_states=start_states, cls_index=cls_index
1968
            )  # Shape (batch size,): one single `cls_logits` for each sample
1969

1970
            if not return_dict:
1971
                outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
1972
                return outputs + transformer_outputs[1:]
1973
            else:
1974
                return XLNetForQuestionAnsweringOutput(
1975
                    start_top_log_probs=start_top_log_probs,
1976
                    start_top_index=start_top_index,
1977
                    end_top_log_probs=end_top_log_probs,
1978
                    end_top_index=end_top_index,
1979
                    cls_logits=cls_logits,
1980
                    mems=transformer_outputs.mems,
1981
                    hidden_states=transformer_outputs.hidden_states,
1982
                    attentions=transformer_outputs.attentions,
1983
                )
1984

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.